From 60db76c93f2b9d7cf6cfa31fbfbd202e7ad3eddb Mon Sep 17 00:00:00 2001 From: Jason Madden Date: Thu, 18 Jul 2019 18:17:02 -0500 Subject: [PATCH 1/8] Introduce a SQL statement framework. The main goal is to make it easier to automatically prepare statements and get rid of the nasty ad-hoc way we've been doing it. Secondary goals are better readability, better refactoring (since we'll get compile-time errors if the schema changes) and easier cross-database and cross-history-state queries. --- src/relstorage/_compat.py | 2 + src/relstorage/_util.py | 39 + src/relstorage/adapters/_sql.py | 901 ++++++++++++++++++ src/relstorage/adapters/locker.py | 59 +- src/relstorage/adapters/mover.py | 105 +- src/relstorage/adapters/mysql/adapter.py | 36 +- .../adapters/mysql/drivers/__init__.py | 20 + src/relstorage/adapters/mysql/locker.py | 15 - src/relstorage/adapters/mysql/mover.py | 24 - src/relstorage/adapters/mysql/txncontrol.py | 5 +- src/relstorage/adapters/oracle/adapter.py | 21 +- src/relstorage/adapters/oracle/mover.py | 6 +- .../adapters/oracle/scriptrunner.py | 33 +- src/relstorage/adapters/oracle/txncontrol.py | 4 +- src/relstorage/adapters/poller.py | 77 +- src/relstorage/adapters/postgresql/adapter.py | 82 +- .../adapters/postgresql/drivers/pg8000.py | 14 + src/relstorage/adapters/postgresql/mover.py | 107 +-- .../adapters/postgresql/tests/test_mover.py | 28 +- .../postgresql/tests/test_txncontrol.py | 13 - .../adapters/postgresql/txncontrol.py | 36 +- src/relstorage/adapters/schema.py | 60 ++ src/relstorage/adapters/tests/test__sql.py | 189 ++++ .../adapters/tests/test_txncontrol.py | 39 +- src/relstorage/adapters/txncontrol.py | 35 +- src/relstorage/storage/__init__.py | 4 +- src/relstorage/tests/__init__.py | 18 +- 27 files changed, 1466 insertions(+), 506 deletions(-) create mode 100644 src/relstorage/adapters/_sql.py create mode 100644 src/relstorage/adapters/tests/test__sql.py diff --git a/src/relstorage/_compat.py b/src/relstorage/_compat.py index 3ad91d99..338f2511 100644 --- a/src/relstorage/_compat.py +++ b/src/relstorage/_compat.py @@ -65,9 +65,11 @@ def iteroiditems(d): if PY3: string_types = (str,) unicode = str + from io import StringIO as NStringIO else: string_types = (basestring,) unicode = unicode + from io import BytesIO as NStringIO try: from abc import ABC diff --git a/src/relstorage/_util.py b/src/relstorage/_util.py index 1d224a21..2a737f63 100644 --- a/src/relstorage/_util.py +++ b/src/relstorage/_util.py @@ -143,6 +143,45 @@ def __get__(self, inst, class_): inst.__dict__[name] = value return value +class CachedIn(object): + """Cached method with given cache attribute.""" + + def __init__(self, attribute_name, factory=dict): + self.attribute_name = attribute_name + self.factory = factory + + def __call__(self, func): + + @functools.wraps(func) + def decorated(instance): + cache = self.cache(instance) + key = () # We don't support arguments right now, so only one key. + try: + v = cache[key] + except KeyError: + v = cache[key] = func(instance) + return v + + decorated.invalidate = self.invalidate + return decorated + + def invalidate(self, instance): + cache = self.cache(instance) + key = () + try: + del cache[key] + except KeyError: + pass + + def cache(self, instance): + try: + cache = getattr(instance, self.attribute_name) + except AttributeError: + cache = self.factory() + setattr(instance, self.attribute_name, cache) + return cache + + def to_utf8(data): if data is None or isinstance(data, bytes): return data diff --git a/src/relstorage/adapters/_sql.py b/src/relstorage/adapters/_sql.py new file mode 100644 index 00000000..e74d5686 --- /dev/null +++ b/src/relstorage/adapters/_sql.py @@ -0,0 +1,901 @@ +# -*- coding: utf-8 -*- +############################################################################## +# +# Copyright (c) 2019 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## + +""" +A small abstraction layer for SQL queries. + +Features: + + - Simple, readable syntax for writing queries. + + - Automatic switching between history-free and history-preserving + schemas. + + - Support for automatically preparing statements (depending on the + driver; some allow parameters, some do not, for example) + + - Always use bind parameters + + - Take care of minor database syntax issues. + +This is inspired by the SQLAlchemy Core, but we don't use it because +we don't want to take a dependency that can conflict with applications +using RelStorage. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from copy import copy as stdlib_copy +from operator import attrgetter +from weakref import WeakKeyDictionary + +from relstorage._compat import NStringIO +from relstorage._util import CachedIn + +def copy(obj): + new = stdlib_copy(obj) + volatile = [k for k in vars(new) if k.startswith('_v')] + for k in volatile: + delattr(new, k) + return new + +class Column(object): + """ + Defines a column in a table. + """ + + def __init__(self, name): + self.name = name + + def __str__(self): + return self.name + + def __eq__(self, other): + return _EqualExpression(self, other) + + def __gt__(self, other): + return _GreaterExpression(self, other) + + def __ge__(self, other): + return _GreaterEqualExpression(self, other) + + def __ne__(self, other): + return _NotEqualExpression(self, other) + + def __le__(self, other): + return _LessEqualExpression(self, other) + + def __compile_visit__(self, compiler): + compiler.emit_identifier(self.name) + +class _TextNode(object): + + def __init__(self, raw): + self.raw = raw + + def __compile_visit__(self, compiler): + compiler.emit(self.raw) + +class _Columns(object): + """ + Grab bag of columns. + """ + + def __init__(self, columns): + cs = [] + for c in columns: + try: + setattr(self, c.name, c) + except AttributeError: + # Must be a string. + c = _TextNode(c) + cs.append(c) + self._columns = tuple(cs) + + def __bool__(self): + return bool(self._columns) + + __nonzero__ = __bool__ + + def __getattr__(self, name): + # Here only so that pylint knows this class has a set of + # dynamic attributes. + raise AttributeError("Column list %s does not include %s" % ( + self._col_list(), + name + )) + + def __getitem__(self, ix): + return self._columns[ix] + + def _col_list(self): + return ','.join(str(c) for c in self._columns) + + def __compile_visit__(self, compiler): + compiler.visit_csv(self._columns) + + +_ColumnList = _Columns + +class Table(object): + """ + A table relation. + """ + + def __init__(self, name, *columns): + self.name = name + self.columns = columns + self.c = _Columns(columns) + + def __str__(self): + return self.name + + def select(self, *args, **kwargs): + return Select(self, *args, **kwargs) + + def __compile_visit__(self, compiler): + compiler.emit_identifier(self.name) + + def bindparam(self, key): + return bindparam(key) + + def orderedbindparam(self): + return orderedbindparam() + + def natural_join(self, other_table): + return NaturalJoinedTable(self, other_table) + + def insert(self, *args, **kwargs): + return Insert(self, *args, **kwargs) + +class TemporaryTable(Table): + """ + A temporary table. + """ + +class _CompositeTableMixin(object): + + def __init__(self, lhs, rhs): + self.lhs = lhs + self.rhs = rhs + + common_columns = [] + for col in lhs.columns: + if hasattr(rhs.c, col.name): + common_columns.append(col) + self.columns = common_columns + self.c = _Columns(common_columns) + + def select(self, *args, **kwargs): + return Select(self, *args, **kwargs) + + def bindparam(self, key): + return bindparam(key) + + def orderedbindparam(self): + return orderedbindparam() + +class _DefaultDialect(object): + + def __init__(self, base): + self._base = base + + def __getattr__(self, name): + return getattr(self._base, name) + + def __bool__(self): + return False + + __nonzero__ = __bool__ + + _driver_locations = ( + attrgetter('_base.driver'), + attrgetter('_base.poller.driver'), + attrgetter('_base.connmanager.driver'), + attrgetter('_base.adapter.driver') + ) + + def compiler_class(self): + # We want to find *something* with a driver. + # Preferably the object we're attached to, but if not that, + # we'll look at some common attributes for adapter objects + # for it. + for getter in self._driver_locations: + try: + return getter(self).sql_compiler_class + except AttributeError: + pass + return _Compiler + + def compiler(self, root): + return self.compiler_class()(root) + + def __eq__(self, other): + if isinstance(other, _DefaultDialect): + return self._base == other._base + return NotImplemented + + def __repr__(self): + return '<%s at %x base=%r>' % ( + type(self), + id(self), + self._base + ) + +class _Bindable(object): + + context = _DefaultDialect(None) + + def _find_dialect(self, context): + # Look up the database type, find the right dialect. + # Ordinarily we want to use the database driver. + if isinstance(context, _DefaultDialect): + return context + return _DefaultDialect(context) + + def bind(self, context): + new = copy(self) + context = self._find_dialect(context) + new.context = context + bound_replacements = { + k: v.bind(context) + for k, v + in vars(new).items() + if isinstance(v, _Bindable) + } + for k, v in bound_replacements.items(): + setattr(new, k, v) + return new + +class NaturalJoinedTable(_Bindable, + _CompositeTableMixin): + + def __init__(self, lhs, rhs): + super(NaturalJoinedTable, self).__init__(lhs, rhs) + assert self.c + self._join_columns = self.c + + # TODO: Check for data type mismatches, etc? + self.columns = list(self.lhs.columns) + for col in self.rhs.columns: + if not hasattr(self.lhs.c, col.name): + self.columns.append(col) + self.columns = tuple(self.columns) + self.c = _Columns(self.columns) + + def __compile_visit__(self, compiler): + compiler.visit(self.lhs) + compiler.emit(' JOIN ') + compiler.visit(self.rhs) + # careful with USING clause in a join: Oracle doesn't allow such + # columns to have a prefix. + compiler.emit_keyword(' USING') + compiler.visit_grouped(self._join_columns) + + +class HistoryVariantTable(_Bindable, + _CompositeTableMixin): + """ + A table that can be one of two tables, depending on whether + the instance is keeping history or not. + """ + + @property + def history_preserving(self): + return self.lhs + + @property + def history_free(self): + return self.rhs + + def __compile_visit__(self, compiler): + keep_history = getattr(self.context, 'keep_history', True) + node = self.history_preserving if keep_history else self.history_free + return compiler.visit(node) + + +class _CompiledQuery(object): + """ + Represents a completed query. + """ + + stmt = None + params = None + _raw_stmt = None + _prepare_stmt = None + _prepare_converter = None + + def __init__(self, root): + self.root = root + # We do not keep a reference to the context; + # it's likely to be an instance object that's + # going to have us stored in its dictionary. + context = root.context + + compiler = context.compiler(root) + self.stmt, self.params = compiler.compile() + self._raw_stmt = self.stmt # for debugging + if compiler.can_prepare(): + self._prepare_stmt, self.stmt, self._prepare_converter = compiler.prepare() + + def __repr__(self): + if self._prepare_stmt: + return "%s (%s)" % ( + self.stmt, + self._prepare_stmt + ) + return self.stmt + + def __str__(self): + return self.stmt + + _cursor_cache = WeakKeyDictionary() + + def _stmt_cache_for_cursor(self, cursor): + """Returns a dictionary.""" + # If we can't store it directly on the cursor, as happens for + # types implemented in C, we use a weakkey dictionary. + try: + cursor_prep_stmts = cursor._rs_prepared_statements + except AttributeError: + try: + cursor_prep_stmts = cursor._rs_prepared_statements = {} + except AttributeError: + cursor_prep_stmts = self._cursor_cache.get(cursor) + if cursor_prep_stmts is None: + cursor_prep_stmts = self._cursor_cache[cursor] = {} + return cursor_prep_stmts + + def execute(self, cursor, params=None): + # (Any, dict) -> None + # TODO: Include literals from self.params. + # TODO: Syntax transformation if they don't support names. + # TODO: Validate given params match expected ones, nothing missing? + stmt = self.stmt + if self._prepare_stmt: + # Prepare on demand. + + # In all databases, prepared statements + # persist past COMMIT/ROLLBACK (though in PostgreSQL + # preparing them takes locks that only go away at + # COMMIT/ROLLBACK). But they don't persist past a session + # restart (new connection) (obviously). + # + # Thus we keep a cache of statements we have prepared for + # this particular connection/cursor. + # + cursor_prep_stmts = self._stmt_cache_for_cursor(cursor) + try: + stmt = cursor_prep_stmts[self._prepare_stmt] + except KeyError: + stmt = cursor_prep_stmts[self._prepare_stmt] = self.stmt + __traceback_info__ = self._prepare_stmt, self, self.root.context.compiler(self.root) + cursor.execute(self._prepare_stmt) + params = self._prepare_converter(params) + + __traceback_info__ = stmt, params + if params: + cursor.execute(stmt, params) + else: + cursor.execute(stmt) + +class _Compiler(object): + + def __init__(self, root): + self.buf = NStringIO() + self.placeholders = {} + self.root = root + + + def __repr__(self): + return "" % ( + self.buf.getvalue(), + self.placeholders + ) + + def compile(self): + self.visit(self.root) + return self.finalize() + + def can_prepare(self): + # Obviously this needs to be more general. + # Some drivers, for example, can't deal with parameters + # in a prepared statement. + return self.root.prepare and isinstance(self.root, _Query) + + _prepared_stmt_counter = 0 + + @classmethod + def _next_prepared_stmt_name(cls): + cls._prepared_stmt_counter += 1 + return 'rs_prep_stmt_%d' % (cls._prepared_stmt_counter,) + + def _prepared_param(self, number): + return '$' + str(number) + + _PREPARED_CONJUNCTION = 'AS' + + def _quote_query_for_prepare(self, query): + return query + + def prepare(self): + # This is correct for PostgreSQL. This needs moved to a dialect specific + # spot. + + # TODO: Deduce the datatypes based on the types of the columns + # we're sending as params. + datatypes = {} + query = self.buf.getvalue() + name = self._next_prepared_stmt_name() + + if datatypes: + assert isinstance(datatypes, (list, tuple)) + datatypes = ', '.join(datatypes) + datatypes = ' (%s)' % (datatypes,) + else: + datatypes = '' + + q = query.strip() + + # PREPARE needs the query string to use $1, $2, $3, etc, + # as placeholders. + # In MySQL, it's a plain question mark. + placeholder_to_number = {} + counter = 0 + for placeholder_name in self.placeholders.values(): + counter += 1 + placeholder = self._placeholder(placeholder_name) + placeholder_to_number[placeholder_name] = counter + param = self._prepared_param(counter) + q = q.replace(placeholder, param, 1) + + q = self._quote_query_for_prepare(q) + + stmt = 'PREPARE {name}{datatypes} {conjunction} {query}'.format( + name=name, datatypes=datatypes, + query=q, + conjunction=self._PREPARED_CONJUNCTION, + ) + + if placeholder_to_number: + execute = 'EXECUTE {name}({params})'.format( + name=name, + params=','.join(['%s'] * len(self.placeholders)), + ) + else: + # Neither MySQL nor PostgreSQL like a set of empty parens: () + execute = 'EXECUTE {name}'.format(name=name) + + if '%s' in placeholder_to_number: + # There was an ordered param. If there was one, + # they must all be ordered, so there's no need to convert anything. + assert len(placeholder_to_number) == 1 + def convert(p): + return p + else: + def convert(d): + # TODO: This may not actually be needed, since we issue a regular + # cursor.execute(), it may be able to handle named? + params = [None] * len(placeholder_to_number) + for placeholder_name, ix in placeholder_to_number.items(): + params[ix - 1] = d[placeholder_name] + return params + + return stmt, execute, convert + + def finalize(self): + return self.buf.getvalue().strip(), {v: k for k, v in self.placeholders.items()} + + def visit(self, node): + node.__compile_visit__(self) + + visit_clause = visit + + def emit(self, *contents): + for content in contents: + self.buf.write(content) + + def emit_w_padding_space(self, value): + ended_in_space = self.buf.getvalue().endswith(' ') + value = value.strip() + if not ended_in_space: + self.buf.write(' ') + self.emit(value, ' ') + + emit_keyword = emit_w_padding_space + + def emit_identifier(self, identifier): + last_char = self.buf.getvalue()[-1] + if last_char not in ('(', ' '): + self.emit(' ', identifier) + else: + self.emit(identifier) + + def visit_column_list(self, column_list): + clist = column_list.c if hasattr(column_list, 'c') else column_list + self.visit(clist) + + def visit_csv(self, nodes): + self.visit(nodes[0]) + for node in nodes[1:]: + self.emit(', ') + self.visit(node) + + def visit_from(self, from_): + self.emit_keyword('FROM') + self.visit(from_) + + def visit_grouped(self, clause): + self.emit('(') + self.visit(clause) + self.emit(')') + + def visit_op(self, op): + self.emit(' ' + op + ' ') + + def _next_placeholder_name(self): + return 'param_%d' % (len(self.placeholders),) + + def _placeholder(self, key): + # Write things in `pyformat` style by default, assuming a + # dictionary of params; this is supported by most drivers. + if key == '%s': + return key + return '%%(%s)s' % (key,) + + def _placeholder_for_literal_param_value(self, value): + placeholder = self.placeholders.get(value) + if not placeholder: + placeholder_name = self._next_placeholder_name() + placeholder = self._placeholder(placeholder_name) + self.placeholders[value] = placeholder_name + return placeholder + + def visit_literal_param(self, value): + placeholder = self._placeholder_for_literal_param_value(value) + self.emit(placeholder) + + def visit_bind_param(self, bind_param): + self.placeholders[bind_param] = bind_param.key + self.emit(self._placeholder(bind_param.key)) + + def visit_ordered_bind_param(self, bind_param): + self.placeholders[bind_param] = '%s' + self.emit('%s') + +class _Expression(_Bindable): + """ + A SQL expression. + """ +class _BindParam(_Expression): + + def __init__(self, key): + self.key = key + + def __compile_visit__(self, compiler): + compiler.visit_bind_param(self) + + +def bindparam(key): + return _BindParam(key) + +class _LiteralExpression(_Expression): + + def __init__(self, value): + self.value = value + + def __compile_visit__(self, compiler): + compiler.visit_literal_param(self.value) + +class _OrderedBindParam(_Expression): + + name = '%s' + + def __compile_visit__(self, compiler): + compiler.visit_ordered_bind_param(self) + +def orderedbindparam(): + return _OrderedBindParam() + +class _BinaryExpression(_Expression): + """ + Expresses a comparison. + """ + + def __init__(self, op, lhs, rhs): + self.op = op + self.lhs = lhs # type: Column + # rhs is either a literal or a column + if not hasattr(rhs, '__compile_visit__'): + rhs = _LiteralExpression(rhs) + self.rhs = rhs + + def __str__(self): + return '%s %s %s' % ( + self.lhs, + self.op, + self.rhs + ) + + def __compile_visit__(self, compiler): + compiler.visit(self.lhs) + compiler.visit_op(self.op) + compiler.visit(self.rhs) + + +class _EmptyExpression(_Expression): + """ + No comparison at all. + """ + + def __bool__(self): + return False + + __nonzero__ = __bool__ + + def __str__(self): + return '' + + def and_(self, expression): + return expression + + def __compile_visit__(self, compiler): + "Does nothing" + +class _EqualExpression(_BinaryExpression): + + def __init__(self, lhs, rhs): + _BinaryExpression.__init__(self, '=', lhs, rhs) + +class _NotEqualExpression(_BinaryExpression): + + def __init__(self, lhs, rhs): + _BinaryExpression.__init__(self, '<>', lhs, rhs) + + +class _GreaterExpression(_BinaryExpression): + + def __init__(self, lhs, rhs): + _BinaryExpression.__init__(self, '>', lhs, rhs) + +class _GreaterEqualExpression(_BinaryExpression): + + def __init__(self, lhs, rhs): + _BinaryExpression.__init__(self, '>=', lhs, rhs) + +class _LessEqualExpression(_BinaryExpression): + + def __init__(self, lhs, rhs): + _BinaryExpression.__init__(self, '<=', lhs, rhs) + + +class _Clause(_Bindable): + """ + A portion of a SQL statement. + """ + +class _And(_Expression): + + def __init__(self, lhs, rhs): + self.lhs = lhs + self.rhs = rhs + + def __compile_visit__(self, compiler): + compiler.visit_grouped(_BinaryExpression('AND', self.lhs, self.rhs)) + +class _WhereClause(_Clause): + + def __init__(self, expression): + self.expression = expression + + def and_(self, expression): + expression = _And(self.expression, expression) + new = copy(self) + new.expression = expression + return new + + def __compile_visit__(self, compiler): + compiler.emit_keyword(' WHERE') + compiler.visit_grouped(self.expression) + +class _OrderBy(_Clause): + + def __init__(self, expression, dir): + self.expression = expression + self.dir = dir + + def __compile_visit__(self, compiler): + compiler.emit(' ORDER BY ') + compiler.visit(self.expression) + if self.dir: + compiler.emit(' ' + self.dir) + + +def _where(expression): + if expression: + return _WhereClause(expression) + +class _Query(_Bindable): + __name__ = None + prepare = False + + def __str__(self): + return str(self.compiled()) + + def __get__(self, inst, klass): + if inst is None: + return self + # We need to set this into the instance's dictionary. + # Otherwise we'll be rebinding and recompiling each time we're + # accessed which is not good. On Python 3.6+, there's the + # `__set_name__(klass, name)` called on the descriptor which + # does the job perfectly. In earlier versions, we're on our + # own. + # TODO: In test cases, we spend a lot of time binding and compiling. + # Can we find another layer of caching somewhere? + result = self.bind(inst).compiled() + if not self.__name__: + # Go through the class hierarchy, find out what we're called. + for base in klass.mro(): + for k, v in vars(base).items(): + if v is self: + self.__name__ = k + break + assert self.__name__ + + vars(inst)[self.__name__] = result + + return result + + def __set_name__(self, owner, name): + self.__name__ = name + + @CachedIn('_v_compiled') + def compiled(self): + return _CompiledQuery(self) + + def prepared(self): + """ + Note that it's good to prepare this query, if + supported by the driver. + """ + s = copy(self) + s.prepare = True + return s + +class Select(_Query): + """ + A Select query. + + When instances of this class are stored in a class dictionary, + they function as non-data descriptors: The first time they are + accessed, they *bind* themselves to the instance and select the + appropriate SQL syntax and compile themselves into a string. + """ + + _where = _EmptyExpression() + _order_by = _EmptyExpression() + _limit = None + _for_update = None + _nowait = None + + def __init__(self, table, *columns): + self.table = table + if columns: + self.column_list = _ColumnList(columns) + else: + self.column_list = table + + def where(self, expression): + s = copy(self) + s._where = _where(expression) + return s + + def and_(self, expression): + s = copy(self) + s._where = self._where.and_(expression) + return s + + def order_by(self, expression, dir=None): + s = copy(self) + s._order_by = _OrderBy(expression, dir) + return s + + def limit(self, literal): + s = copy(self) + s._limit = literal + return s + + def for_update(self): + s = copy(self) + s._for_update = 'FOR UPDATE' + return s + + def nowait(self): + s = copy(self) + s._nowait = 'NOWAIT' + return s + + def __compile_visit__(self, compiler): + compiler.emit_keyword('SELECT') + compiler.visit_column_list(self.column_list) + compiler.visit_from(self.table) + compiler.visit_clause(self._where) + compiler.visit_clause(self._order_by) + if self._limit: + compiler.emit_keyword('LIMIT') + compiler.emit(str(self._limit)) + if self._for_update: + compiler.emit_keyword(self._for_update) + if self._nowait: + compiler.emit_keyword(self._nowait) + +class _Functions(object): + + def max(self, column): + return _Function('max', column) + +class _Function(_Expression): + + def __init__(self, name, expression): + self.name = name + self.expression = expression + + def __compile_visit__(self, compiler): + compiler.emit_identifier(self.name) + compiler.visit_grouped(self.expression) + +func = _Functions() + +class Insert(_Query): + + column_list = None + select = None + epilogue = '' + values = None + + def __init__(self, table, *columns): + self.table = table + if columns: + self.column_list = _Columns(columns) + # TODO: Probably want a different type, like a ValuesList + self.values = _Columns([orderedbindparam() for _ in columns]) + + def from_select(self, names, select): + i = copy(self) + i.column_list = _Columns(names) + i.select = select + return i + + def __compile_visit__(self, compiler): + compiler.emit_keyword('INSERT INTO') + compiler.visit(self.table) + compiler.visit_grouped(self.column_list) + if self.select: + compiler.visit(self.select) + else: + compiler.emit_keyword('VALUES') + compiler.visit_grouped(self.values) + compiler.emit(self.epilogue) + + def __add__(self, extension): + # This appends a textual epilogue. It's a temporary + # measure until we have more nodes and can model what + # we're trying to accomplish. + assert isinstance(extension, str) + i = copy(self) + i.epilogue += extension + return i diff --git a/src/relstorage/adapters/locker.py b/src/relstorage/adapters/locker.py index 336d15fd..acc34ed3 100644 --- a/src/relstorage/adapters/locker.py +++ b/src/relstorage/adapters/locker.py @@ -30,6 +30,7 @@ from ._util import query_property as _query_property from ._util import DatabaseHelpersMixin +from .schema import Schema from .interfaces import UnableToAcquireCommitLockError logger = __import__('logging').getLogger(__name__) @@ -198,39 +199,31 @@ def lock_current_objects(self, cursor, current_oids): consume(rows) - _commit_lock_queries = ( - # MySQL allows aggregates in the top level to use FOR UPDATE, - # but PostgreSQL does not, so we have to use the second form. - # - # 'SELECT MAX(tid) FROM transaction FOR UPDATE', - # 'SELECT tid FROM transaction WHERE tid = (SELECT MAX(tid) FROM transaction) FOR UPDATE', - - # Note that using transaction in history-preserving databases - # can still lead to deadlock in older versions of MySQL (test - # checkPackWhileWriting), and the above lock statement can - # lead to duplicate transaction ids being inserted on older - # versions (5.7.12, PyMySQL: - # https://ci.appveyor.com/project/jamadden/relstorage/builds/25748619/job/cyio3w54uqi026lr#L923). - # So both HF and HP use an artificial lock row. - # - # TODO: Figure out exactly the best way to lock just the rows - # in the transaction table we care about that works - # everywhere, or a better way to choose the next TID. - # gap/intention locks might be a clue. - - 'SELECT tid FROM commit_row_lock FOR UPDATE', - 'SELECT tid FROM commit_row_lock FOR UPDATE' - ) - - _commit_lock_query = _query_property('_commit_lock') - - _commit_lock_nowait_queries = ( - _commit_lock_queries[0] + ' NOWAIT', - _commit_lock_queries[1] + ' NOWAIT', + # MySQL allows aggregates in the top level to use FOR UPDATE, + # but PostgreSQL does not, so we have to use the second form. + # + # 'SELECT MAX(tid) FROM transaction FOR UPDATE', + # 'SELECT tid FROM transaction WHERE tid = (SELECT MAX(tid) FROM transaction) FOR UPDATE', + + # Note that using transaction in history-preserving databases + # can still lead to deadlock in older versions of MySQL (test + # checkPackWhileWriting), and the above lock statement can + # lead to duplicate transaction ids being inserted on older + # versions (5.7.12, PyMySQL: + # https://ci.appveyor.com/project/jamadden/relstorage/builds/25748619/job/cyio3w54uqi026lr#L923). + # So both HF and HP use an artificial lock row. + # + # TODO: Figure out exactly the best way to lock just the rows + # in the transaction table we care about that works + # everywhere, or a better way to choose the next TID. + # gap/intention locks might be a clue. + _commit_lock_query = Schema.commit_row_lock.select( + Schema.commit_row_lock.c.tid + ).for_update( + ).prepared( ) - _commit_lock_nowait_query = _query_property('_commit_lock_nowait') - + _commit_lock_nowait_query = _commit_lock_query.nowait() @metricmethod def hold_commit_lock(self, cursor, ensure_current=False, nowait=False): @@ -241,9 +234,9 @@ def hold_commit_lock(self, cursor, ensure_current=False, nowait=False): lock_stmt = self._commit_lock_nowait_query else: self._set_row_lock_nowait(cursor) - __traceback_info__ = lock_stmt + try: - cursor.execute(lock_stmt) + lock_stmt.execute(cursor) rows = cursor.fetchall() if not rows or not rows[0]: raise UnableToAcquireCommitLockError("No row returned from commit_row_lock") diff --git a/src/relstorage/adapters/mover.py b/src/relstorage/adapters/mover.py index 85b603a0..9d25d125 100644 --- a/src/relstorage/adapters/mover.py +++ b/src/relstorage/adapters/mover.py @@ -28,6 +28,10 @@ from .._compat import ABC from .batch import RowBatcher from .interfaces import IObjectMover +from .schema import Schema + +objects = Schema.all_current_object_state +object_state = Schema.object_state metricmethod_sampled = Metric(method=True, rate=0.1) @@ -54,20 +58,11 @@ def _compute_md5sum(self, data): return None return md5(data).hexdigest() - _load_current_queries = ( - """ - SELECT state, tid - FROM current_object - JOIN object_state USING(zoid, tid) - WHERE zoid = %s - """, - """ - SELECT state, tid - FROM object_state - WHERE zoid = %s - """) - - _load_current_query = _query_property('_load_current') + _load_current_query = objects.select( + objects.c.state, objects.c.tid + ).where( + objects.c.zoid == objects.orderedbindparam() + ).prepared() @metricmethod_sampled def load_current(self, cursor, oid): @@ -76,8 +71,7 @@ def load_current(self, cursor, oid): oid is an integer. Returns (None, None) if object does not exist. """ stmt = self._load_current_query - - cursor.execute(stmt, (oid,)) + stmt.execute(cursor, (oid,)) # Note that we cannot rely on cursor.rowcount being # a valid indicator. The DB-API doesn't require it, and # some implementations, like MySQL Connector/Python are @@ -110,12 +104,13 @@ def load_currents(self, cursor, oids): oid, state, tid = row yield oid, binary_column_as_state_type(state), tid - _load_revision_query = """ - SELECT state - FROM object_state - WHERE zoid = %s - AND tid = %s - """ + _load_revision_query = object_state.select( + object_state.c.state + ).where( + object_state.c.zoid == object_state.orderedbindparam() + ).and_( + object_state.c.tid == object_state.orderedbindparam() + ).prepared() @metricmethod_sampled def load_revision(self, cursor, oid, tid): @@ -124,25 +119,24 @@ def load_revision(self, cursor, oid, tid): Returns None if no such state exists. """ stmt = self._load_revision_query - cursor.execute(stmt, (oid, tid)) + stmt.execute(cursor, (oid, tid)) row = cursor.fetchone() if row: (state,) = row return self.driver.binary_column_as_state_type(state) return None - _exists_queries = ( - "SELECT 1 FROM current_object WHERE zoid = %s", - "SELECT 1 FROM object_state WHERE zoid = %s" + _exists_query = Schema.all_current_object.select( + Schema.all_current_object.c.zoid + ).where( + Schema.all_current_object.c.zoid == Schema.all_current_object.orderedbindparam() ) - _exists_query = _query_property('_exists') - @metricmethod_sampled def exists(self, cursor, oid): """Returns a true value if the given object exists.""" stmt = self._exists_query - cursor.execute(stmt, (oid,)) + stmt.execute(cursor, (oid,)) row = cursor.fetchone() return row @@ -326,24 +320,15 @@ def restore(self, cursor, batcher, oid, tid, data): # careful with USING clause in a join: Oracle doesn't allow such # columns to have a prefix. - _detect_conflict_queries = ( - """ - SELECT zoid, current_object.tid, temp_store.prev_tid - FROM temp_store - JOIN current_object USING (zoid) - WHERE temp_store.prev_tid != current_object.tid - ORDER BY zoid - """, - """ - SELECT zoid, object_state.tid, temp_store.prev_tid - FROM temp_store - JOIN object_state USING (zoid) - WHERE temp_store.prev_tid != object_state.tid - ORDER BY zoid - """ - ) - - _detect_conflict_query = _query_property('_detect_conflict') + _detect_conflict_query = Schema.temp_store.natural_join( + Schema.all_current_object + ).select( + Schema.temp_store.c.zoid, Schema.all_current_object.c.tid, Schema.temp_store.c.prev_tid + ).where( + Schema.temp_store.c.prev_tid != Schema.all_current_object.c.tid + ).order_by( + Schema.temp_store.c.zoid + ).prepared() @metricmethod_sampled def detect_conflict(self, cursor): @@ -351,7 +336,7 @@ def detect_conflict(self, cursor): # passed to tryToResolveConflict, saving extra queries. # OTOH, using extra memory. stmt = self._detect_conflict_query - cursor.execute(stmt) + stmt.execute(cursor) rows = cursor.fetchall() return rows @@ -389,12 +374,21 @@ def replace_temp(self, cursor, oid, prev_tid, data): WHERE zoid IN (SELECT zoid FROM temp_store) """ - _move_from_temp_hf_insert_query = """ - INSERT INTO object_state (zoid, tid, state_size, state) - SELECT zoid, %s, COALESCE(LENGTH(state), 0), state - FROM temp_store - ORDER BY zoid - """ + _move_from_temp_hf_insert_query = Schema.object_state.insert( + ).from_select( + (Schema.object_state.c.zoid, + Schema.object_state.c.tid, + Schema.object_state.c.state_size, + Schema.object_state.c.state), + Schema.temp_store.select( + Schema.temp_store.c.zoid, + Schema.temp_store.orderedbindparam(), + 'COALESCE(LENGTH(state), 0)', + Schema.temp_store.c.state + ).order_by( + Schema.temp_store.c.zoid + ) + ).prepared() _move_from_temp_copy_blob_query = """ INSERT INTO blob_chunk (zoid, tid, chunk_num, chunk) @@ -435,7 +429,8 @@ def _move_from_temp_object_state(self, cursor, tid): cursor.execute(stmt) stmt = self._move_from_temp_hf_insert_query - cursor.execute(stmt, (tid,)) + __traceback_info__ = stmt + stmt.execute(cursor, (tid,)) @metricmethod_sampled diff --git a/src/relstorage/adapters/mysql/adapter.py b/src/relstorage/adapters/mysql/adapter.py index a4fc247b..a3e19e89 100644 --- a/src/relstorage/adapters/mysql/adapter.py +++ b/src/relstorage/adapters/mysql/adapter.py @@ -58,7 +58,7 @@ from relstorage.options import Options from .._abstract_drivers import _select_driver -from .._util import query_property + from ..dbiter import HistoryFreeDatabaseIterator from ..dbiter import HistoryPreservingDatabaseIterator from ..interfaces import IRelStorageAdapter @@ -69,7 +69,7 @@ from .connmanager import MySQLdbConnectionManager from .locker import MySQLLocker from .mover import MySQLObjectMover -from .mover import to_prepared_queries + from .oidallocator import MySQLOIDAllocator from .packundo import MySQLHistoryFreePackUndo from .packundo import MySQLHistoryPreservingPackUndo @@ -122,21 +122,20 @@ def __init__(self, options=None, **params): self.connmanager.add_on_store_opened(self.mover.on_store_opened) self.connmanager.add_on_load_opened(self.mover.on_load_opened) self.oidallocator = MySQLOIDAllocator(driver) - self.txncontrol = MySQLTransactionControl( - connmanager=self.connmanager, - keep_history=self.keep_history, - Binary=driver.Binary, - ) self.poller = Poller( - poll_query='EXECUTE get_latest_tid', + self.driver, keep_history=self.keep_history, runner=self.runner, revert_when_stale=options.revert_when_stale, ) - self.connmanager.add_on_load_opened(self._prepare_get_latest_tid) - self.connmanager.add_on_store_opened(self._prepare_get_latest_tid) + self.txncontrol = MySQLTransactionControl( + connmanager=self.connmanager, + poller=self.poller, + keep_history=self.keep_history, + Binary=driver.Binary, + ) if self.keep_history: self.packundo = MySQLHistoryPreservingPackUndo( @@ -168,23 +167,6 @@ def __init__(self, options=None, **params): keep_history=self.keep_history ) - _get_latest_tid_queries = ( - "SELECT MAX(tid) FROM transaction", - "SELECT MAX(tid) FROM object_state", - ) - - _prepare_get_latest_tid_queries = to_prepared_queries( - 'get_latest_tid', - _get_latest_tid_queries) - - _prepare_get_latest_tid_query = query_property('_prepare_get_latest_tid') - - def _prepare_get_latest_tid(self, cursor, restart=False): - if restart: - return - stmt = self._prepare_get_latest_tid_query - cursor.execute(stmt) - def new_instance(self): return type(self)(options=self.options, **self._params) diff --git a/src/relstorage/adapters/mysql/drivers/__init__.py b/src/relstorage/adapters/mysql/drivers/__init__.py index 7d414245..1110f7e8 100644 --- a/src/relstorage/adapters/mysql/drivers/__init__.py +++ b/src/relstorage/adapters/mysql/drivers/__init__.py @@ -20,9 +20,26 @@ from ..._abstract_drivers import AbstractModuleDriver from ..._abstract_drivers import implement_db_driver_options +from ..._sql import _Compiler database_type = 'mysql' +class MySQLCompiler(_Compiler): + + def can_prepare(self): + # If there are params, we can't prepare unless we're using + # the binary protocol; otherwise we have to SET user variables + # with extra round trips, which is worse. + return not self.placeholders and super(MySQLCompiler, self).can_prepare() + + _PREPARED_CONJUNCTION = 'FROM' + + def _prepared_param(self, number): + return '?' + + def _quote_query_for_prepare(self, query): + return '"{query}"'.format(query=query) + class AbstractMySQLDriver(AbstractModuleDriver): # Don't try to decode pickle states as UTF-8 (or whatever the @@ -75,6 +92,9 @@ def callproc_multi_result(self, cursor, proc, args=()): return multi_results + sql_compiler_class = MySQLCompiler + + implement_db_driver_options( __name__, 'mysqlconnector', 'mysqldb', 'pymysql', diff --git a/src/relstorage/adapters/mysql/locker.py b/src/relstorage/adapters/mysql/locker.py index 7b251de5..503d5b34 100644 --- a/src/relstorage/adapters/mysql/locker.py +++ b/src/relstorage/adapters/mysql/locker.py @@ -97,19 +97,8 @@ class MySQLLocker(AbstractLocker): def __init__(self, options, driver, batcher_factory): super(MySQLLocker, self).__init__(options, driver, batcher_factory) - # TODO: Back to needing a proper prepare registry. - lock_stmt_raw = self._commit_lock_query - lock_stmt_nowait_raw = self._commit_lock_nowait_query - - self._prepare_lock_stmt = 'PREPARE hold_commit_lock FROM "%s"' % ( - lock_stmt_raw) - self._prepare_lock_stmt_nowait = 'PREPARE hold_commit_lock_nowait FROM "%s"' % ( - lock_stmt_nowait_raw) self._supports_row_lock_nowait = None - self._commit_lock_query = 'EXECUTE hold_commit_lock' - self._commit_lock_nowait_query = 'EXECUTE hold_commit_lock_nowait' - # No good preparing this, mysql can't take parameters in EXECUTE, # they have to be user variables, which defeats most of the point # (Although in this case, because it's a static value, maybe not; @@ -136,10 +125,6 @@ def on_store_opened(self, cursor, restart=False): __traceback_info__ = ver, major self._supports_row_lock_nowait = (major >= 8) - cursor.execute(self._prepare_lock_stmt) - if self._supports_row_lock_nowait: - cursor.execute(self._prepare_lock_stmt_nowait) - def _on_store_opened_set_row_lock_timeout(self, cursor, restart=False): if restart: return diff --git a/src/relstorage/adapters/mysql/mover.py b/src/relstorage/adapters/mysql/mover.py index a9e533ad..38c64c5b 100644 --- a/src/relstorage/adapters/mysql/mover.py +++ b/src/relstorage/adapters/mysql/mover.py @@ -22,37 +22,13 @@ from relstorage.adapters.interfaces import IObjectMover -from .._util import query_property from ..mover import AbstractObjectMover from ..mover import metricmethod_sampled -def to_prepared_queries(name, queries, extension=''): - - return [ - 'PREPARE ' + name + ' FROM "' + x.replace('%s', '?') + extension + '"' - for x in queries - ] - @implementer(IObjectMover) class MySQLObjectMover(AbstractObjectMover): - _prepare_detect_conflict_queries = to_prepared_queries( - 'detect_conflicts', - AbstractObjectMover._detect_conflict_queries, - # Now that we explicitly lock the rows before we begin, - # no sense applying a locking clause here too. - ) - - _prepare_detect_conflict_query = query_property('_prepare_detect_conflict') - - _detect_conflict_query = 'EXECUTE detect_conflicts' - - on_load_opened_statement_names = () - - on_store_opened_statement_names = on_load_opened_statement_names - on_store_opened_statement_names += ('_prepare_detect_conflict_query',) - @metricmethod_sampled def on_store_opened(self, cursor, restart=False): """Create the temporary table for storing objects""" diff --git a/src/relstorage/adapters/mysql/txncontrol.py b/src/relstorage/adapters/mysql/txncontrol.py index d25c8ced..d7a01b1c 100644 --- a/src/relstorage/adapters/mysql/txncontrol.py +++ b/src/relstorage/adapters/mysql/txncontrol.py @@ -19,7 +19,4 @@ class MySQLTransactionControl(GenericTransactionControl): - - # See adapter.py for where this is prepared. - # Either history preserving or not, it's the same. - _get_tid_query = 'EXECUTE get_latest_tid' + pass diff --git a/src/relstorage/adapters/oracle/adapter.py b/src/relstorage/adapters/oracle/adapter.py index 88945968..a1d96b2d 100644 --- a/src/relstorage/adapters/oracle/adapter.py +++ b/src/relstorage/adapters/oracle/adapter.py @@ -110,23 +110,24 @@ def __init__(self, user, password, dsn, commit_lock_id=0, self.oidallocator = OracleOIDAllocator( connmanager=self.connmanager, ) + + + self.poller = Poller( + self.driver, + keep_history=self.keep_history, + runner=self.runner, + revert_when_stale=options.revert_when_stale, + ) + self.txncontrol = OracleTransactionControl( connmanager=self.connmanager, + poller=self.poller, keep_history=self.keep_history, Binary=driver.Binary, twophase=twophase, ) - if self.keep_history: - poll_query = "SELECT MAX(tid) FROM transaction" - else: - poll_query = "SELECT MAX(tid) FROM object_state" - self.poller = Poller( - poll_query=poll_query, - keep_history=self.keep_history, - runner=self.runner, - revert_when_stale=options.revert_when_stale, - ) + if self.keep_history: self.packundo = OracleHistoryPreservingPackUndo( diff --git a/src/relstorage/adapters/oracle/mover.py b/src/relstorage/adapters/oracle/mover.py index d1b54ac2..39ea25ca 100644 --- a/src/relstorage/adapters/oracle/mover.py +++ b/src/relstorage/adapters/oracle/mover.py @@ -46,7 +46,8 @@ class OracleObjectMover(AbstractObjectMover): _move_from_temp_copy_blob_query = format_to_named( AbstractObjectMover._move_from_temp_copy_blob_query) - _load_current_queries = _to_oracle_ordered(AbstractObjectMover._load_current_queries) + # XXX: This is definitely broken! + # _load_current_queries = _to_oracle_ordered(AbstractObjectMover._load_current_queries) @metricmethod_sampled def load_current(self, cursor, oid): @@ -65,7 +66,8 @@ def load_revision(self, cursor, oid, tid): return state - _exists_queries = _to_oracle_ordered(AbstractObjectMover._exists_queries) + # XXX: Def broken. + #_exists_queries = _to_oracle_ordered(AbstractObjectMover._exists_queries) @metricmethod_sampled def exists(self, cursor, oid): diff --git a/src/relstorage/adapters/oracle/scriptrunner.py b/src/relstorage/adapters/oracle/scriptrunner.py index 08076a6e..ba0fe99a 100644 --- a/src/relstorage/adapters/oracle/scriptrunner.py +++ b/src/relstorage/adapters/oracle/scriptrunner.py @@ -15,9 +15,6 @@ from __future__ import absolute_import import logging -import re - -from relstorage._compat import intern from relstorage._compat import iteritems from ..scriptrunner import ScriptRunner @@ -31,18 +28,24 @@ def format_to_named(stmt): Convert '%s' pyformat strings to :n numbered strings. Intended only for static strings. """ - try: - return _stmt_cache[stmt] - except KeyError: - matches = [] - - def replace(_match): - matches.append(None) - return ':%d' % len(matches) - new_stmt = intern(re.sub('%s', replace, stmt)) - _stmt_cache[stmt] = new_stmt - - return new_stmt + # XXX: This should be part of the Compiler, as handled + # by the driver. + return stmt + # import re + # from relstorage._compat import intern + + # try: + # return _stmt_cache[stmt] + # except KeyError: + # matches = [] + + # def replace(_match): + # matches.append(None) + # return ':%d' % len(matches) + # new_stmt = intern(re.sub('%s', replace, stmt)) + # _stmt_cache[stmt] = new_stmt + + # return new_stmt class OracleScriptRunner(ScriptRunner): diff --git a/src/relstorage/adapters/oracle/txncontrol.py b/src/relstorage/adapters/oracle/txncontrol.py index 4607a180..8095f3f7 100644 --- a/src/relstorage/adapters/oracle/txncontrol.py +++ b/src/relstorage/adapters/oracle/txncontrol.py @@ -25,8 +25,8 @@ class OracleTransactionControl(GenericTransactionControl): - def __init__(self, connmanager, keep_history, Binary, twophase): - GenericTransactionControl.__init__(self, connmanager, keep_history, Binary) + def __init__(self, connmanager, poller, keep_history, Binary, twophase): + GenericTransactionControl.__init__(self, connmanager, poller, keep_history, Binary) self.twophase = twophase def commit_phase1(self, conn, cursor, tid): diff --git a/src/relstorage/adapters/poller.py b/src/relstorage/adapters/poller.py index dc31c381..184818b2 100644 --- a/src/relstorage/adapters/poller.py +++ b/src/relstorage/adapters/poller.py @@ -17,12 +17,13 @@ from zope.interface import implementer -from ._util import formatted_query_property from .interfaces import IPoller from .interfaces import StaleConnectionError -log = logging.getLogger(__name__) +from .schema import Schema +from ._sql import func +log = logging.getLogger(__name__) @implementer(IPoller) class Poller(object): @@ -31,44 +32,30 @@ class Poller(object): # The zoid is the primary key on both ``current_object`` (history # preserving) and ``object_state`` (history free), so these # queries are guaranteed to only produce an OID once. - _list_changes_range_queries = ( - """ - SELECT zoid, tid - FROM current_object - WHERE tid > %(min_tid)s - AND tid <= %(max_tid)s - """, - """ - SELECT zoid, tid - FROM object_state - WHERE tid > %(min_tid)s - AND tid <= %(max_tid)s - """ - ) - - _list_changes_range_query = formatted_query_property('_list_changes_range') - - _poll_inv_queries = ( - """ - SELECT zoid, tid - FROM current_object - WHERE tid > %(tid)s - """, - """ - SELECT zoid, tid - FROM object_state - WHERE tid > %(tid)s - """ - ) - - _poll_inv_query = formatted_query_property('_poll_inv') - - _poll_inv_exc_query = formatted_query_property('_poll_inv', - extension=' AND tid != %(self_tid)s') - - - def __init__(self, poll_query, keep_history, runner, revert_when_stale): - self.poll_query = poll_query + _list_changes_range_query = Schema.all_current_object.select( + Schema.all_current_object.c.zoid, Schema.all_current_object.c.tid + ).where( + Schema.all_current_object.c.tid > Schema.all_current_object.bindparam('min_tid') + ).and_( + Schema.all_current_object.c.tid <= Schema.all_current_object.bindparam('max_tid') + ).prepared() + + _poll_inv_query = Schema.all_current_object.select( + Schema.all_current_object.c.zoid, Schema.all_current_object.c.tid + ).where( + Schema.all_current_object.c.tid > Schema.all_current_object.bindparam('tid') + ).prepared() + + _poll_inv_exc_query = _poll_inv_query.and_( + Schema.all_current_object.c.tid != Schema.all_current_object.bindparam('self_tid') + ).prepared() + + poll_query = Schema.all_transaction.select( + func.max(Schema.all_transaction.c.tid) + ).prepared() + + def __init__(self, driver, keep_history, runner, revert_when_stale): + self.driver = driver self.keep_history = keep_history self.runner = runner self.revert_when_stale = revert_when_stale @@ -92,7 +79,7 @@ def poll_invalidations(self, conn, cursor, prev_polled_tid, ignore_tid): """ # pylint:disable=unused-argument # find out the tid of the most recent transaction. - cursor.execute(self.poll_query) + self.poll_query.execute(cursor) rows = cursor.fetchall() if not rows or not rows[0][0]: # No data, must be fresh database, without even @@ -148,8 +135,8 @@ def poll_invalidations(self, conn, cursor, prev_polled_tid, ignore_tid): # all the unreachable objects will be garbage collected # anyway. # - # Thus we became convinced it was safe to remove the check in history-preserving - # databases. + # Thus we became convinced it was safe to remove the check in + # history-preserving databases. # Get the list of changed OIDs and return it. stmt = self._poll_inv_query @@ -158,7 +145,7 @@ def poll_invalidations(self, conn, cursor, prev_polled_tid, ignore_tid): stmt = self._poll_inv_exc_query params['self_tid'] = ignore_tid - cursor.execute(stmt, params) + stmt.execute(cursor, params) changes = cursor.fetchall() return changes, new_polled_tid @@ -167,5 +154,5 @@ def list_changes(self, cursor, after_tid, last_tid): See ``IPoller``. """ params = {'min_tid': after_tid, 'max_tid': last_tid} - cursor.execute(self._list_changes_range_query, params) + self._list_changes_range_query.execute(cursor, params) return cursor.fetchall() diff --git a/src/relstorage/adapters/postgresql/adapter.py b/src/relstorage/adapters/postgresql/adapter.py index 337576f9..f869d9d3 100644 --- a/src/relstorage/adapters/postgresql/adapter.py +++ b/src/relstorage/adapters/postgresql/adapter.py @@ -22,32 +22,44 @@ from ...options import Options from .._abstract_drivers import _select_driver -from .._util import query_property + from ..dbiter import HistoryFreeDatabaseIterator from ..dbiter import HistoryPreservingDatabaseIterator from ..interfaces import IRelStorageAdapter from ..packundo import HistoryFreePackUndo from ..packundo import HistoryPreservingPackUndo from ..poller import Poller +from ..schema import Schema from ..scriptrunner import ScriptRunner from . import drivers from .batch import PostgreSQLRowBatcher from .connmanager import Psycopg2ConnectionManager from .locker import PostgreSQLLocker -from .mover import PG8000ObjectMover from .mover import PostgreSQLObjectMover -from .mover import to_prepared_queries + from .oidallocator import PostgreSQLOIDAllocator from .schema import PostgreSQLSchemaInstaller from .stats import PostgreSQLStats from .txncontrol import PostgreSQLTransactionControl -from .txncontrol import PG8000TransactionControl + log = logging.getLogger(__name__) def select_driver(options=None): return _select_driver(options or Options(), drivers) +# TODO: Move to own file +class PGPoller(Poller): + + poll_query = Schema.all_transaction.select( + Schema.all_transaction.c.tid + ).order_by( + Schema.all_transaction.c.tid, dir='DESC' + ).limit( + 1 + ).prepared() + + @implementer(IRelStorageAdapter) class PostgreSQLAdapter(object): """PostgreSQL adapter for RelStorage.""" @@ -83,13 +95,7 @@ def __init__(self, dsn='', options=None): locker=self.locker, ) - mover_type = PostgreSQLObjectMover - txn_type = PostgreSQLTransactionControl - if driver.__name__ == 'pg8000': - mover_type = PG8000ObjectMover - txn_type = PG8000TransactionControl - - self.mover = mover_type( + self.mover = PostgreSQLObjectMover( driver, options=options, runner=self.runner, @@ -97,19 +103,21 @@ def __init__(self, dsn='', options=None): batcher_factory=PostgreSQLRowBatcher, ) self.oidallocator = PostgreSQLOIDAllocator() - self.txncontrol = txn_type( - connmanager=self.connmanager, - keep_history=self.keep_history, - driver=driver, - ) - self.poller = Poller( - poll_query="EXECUTE get_latest_tid", + self.poller = PGPoller( + self.driver, keep_history=self.keep_history, runner=self.runner, revert_when_stale=options.revert_when_stale, ) + self.txncontrol = PostgreSQLTransactionControl( + connmanager=self.connmanager, + poller=self.poller, + keep_history=self.keep_history, + Binary=driver.Binary, + ) + if self.keep_history: self.packundo = HistoryPreservingPackUndo( driver, @@ -144,49 +152,11 @@ def __init__(self, dsn='', options=None): self.connmanager.add_on_store_opened(self.mover.on_store_opened) self.connmanager.add_on_load_opened(self.mover.on_load_opened) - self.connmanager.add_on_load_opened(self.__prepare_statements) self.connmanager.add_on_store_opened(self.__prepare_store_statements) - _get_latest_tid_queries = ( - """ - SELECT tid - FROM transaction - ORDER BY tid DESC - LIMIT 1 - """, - """ - SELECT tid - FROM object_state - ORDER BY tid DESC - LIMIT 1 - """ - ) - - _prepare_get_latest_tid_queries = to_prepared_queries( - 'get_latest_tid', - _get_latest_tid_queries) - - _prepare_get_latest_tid_query = query_property('_prepare_get_latest_tid') - - def __prepare_statements(self, cursor, restart=False): - if restart: - return - - # TODO: Generalize all of this better. There should be a - # registry of things to prepare, or we should wrap cursors to - # detect and prepare when needed. Preparation and switching to - # EXECUTE should be automatic for drivers that don't already do that. - - # A meta-class or base class __new__ could handle proper - # history/free query selection without this mass of tuples and - # manual properties and property names. - stmt = self._prepare_get_latest_tid_query - cursor.execute(stmt) - def __prepare_store_statements(self, cursor, restart=False): if not restart: - self.__prepare_statements(cursor, restart) try: stmt = self.txncontrol._prepare_add_transaction_query except (Unsupported, AttributeError): diff --git a/src/relstorage/adapters/postgresql/drivers/pg8000.py b/src/relstorage/adapters/postgresql/drivers/pg8000.py index 57c9bcf2..a7b9c83d 100644 --- a/src/relstorage/adapters/postgresql/drivers/pg8000.py +++ b/src/relstorage/adapters/postgresql/drivers/pg8000.py @@ -25,6 +25,7 @@ from ..._abstract_drivers import AbstractModuleDriver from ...interfaces import IDBDriver +from ..._sql import _Compiler __all__ = [ 'PG8000Driver', @@ -115,6 +116,17 @@ class _tuple_deque(deque): def append(self, row): # pylint:disable=arguments-differ deque.append(self, tuple(row)) +class PG8000Compiler(_Compiler): + + def can_prepare(self): + # Important: pg8000 1.10 - 1.13, at least, can't handle prepared + # statements that take parameters but it doesn't need to because it + # prepares every statement anyway. So you must have a backup that you use + # for that driver. + # https://github.com/mfenniak/pg8000/issues/132 + return False + + @implementer(IDBDriver) class PG8000Driver(AbstractModuleDriver): __name__ = 'pg8000' @@ -229,3 +241,5 @@ def connect_with_isolation(self, dsn, cursor.execute('SET SESSION CHARACTERISTICS AS ' + transaction_stmt) conn.commit() return conn, cursor + + sql_compiler_class = PG8000Compiler diff --git a/src/relstorage/adapters/postgresql/mover.py b/src/relstorage/adapters/postgresql/mover.py index 275477dd..b3caf2ca 100644 --- a/src/relstorage/adapters/postgresql/mover.py +++ b/src/relstorage/adapters/postgresql/mover.py @@ -20,76 +20,17 @@ import struct from zope.interface import implementer -from ZODB.POSException import Unsupported -from .._util import query_property + from ..interfaces import IObjectMover from ..mover import AbstractObjectMover from ..mover import metricmethod_sampled -# Important: pg8000 1.10 - 1.13, at least, can't handle prepared -# statements that take parameters but it doesn't need to because it -# prepares every statement anyway. So you must have a backup that you use -# for that driver. -# https://github.com/mfenniak/pg8000/issues/132 - - -def to_prepared_queries(name, queries, datatypes=()): - # Give correct datatypes for the queries, wherever possible. - # The number of parameters should be the same or more than the - # number of datatypes. - # datatypes is a sequence of strings. - - # Maybe instead of having the adapter have to know about all the - # statements that need prepared, we could keep a registry? - if datatypes: - assert isinstance(datatypes, (list, tuple)) - datatypes = ', '.join(datatypes) - datatypes = ' (%s)' % (datatypes,) - else: - datatypes = '' - - result = [] - for q in queries: - if not isinstance(q, str): - # Unsupported marker - result.append(q) - continue - - q = q.strip() - param_count = q.count('%s') - rep_count = 0 - while rep_count < param_count: - rep_count += 1 - q = q.replace('%s', '$' + str(rep_count), 1) - stmt = 'PREPARE {name}{datatypes} AS {query}'.format( - name=name, datatypes=datatypes, query=q - ) - result.append(stmt) - return result - @implementer(IObjectMover) class PostgreSQLObjectMover(AbstractObjectMover): - _prepare_load_current_queries = to_prepared_queries( - 'load_current', - AbstractObjectMover._load_current_queries, - ['BIGINT']) - - _prepare_load_current_query = query_property('_prepare_load_current') - - _load_current_query = 'EXECUTE load_current(%s)' - - _prepare_detect_conflict_queries = to_prepared_queries( - 'detect_conflicts', - AbstractObjectMover._detect_conflict_queries) - - _prepare_detect_conflict_query = query_property('_prepare_detect_conflict') - - _detect_conflict_query = 'EXECUTE detect_conflicts' - - _move_from_temp_hf_insert_query_raw = AbstractObjectMover._move_from_temp_hf_insert_query + """ + _move_from_temp_hf_insert_query = AbstractObjectMover._move_from_temp_hf_insert_query + """ ON CONFLICT (zoid) DO UPDATE SET state_size = COALESCE(LENGTH(excluded.state), 0), @@ -107,36 +48,10 @@ class PostgreSQLObjectMover(AbstractObjectMover): """ _update_current_update_query = None - _move_from_temp_hf_insert_raw_queries = ( - Unsupported("States accumulate in history-preserving mode"), - _move_from_temp_hf_insert_query_raw, - ) - - _prepare_move_from_temp_hf_insert_queries = to_prepared_queries( - 'move_from_temp', - _move_from_temp_hf_insert_raw_queries, - ('BIGINT',) - ) - - _prepare_move_from_temp_hf_insert_query = query_property( - '_prepare_move_from_temp_hf_insert') - - _move_from_temp_hf_insert_queries = ( - Unsupported("States accumulate in history-preserving mode"), - 'EXECUTE move_from_temp(%s)' - ) - - _move_from_temp_hf_insert_query = query_property('_move_from_temp_hf_insert') # We upsert, no need _move_from_temp_hf_delete_query = '' - on_load_opened_statement_names = ('_prepare_load_current_query',) - on_store_opened_statement_names = on_load_opened_statement_names + ( - '_prepare_detect_conflict_query', - '_prepare_move_from_temp_hf_insert_query', - ) - @metricmethod_sampled def on_store_opened(self, cursor, restart=False): @@ -175,6 +90,9 @@ def on_store_opened(self, cursor, restart=False): """, ] + # XXX: we're not preparing statements anymore until just + # before we want to use them, so this is no longer needed. + # For some reason, preparing the INSERT statement also # wants to acquire a lock. If we're committing in another # transaction, this can block indefinitely (if that other @@ -193,7 +111,7 @@ def on_store_opened(self, cursor, restart=False): # this to 100, but under high concurrency (10 processes) # that turned out to be laughably optimistic. We might # actually need to go as high as the commit lock timeout. - cursor.execute('SET lock_timeout = 10000') + # cursor.execute('SET lock_timeout = 10000') for stmt in ddl_stmts: cursor.execute(stmt) @@ -334,19 +252,6 @@ def store_temps(self, cursor, state_oid_tid_iter): cursor.copy_expert(buf.COPY_COMMAND, buf) -class PG8000ObjectMover(PostgreSQLObjectMover): - # Delete the statements that need paramaters. - on_load_opened_statement_names = () - on_store_opened_statement_names = ('_prepare_detect_conflict_query',) - - _load_current_query = AbstractObjectMover._load_current_query - - _move_from_temp_hf_insert_queries = ( - Unsupported("States accumulate in history-preserving mode"), - PostgreSQLObjectMover._move_from_temp_hf_insert_query_raw - ) - - class TempStoreCopyBuffer(io.BufferedIOBase): """ A binary file-like object for putting data into diff --git a/src/relstorage/adapters/postgresql/tests/test_mover.py b/src/relstorage/adapters/postgresql/tests/test_mover.py index be930cc0..96f59d26 100644 --- a/src/relstorage/adapters/postgresql/tests/test_mover.py +++ b/src/relstorage/adapters/postgresql/tests/test_mover.py @@ -16,7 +16,7 @@ from __future__ import division from __future__ import print_function -from ZODB.POSException import Unsupported +import unittest from relstorage.tests import TestCase from relstorage.tests import MockOptions @@ -26,8 +26,9 @@ class MockDriver(object): pass +@unittest.skip("Needs moved to test__sql") class TestFunctions(TestCase): - + # pylint:disable=no-member def _prepare1(self, query, name='prepped', datatypes=()): return mover.to_prepared_queries(name, [query], datatypes)[0] @@ -93,23 +94,8 @@ def _makeOne(self, **options): def test_prep_statements_hf(self): inst = self._makeOne(keep_history=False) - self.assertEqual( - inst._move_from_temp_hf_insert_query, - self._expected_move_from_temp_hf_insert_query + self.assertTrue( + str(inst._move_from_temp_hf_insert_query).startswith( + 'EXECUTE rs_prep_stmt' + ) ) - - def test_prep_statements_hp(self): - inst = self._makeOne(keep_history=True) - with self.assertRaises(Unsupported): - getattr(inst, '_move_from_temp_hf_insert_query') - - -class TestPG8000ObjectMover(TestPostgreSQLObjectMover): - - def setUp(self): - super(TestPG8000ObjectMover, self).setUp() - raw = mover.PostgreSQLObjectMover._move_from_temp_hf_insert_query_raw - self._expected_move_from_temp_hf_insert_query = raw - - def _getClass(self): - return mover.PG8000ObjectMover diff --git a/src/relstorage/adapters/postgresql/tests/test_txncontrol.py b/src/relstorage/adapters/postgresql/tests/test_txncontrol.py index f556efd5..0d32b565 100644 --- a/src/relstorage/adapters/postgresql/tests/test_txncontrol.py +++ b/src/relstorage/adapters/postgresql/tests/test_txncontrol.py @@ -18,22 +18,9 @@ from relstorage.adapters.tests import test_txncontrol -class MockDriver(object): - Binary = bytes class TestPostgreSQLTransactionControl(test_txncontrol.TestTransactionControl): - def setUp(self): - super(TestPostgreSQLTransactionControl, self).setUp() - - driver = MockDriver() - driver.Binary = super(TestPostgreSQLTransactionControl, self).Binary - self.Binary = driver - - def tearDown(self): - del self.Binary - super(TestPostgreSQLTransactionControl, self).tearDown() - def _getClass(self): from ..txncontrol import PostgreSQLTransactionControl return PostgreSQLTransactionControl diff --git a/src/relstorage/adapters/postgresql/txncontrol.py b/src/relstorage/adapters/postgresql/txncontrol.py index ed9a51b4..bef46563 100644 --- a/src/relstorage/adapters/postgresql/txncontrol.py +++ b/src/relstorage/adapters/postgresql/txncontrol.py @@ -15,43 +15,15 @@ from __future__ import absolute_import -from ZODB.POSException import Unsupported from ..txncontrol import GenericTransactionControl -from .._util import query_property -from .mover import to_prepared_queries class _PostgreSQLTransactionControl(GenericTransactionControl): - - # See adapter.py for where this is prepared. - # Either history preserving or not, it's the same. - _get_tid_query = 'EXECUTE get_latest_tid' - - - def __init__(self, connmanager, keep_history, driver): - super(_PostgreSQLTransactionControl, self).__init__( - connmanager, - keep_history, - driver.Binary - ) + pass class PostgreSQLTransactionControl(_PostgreSQLTransactionControl): - + # TODO: Get the data types applied at compile time + # to add_transaction_query # (tid, packed, username, description, extension) - _add_transaction_query = 'EXECUTE add_transaction(%s, %s, %s, %s, %s)' - - _prepare_add_transaction_queries = to_prepared_queries( - 'add_transaction', - [ - GenericTransactionControl._add_transaction_query, - Unsupported("No transactions in HF mode"), - ], - ('BIGINT', 'BOOLEAN', 'BYTEA', 'BYTEA', 'BYTEA') - ) - - _prepare_add_transaction_query = query_property('_prepare_add_transaction') - - -class PG8000TransactionControl(_PostgreSQLTransactionControl): - # We can't handle the parameterized prepared statements. + # ('BIGINT', 'BOOLEAN', 'BYTEA', 'BYTEA', 'BYTEA') pass diff --git a/src/relstorage/adapters/schema.py b/src/relstorage/adapters/schema.py index 5c1a89f2..17819617 100644 --- a/src/relstorage/adapters/schema.py +++ b/src/relstorage/adapters/schema.py @@ -25,12 +25,72 @@ from ._util import query_property from ._util import noop_when_history_free +from ._sql import Table +from ._sql import TemporaryTable +from ._sql import Column +from ._sql import HistoryVariantTable + + log = logging.getLogger("relstorage") tmpl_property = partial(query_property, property_suffix='_TMPLS', lazy_suffix='_TMPL') +class Schema(object): + current_object = Table( + 'current_object', + Column('zoid'), + Column('tid') + ) + + object_state = Table( + 'object_state', + Column('zoid'), + Column('tid'), + Column('state'), + Column('state_size'), + ) + + # Does the right thing whether history free or preserving + all_current_object = HistoryVariantTable( + current_object, + object_state, + ) + + # Does the right thing whether history free or preserving + all_current_object_state = HistoryVariantTable( + current_object.natural_join(object_state), + object_state + ) + + temp_store = TemporaryTable( + 'temp_store', + Column('zoid'), + Column('prev_tid'), + Column('md5'), + Column('state') + ) + + transaction = Table( + 'transaction', + Column('tid'), + Column('packed'), + Column('username'), + Column('description'), + Column('extension'), + ) + + commit_row_lock = Table( + 'commit_row_lock', + Column('tid'), + ) + + all_transaction = HistoryVariantTable( + transaction, + object_state, + ) + class AbstractSchemaInstaller(DatabaseHelpersMixin, ABC): diff --git a/src/relstorage/adapters/tests/test__sql.py b/src/relstorage/adapters/tests/test__sql.py new file mode 100644 index 00000000..b56b8ac4 --- /dev/null +++ b/src/relstorage/adapters/tests/test__sql.py @@ -0,0 +1,189 @@ +# -*- coding: utf-8 -*- +############################################################################## +# +# Copyright (c) 2019 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## + +""" +Tests for the SQL abstraction layer. + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from relstorage.tests import TestCase + +from .._sql import Table +from .._sql import HistoryVariantTable +from .._sql import Column +from .._sql import bindparam + +current_object = Table( + 'current_object', + Column('zoid'), + Column('tid') +) + +object_state = Table( + 'object_state', + Column('zoid'), + Column('tid'), + Column('state'), +) + +hp_object_and_state = current_object.natural_join(object_state) + +objects = HistoryVariantTable( + current_object, + object_state, +) + +object_and_state = HistoryVariantTable( + hp_object_and_state, + object_state +) + +class TestTableSelect(TestCase): + + def test_simple_eq_select(self): + table = object_state + + stmt = table.select().where(table.c.zoid == table.c.tid) + + self.assertEqual( + str(stmt), + 'SELECT zoid, tid, state FROM object_state WHERE (zoid = tid)' + ) + + def test_simple_eq_select_and(self): + + table = object_state + + stmt = table.select().where(table.c.zoid == table.c.tid) + + self.assertEqual( + str(stmt), + 'SELECT zoid, tid, state FROM object_state WHERE (zoid = tid)' + ) + + stmt = stmt.and_(table.c.zoid > 5) + self.assertEqual( + str(stmt), + 'SELECT zoid, tid, state FROM object_state WHERE ((zoid = tid AND zoid > %(param_0)s))' + ) + + def test_simple_eq_select_literal(self): + table = object_state + + # This is a useless query + stmt = table.select().where(table.c.zoid == 7) + + self.assertEqual( + str(stmt), + 'SELECT zoid, tid, state FROM object_state WHERE (zoid = %(param_0)s)' + ) + + self.assertEqual( + stmt.compiled().params, + {'param_0': 7}) + + def test_column_query_variant_table(self): + stmt = objects.select(objects.c.tid, objects.c.zoid).where( + objects.c.tid > bindparam('tid') + ) + + self.assertEqual( + str(stmt), + 'SELECT tid, zoid FROM current_object WHERE (tid > %(tid)s)' + ) + + def test_natural_join(self): + + stmt = object_and_state.select( + object_and_state.c.zoid, object_and_state.c.state + ).where( + object_and_state.c.zoid == object_and_state.bindparam('oid') + ) + + self.assertEqual( + str(stmt), + 'SELECT zoid, state ' + 'FROM current_object ' + 'JOIN object_state ' + 'USING (zoid, tid) WHERE (zoid = %(oid)s)' + ) + + class H(object): + keep_history = False + + stmt = stmt.bind(H()) + + self.assertEqual( + str(stmt), + 'SELECT zoid, state ' + 'FROM object_state ' + 'WHERE (zoid = %(oid)s)' + ) + + def test_bind(self): + select = objects.select(objects.c.tid, objects.c.zoid).where( + objects.c.tid > bindparam('tid') + ) + # Unbound we assume history + self.assertEqual( + str(select), + 'SELECT tid, zoid FROM current_object WHERE (tid > %(tid)s)' + ) + from .._sql import _DefaultDialect + select = select.bind(42) + context = _DefaultDialect(42) + + self.assertEqual(select.context, context) + self.assertEqual(select.table.context, context) + self.assertEqual(select._where.context, context) + self.assertEqual(select._where.expression.context, context) + # Bound to the wrong thing we assume history + self.assertEqual( + str(select), + 'SELECT tid, zoid FROM current_object WHERE (tid > %(tid)s)' + ) + + # Bound to history-free we use history free + class H(object): + keep_history = False + select = select.bind(H()) + + self.assertEqual( + str(select), + 'SELECT tid, zoid FROM object_state WHERE (tid > %(tid)s)' + ) + + def test_bind_descriptor(self): + class Context(object): + keep_history = True + select = objects.select(objects.c.tid, objects.c.zoid).where( + objects.c.tid > bindparam('tid') + ) + + # Unbound we assume history + self.assertEqual( + str(Context.select), + 'SELECT tid, zoid FROM current_object WHERE (tid > %(tid)s)' + ) + + context = Context() + context.keep_history = False + self.assertEqual( + str(context.select), + 'SELECT tid, zoid FROM object_state WHERE (tid > %(tid)s)' + ) diff --git a/src/relstorage/adapters/tests/test_txncontrol.py b/src/relstorage/adapters/tests/test_txncontrol.py index 97b67841..1e5d5d66 100644 --- a/src/relstorage/adapters/tests/test_txncontrol.py +++ b/src/relstorage/adapters/tests/test_txncontrol.py @@ -18,12 +18,9 @@ from relstorage.tests import TestCase from relstorage.tests import MockConnection +from relstorage.tests import MockConnectionManager from relstorage.tests import MockCursor - -class MockConnmanager(object): - - def rollback(self, conn, _cursor): - conn.rollback() +from relstorage.tests import MockPoller class TestTransactionControl(TestCase): @@ -37,29 +34,8 @@ def Binary(self, arg): return arg def _makeOne(self, keep_history=True, binary=None): - return self._getClass()(MockConnmanager(), keep_history, binary or self.Binary) - - def _get_hf_tid_query(self): - return self._getClass()._get_tid_queries[1] - - def _get_hp_tid_query(self): - return self._getClass()._get_tid_queries[0] - - def _check_get_tid_query(self, keep_history, expected_query): - inst = self._makeOne(keep_history) - cur = MockCursor() - cur.results = [(1,)] - - inst.get_tid(cur) - - self.assertEqual(cur.executed.pop(), - (expected_query, None)) - - def test_get_tid_hf(self): - self._check_get_tid_query(False, self._get_hf_tid_query()) - - def test_get_tid_hp(self): - self._check_get_tid_query(True, self._get_hp_tid_query()) + return self._getClass()(MockConnectionManager(), MockPoller(), + keep_history, binary or self.Binary) def test_get_tid_empty_db(self): inst = self._makeOne() @@ -71,12 +47,11 @@ def test_get_tid_empty_db(self): def test_add_transaction_hp(self): inst = self._makeOne() cur = MockCursor() - + __traceback_info__ = inst.__dict__ inst.add_transaction(cur, 1, u'user', u'desc', u'ext') - self.assertEqual( cur.executed.pop(), - (inst._add_transaction_query, + (str(inst._add_transaction_query), (1, False, b'user', b'desc', b'ext')) ) @@ -84,7 +59,7 @@ def test_add_transaction_hp(self): self.assertEqual( cur.executed.pop(), - (inst._add_transaction_query, + (str(inst._add_transaction_query), (1, True, b'user', b'desc', b'ext')) ) diff --git a/src/relstorage/adapters/txncontrol.py b/src/relstorage/adapters/txncontrol.py index 7952a21d..ec1067eb 100644 --- a/src/relstorage/adapters/txncontrol.py +++ b/src/relstorage/adapters/txncontrol.py @@ -21,7 +21,8 @@ from .._compat import ABC from ._util import noop_when_history_free -from ._util import query_property + +from .schema import Schema from .interfaces import ITransactionControl @@ -73,19 +74,14 @@ class GenericTransactionControl(AbstractTransactionControl): and history-preserving storages that share a common syntax. """ - _get_tid_queries = ( - "SELECT MAX(tid) FROM transaction", - "SELECT MAX(tid) FROM object_state" - ) - _get_tid_query = query_property('_get_tid') - - def __init__(self, connmanager, keep_history, Binary): + def __init__(self, connmanager, poller, keep_history, Binary): super(GenericTransactionControl, self).__init__(connmanager) + self.poller = poller self.keep_history = keep_history self.Binary = Binary def get_tid(self, cursor): - cursor.execute(self._get_tid_query) + self.poller.poll_query.execute(cursor) row = cursor.fetchall() if not row: # nothing has been stored yet @@ -94,15 +90,22 @@ def get_tid(self, cursor): tid = row[0][0] return tid if tid is not None else 0 - _add_transaction_query = """ - INSERT INTO transaction (tid, packed, username, description, extension) - VALUES (%s, %s, %s, %s, %s) - """ + _add_transaction_query = Schema.transaction.insert( + Schema.transaction.c.tid, + Schema.transaction.c.packed, + Schema.transaction.c.username, + Schema.transaction.c.description, + Schema.transaction.c.extension + ).prepared() @noop_when_history_free def add_transaction(self, cursor, tid, username, description, extension, packed=False): binary = self.Binary - cursor.execute(self._add_transaction_query, ( - tid, packed, binary(username), - binary(description), binary(extension))) + self._add_transaction_query.execute( + cursor, + ( + tid, packed, binary(username), + binary(description), binary(extension) + ) + ) diff --git a/src/relstorage/storage/__init__.py b/src/relstorage/storage/__init__.py index 0559fa21..fb6b1d16 100644 --- a/src/relstorage/storage/__init__.py +++ b/src/relstorage/storage/__init__.py @@ -470,7 +470,9 @@ def afterCompletion(self): # abort. # The next time we use the load connection, it will need to poll - # and will call our _on_load_activated. + # and will call our __on_first_use. + # Typically our next call from the ZODB Connection will be from its + # `newTransaction` method, a forced `sync` followed by `poll_invalidations`. # TODO: Why doesn't this use connmanager.restart_load()? # They both rollback; the difference is that restart_load checks for replicas, diff --git a/src/relstorage/tests/__init__.py b/src/relstorage/tests/__init__.py index ee954788..981d7a9d 100644 --- a/src/relstorage/tests/__init__.py +++ b/src/relstorage/tests/__init__.py @@ -251,10 +251,12 @@ class MockConnectionManager(object): disconnected_exceptions = () - def rollback(self, conn, cursor): - "Does nothing" + def rollback(self, conn, cursor): # pylint:disable=unused-argument + if hasattr(conn, 'rollback'): + conn.rollback() def rollback_and_close(self, conn, cursor): + self.rollback(conn, cursor) if conn: conn.close() if cursor: @@ -283,3 +285,15 @@ def __init__(self): self.connmanager = MockConnectionManager() self.packundo = MockPackUndo() self.oidallocator = MockOIDAllocator() + +class MockQuery(object): + + def __init__(self, raw): + self.raw = raw + + def execute(self, cursor, params=None): + cursor.execute(self.raw, params) + +class MockPoller(object): + + poll_query = MockQuery('SELECT MAX(tid) FROM object_state') From 5d13cd37d423c0e32cbfb59a30a7cba6c4777ae4 Mon Sep 17 00:00:00 2001 From: Jason Madden Date: Fri, 19 Jul 2019 07:12:48 -0500 Subject: [PATCH 2/8] Do a better job finding the dialect we want and don't silently hide errors if we can't. --- src/relstorage/adapters/_sql.py | 90 ++++++++++--------- src/relstorage/adapters/interfaces.py | 10 +++ .../adapters/mysql/drivers/__init__.py | 8 +- .../adapters/postgresql/drivers/__init__.py | 10 +++ .../adapters/postgresql/drivers/pg8000.py | 13 ++- .../adapters/postgresql/drivers/psycopg2.py | 4 +- .../adapters/postgresql/tests/test_mover.py | 3 +- src/relstorage/adapters/tests/test__sql.py | 32 ++++--- src/relstorage/tests/__init__.py | 24 +++-- 9 files changed, 126 insertions(+), 68 deletions(-) diff --git a/src/relstorage/adapters/_sql.py b/src/relstorage/adapters/_sql.py index e74d5686..45be641d 100644 --- a/src/relstorage/adapters/_sql.py +++ b/src/relstorage/adapters/_sql.py @@ -188,67 +188,75 @@ def bindparam(self, key): def orderedbindparam(self): return orderedbindparam() -class _DefaultDialect(object): +class DefaultDialect(object): - def __init__(self, base): - self._base = base + keep_history = True - def __getattr__(self, name): - return getattr(self._base, name) - - def __bool__(self): - return False - - __nonzero__ = __bool__ - - _driver_locations = ( - attrgetter('_base.driver'), - attrgetter('_base.poller.driver'), - attrgetter('_base.connmanager.driver'), - attrgetter('_base.adapter.driver') - ) + def bind(self, context): + # The context will reference us most likely + # (compiled statement in instance dictionary) + # so try to avoid reference cycles. + keep_history = context.keep_history + new = copy(self) + new.keep_history = keep_history + return new def compiler_class(self): - # We want to find *something* with a driver. - # Preferably the object we're attached to, but if not that, - # we'll look at some common attributes for adapter objects - # for it. - for getter in self._driver_locations: - try: - return getter(self).sql_compiler_class - except AttributeError: - pass return _Compiler def compiler(self, root): return self.compiler_class()(root) def __eq__(self, other): - if isinstance(other, _DefaultDialect): - return self._base == other._base + if isinstance(other, DefaultDialect): + return other.keep_history == self.keep_history return NotImplemented - def __repr__(self): - return '<%s at %x base=%r>' % ( - type(self), - id(self), - self._base - ) + +class _MissingDialect(DefaultDialect): + def __bool__(self): + return False + + __nonzero__ = __bool__ + class _Bindable(object): - context = _DefaultDialect(None) + context = _MissingDialect() + + _dialect_locations = ( + attrgetter('dialect'), + attrgetter('driver.dialect'), + attrgetter('poller.driver.dialect'), + attrgetter('connmanager.driver.dialect'), + attrgetter('adapter.driver.dialect') + ) def _find_dialect(self, context): - # Look up the database type, find the right dialect. - # Ordinarily we want to use the database driver. - if isinstance(context, _DefaultDialect): + # Find the dialect to use for the context. If it specifies + # one, then use it. Otherwise go hunting for the database + # driver and use *it*. Preferably the driver is attached to + # the object we're looking at, but if not that, we'll look at + # some common attributes for adapter objects for it. + if isinstance(context, DefaultDialect): return context - return _DefaultDialect(context) + + for getter in self._dialect_locations: + try: + dialect = getter(context) + except AttributeError: + pass + else: + return dialect.bind(context) + __traceback_info__ = vars(context) + raise TypeError("Unable to bind to %s; no dialect found" % (context,)) def bind(self, context): - new = copy(self) context = self._find_dialect(context) + if context is None: + return self + + new = copy(self) new.context = context bound_replacements = { k: v.bind(context) @@ -302,7 +310,7 @@ def history_free(self): return self.rhs def __compile_visit__(self, compiler): - keep_history = getattr(self.context, 'keep_history', True) + keep_history = self.context.keep_history node = self.history_preserving if keep_history else self.history_free return compiler.visit(node) diff --git a/src/relstorage/adapters/interfaces.py b/src/relstorage/adapters/interfaces.py index 2b8cc7cf..543c93b6 100644 --- a/src/relstorage/adapters/interfaces.py +++ b/src/relstorage/adapters/interfaces.py @@ -50,6 +50,14 @@ def __str__(): """Return a short description of the adapter""" +class IDBDialect(Interface): + """ + Handles converting from our internal "standard" SQL queries to + something database specific. + """ + + # TODO: Fill this in. + class IDBDriver(Interface): """ An abstraction over the information needed for RelStorage to work @@ -71,6 +79,8 @@ class IDBDriver(Interface): Binary = Attribute("A callable.") + dialect = Attribute("The IDBDialect for this driver.") + def binary_column_as_state_type(db_column_data): """ Turn *db_column_data* into something that's a valid pickle diff --git a/src/relstorage/adapters/mysql/drivers/__init__.py b/src/relstorage/adapters/mysql/drivers/__init__.py index 1110f7e8..77557e8c 100644 --- a/src/relstorage/adapters/mysql/drivers/__init__.py +++ b/src/relstorage/adapters/mysql/drivers/__init__.py @@ -21,6 +21,7 @@ from ..._abstract_drivers import AbstractModuleDriver from ..._abstract_drivers import implement_db_driver_options from ..._sql import _Compiler +from ..._sql import DefaultDialect database_type = 'mysql' @@ -40,6 +41,11 @@ def _prepared_param(self, number): def _quote_query_for_prepare(self, query): return '"{query}"'.format(query=query) +class MySQLDialect(DefaultDialect): + + def compiler_class(self): + return MySQLCompiler + class AbstractMySQLDriver(AbstractModuleDriver): # Don't try to decode pickle states as UTF-8 (or whatever the @@ -92,7 +98,7 @@ def callproc_multi_result(self, cursor, proc, args=()): return multi_results - sql_compiler_class = MySQLCompiler + dialect = MySQLDialect() implement_db_driver_options( diff --git a/src/relstorage/adapters/postgresql/drivers/__init__.py b/src/relstorage/adapters/postgresql/drivers/__init__.py index 04c2e12b..ba6ad737 100644 --- a/src/relstorage/adapters/postgresql/drivers/__init__.py +++ b/src/relstorage/adapters/postgresql/drivers/__init__.py @@ -22,6 +22,16 @@ from __future__ import print_function from ..._abstract_drivers import implement_db_driver_options +from ..._abstract_drivers import AbstractModuleDriver +from ..._sql import DefaultDialect + +class PostgreSQLDialect(DefaultDialect): + """ + The defaults are setup for PostgreSQL. + """ + +class AbstractPostgreSQLDriver(AbstractModuleDriver): + dialect = PostgreSQLDialect() database_type = 'postgresql' diff --git a/src/relstorage/adapters/postgresql/drivers/pg8000.py b/src/relstorage/adapters/postgresql/drivers/pg8000.py index a7b9c83d..2b5957b3 100644 --- a/src/relstorage/adapters/postgresql/drivers/pg8000.py +++ b/src/relstorage/adapters/postgresql/drivers/pg8000.py @@ -23,10 +23,12 @@ from zope.interface import implementer -from ..._abstract_drivers import AbstractModuleDriver from ...interfaces import IDBDriver from ..._sql import _Compiler +from . import AbstractPostgreSQLDriver +from . import PostgreSQLDialect + __all__ = [ 'PG8000Driver', ] @@ -126,9 +128,14 @@ def can_prepare(self): # https://github.com/mfenniak/pg8000/issues/132 return False +class PG8000Dialect(PostgreSQLDialect): + + def compiler_class(self): + return PG8000Compiler + @implementer(IDBDriver) -class PG8000Driver(AbstractModuleDriver): +class PG8000Driver(AbstractPostgreSQLDriver): __name__ = 'pg8000' MODULE_NAME = __name__ PRIORITY = 3 @@ -137,6 +144,8 @@ class PG8000Driver(AbstractModuleDriver): _GEVENT_CAPABLE = True _GEVENT_NEEDS_SOCKET_PATCH = True + dialect = PG8000Dialect() + def __init__(self): super(PG8000Driver, self).__init__() # XXX: global side-effect! diff --git a/src/relstorage/adapters/postgresql/drivers/psycopg2.py b/src/relstorage/adapters/postgresql/drivers/psycopg2.py index 30e1dca2..b7776fd0 100644 --- a/src/relstorage/adapters/postgresql/drivers/psycopg2.py +++ b/src/relstorage/adapters/postgresql/drivers/psycopg2.py @@ -22,8 +22,8 @@ from zope.interface import implementer from relstorage._compat import PY3 -from ..._abstract_drivers import AbstractModuleDriver from ...interfaces import IDBDriver +from . import AbstractPostgreSQLDriver __all__ = [ @@ -32,7 +32,7 @@ @implementer(IDBDriver) -class Psycopg2Driver(AbstractModuleDriver): +class Psycopg2Driver(AbstractPostgreSQLDriver): __name__ = 'psycopg2' MODULE_NAME = __name__ diff --git a/src/relstorage/adapters/postgresql/tests/test_mover.py b/src/relstorage/adapters/postgresql/tests/test_mover.py index 96f59d26..d2455f11 100644 --- a/src/relstorage/adapters/postgresql/tests/test_mover.py +++ b/src/relstorage/adapters/postgresql/tests/test_mover.py @@ -20,11 +20,10 @@ from relstorage.tests import TestCase from relstorage.tests import MockOptions +from relstorage.tests import MockDriver from .. import mover -class MockDriver(object): - pass @unittest.skip("Needs moved to test__sql") class TestFunctions(TestCase): diff --git a/src/relstorage/adapters/tests/test__sql.py b/src/relstorage/adapters/tests/test__sql.py index b56b8ac4..8f728197 100644 --- a/src/relstorage/adapters/tests/test__sql.py +++ b/src/relstorage/adapters/tests/test__sql.py @@ -27,6 +27,7 @@ from .._sql import HistoryVariantTable from .._sql import Column from .._sql import bindparam +from .._sql import DefaultDialect current_object = Table( 'current_object', @@ -108,7 +109,6 @@ def test_column_query_variant_table(self): ) def test_natural_join(self): - stmt = object_and_state.select( object_and_state.c.zoid, object_and_state.c.state ).where( @@ -125,6 +125,7 @@ def test_natural_join(self): class H(object): keep_history = False + dialect = DefaultDialect() stmt = stmt.bind(H()) @@ -144,24 +145,28 @@ def test_bind(self): str(select), 'SELECT tid, zoid FROM current_object WHERE (tid > %(tid)s)' ) - from .._sql import _DefaultDialect - select = select.bind(42) - context = _DefaultDialect(42) - - self.assertEqual(select.context, context) - self.assertEqual(select.table.context, context) - self.assertEqual(select._where.context, context) - self.assertEqual(select._where.expression.context, context) - # Bound to the wrong thing we assume history + + class Context(object): + dialect = DefaultDialect() + keep_history = True + + context = Context() + dialect = context.dialect + select = select.bind(context) + + self.assertEqual(select.context, dialect) + self.assertEqual(select.table.context, dialect) + self.assertEqual(select._where.context, dialect) + self.assertEqual(select._where.expression.context, dialect) + # We take up its history setting self.assertEqual( str(select), 'SELECT tid, zoid FROM current_object WHERE (tid > %(tid)s)' ) # Bound to history-free we use history free - class H(object): - keep_history = False - select = select.bind(H()) + context.keep_history = False + select = select.bind(context) self.assertEqual( str(select), @@ -171,6 +176,7 @@ class H(object): def test_bind_descriptor(self): class Context(object): keep_history = True + dialect = DefaultDialect() select = objects.select(objects.c.tid, objects.c.zoid).where( objects.c.tid > bindparam('tid') ) diff --git a/src/relstorage/tests/__init__.py b/src/relstorage/tests/__init__.py index 981d7a9d..edcff3c6 100644 --- a/src/relstorage/tests/__init__.py +++ b/src/relstorage/tests/__init__.py @@ -11,6 +11,7 @@ from relstorage._compat import ABC from relstorage.options import Options +from relstorage.adapters._sql import DefaultDialect try: from unittest import mock @@ -279,13 +280,6 @@ class MockPackUndo(object): class MockOIDAllocator(object): pass -class MockAdapter(object): - - def __init__(self): - self.connmanager = MockConnectionManager() - self.packundo = MockPackUndo() - self.oidallocator = MockOIDAllocator() - class MockQuery(object): def __init__(self, raw): @@ -297,3 +291,19 @@ def execute(self, cursor, params=None): class MockPoller(object): poll_query = MockQuery('SELECT MAX(tid) FROM object_state') + + def __init__(self, driver=None): + self.driver = driver or MockDriver() + +class MockDriver(object): + + dialect = DefaultDialect() + +class MockAdapter(object): + + def __init__(self): + self.driver = MockDriver() + self.connmanager = MockConnectionManager() + self.packundo = MockPackUndo() + self.oidallocator = MockOIDAllocator() + self.poller = MockPoller(self.driver) From 923689d379b49cec405bb38182d5e67995b26337 Mon Sep 17 00:00:00 2001 From: Jason Madden Date: Fri, 19 Jul 2019 07:59:38 -0500 Subject: [PATCH 3/8] Add parameter types for INSERT VALUES prepared statements. --- src/relstorage/adapters/_sql.py | 80 ++++++++++++++++++- .../adapters/postgresql/txncontrol.py | 4 - src/relstorage/adapters/schema.py | 31 ++++--- src/relstorage/adapters/tests/test__sql.py | 34 ++++++-- src/relstorage/tests/__init__.py | 7 +- 5 files changed, 129 insertions(+), 27 deletions(-) diff --git a/src/relstorage/adapters/_sql.py b/src/relstorage/adapters/_sql.py index 45be641d..43ad584f 100644 --- a/src/relstorage/adapters/_sql.py +++ b/src/relstorage/adapters/_sql.py @@ -42,8 +42,11 @@ from operator import attrgetter from weakref import WeakKeyDictionary +from zope.interface import implementer + from relstorage._compat import NStringIO from relstorage._util import CachedIn +from .interfaces import IDBDialect def copy(obj): new = stdlib_copy(obj) @@ -52,13 +55,51 @@ def copy(obj): delattr(new, k) return new +class Type(object): + """ + A database type. + """ + +class Integer64(Type): + """ + A 64-bit integer. + """ + +class OID(Integer64): + """ + Type of an OID. + """ + +class TID(Integer64): + """ + Type of a TID. + """ + +class BinaryString(Type): + """ + Arbitrary sized binary string. + """ + +class State(Type): + """ + Used for storing object state. + """ + +class Boolean(Type): + """ + A two-value column. + """ + class Column(object): """ Defines a column in a table. """ - def __init__(self, name): + def __init__(self, name, type_=None, primary_key=False, nullable=True): self.name = name + self.type_ = type_ + self.primary_key = primary_key + self.nullable = False if primary_key else nullable def __str__(self): return self.name @@ -188,10 +229,19 @@ def bindparam(self, key): def orderedbindparam(self): return orderedbindparam() +@implementer(IDBDialect) class DefaultDialect(object): keep_history = True + datatype_map = { + OID: 'BIGINT', + TID: 'BIGINT', + BinaryString: 'BYTEA', + State: 'BYTEA', + Boolean: 'BOOLEAN', + } + def bind(self, context): # The context will reference us most likely # (compiled statement in instance dictionary) @@ -207,6 +257,14 @@ def compiler_class(self): def compiler(self, root): return self.compiler_class()(root) + def datatypes_for_columns(self, column_list): + columns = list(column_list) + datatypes = [] + for column in columns: + datatype = self.datatype_map[column.type_] + datatypes.append(datatype) + return datatypes + def __eq__(self, other): if isinstance(other, DefaultDialect): return other.keep_history == self.keep_history @@ -439,13 +497,27 @@ def _prepared_param(self, number): def _quote_query_for_prepare(self, query): return query + def _find_datatypes_for_prepared_query(self): + # Deduce the datatypes based on the types of the columns + # we're sending as params. + if isinstance(self.root, Insert) and self.root.values and self.root.column_list: + # If we're sending in a list of values, those have to + # exactly match the columns, so we can easily get a list + # of datatypes. + # + # TODO: We should be able to do this for an `INSERT (col) SELECT $1` too, + # by matching the parameter to the column name. + # TODO: Should probably delegate this to the node. + column_list = self.root.column_list + datatypes = self.root.context.datatypes_for_columns(column_list) + return datatypes + return () + def prepare(self): # This is correct for PostgreSQL. This needs moved to a dialect specific # spot. - # TODO: Deduce the datatypes based on the types of the columns - # we're sending as params. - datatypes = {} + datatypes = self._find_datatypes_for_prepared_query() query = self.buf.getvalue() name = self._next_prepared_stmt_name() diff --git a/src/relstorage/adapters/postgresql/txncontrol.py b/src/relstorage/adapters/postgresql/txncontrol.py index bef46563..93211060 100644 --- a/src/relstorage/adapters/postgresql/txncontrol.py +++ b/src/relstorage/adapters/postgresql/txncontrol.py @@ -22,8 +22,4 @@ class _PostgreSQLTransactionControl(GenericTransactionControl): class PostgreSQLTransactionControl(_PostgreSQLTransactionControl): - # TODO: Get the data types applied at compile time - # to add_transaction_query - # (tid, packed, username, description, extension) - # ('BIGINT', 'BOOLEAN', 'BYTEA', 'BYTEA', 'BYTEA') pass diff --git a/src/relstorage/adapters/schema.py b/src/relstorage/adapters/schema.py index 17819617..16cf36ec 100644 --- a/src/relstorage/adapters/schema.py +++ b/src/relstorage/adapters/schema.py @@ -29,6 +29,11 @@ from ._sql import TemporaryTable from ._sql import Column from ._sql import HistoryVariantTable +from ._sql import OID +from ._sql import TID +from ._sql import State +from ._sql import Boolean +from ._sql import BinaryString log = logging.getLogger("relstorage") @@ -40,15 +45,15 @@ class Schema(object): current_object = Table( 'current_object', - Column('zoid'), - Column('tid') + Column('zoid', OID), + Column('tid', TID) ) object_state = Table( 'object_state', - Column('zoid'), - Column('tid'), - Column('state'), + Column('zoid', OID), + Column('tid', TID), + Column('state', State), Column('state_size'), ) @@ -66,19 +71,19 @@ class Schema(object): temp_store = TemporaryTable( 'temp_store', - Column('zoid'), - Column('prev_tid'), + Column('zoid', OID), + Column('prev_tid', TID), Column('md5'), - Column('state') + Column('state', State) ) transaction = Table( 'transaction', - Column('tid'), - Column('packed'), - Column('username'), - Column('description'), - Column('extension'), + Column('tid', TID), + Column('packed', Boolean), + Column('username', BinaryString), + Column('description', BinaryString), + Column('extension', BinaryString), ) commit_row_lock = Table( diff --git a/src/relstorage/adapters/tests/test__sql.py b/src/relstorage/adapters/tests/test__sql.py index 8f728197..27557431 100644 --- a/src/relstorage/adapters/tests/test__sql.py +++ b/src/relstorage/adapters/tests/test__sql.py @@ -28,18 +28,21 @@ from .._sql import Column from .._sql import bindparam from .._sql import DefaultDialect +from .._sql import OID +from .._sql import TID +from .._sql import State current_object = Table( 'current_object', - Column('zoid'), - Column('tid') + Column('zoid', OID), + Column('tid', TID) ) object_state = Table( 'object_state', - Column('zoid'), - Column('tid'), - Column('state'), + Column('zoid', OID), + Column('tid', TID), + Column('state', State), ) hp_object_and_state = current_object.natural_join(object_state) @@ -193,3 +196,24 @@ class Context(object): str(context.select), 'SELECT tid, zoid FROM object_state WHERE (tid > %(tid)s)' ) + + def test_prepared_insert_values(self): + stmt = current_object.insert( + current_object.c.zoid + ) + + self.assertEqual( + str(stmt), + 'INSERT INTO current_object(zoid) VALUES (%s)' + ) + + stmt = stmt.prepared() + self.assertTrue( + str(stmt).startswith('EXECUTE rs_prep_stmt') + ) + + stmt = stmt.compiled() + self.assertRegex( + stmt._prepare_stmt, + r"PREPARE rs_prep_stmt_[0-9]* \(BIGINT\) AS.*" + ) diff --git a/src/relstorage/tests/__init__.py b/src/relstorage/tests/__init__.py index edcff3c6..b0fffed7 100644 --- a/src/relstorage/tests/__init__.py +++ b/src/relstorage/tests/__init__.py @@ -30,12 +30,17 @@ class TestCase(unittest.TestCase): cleanups. """ # Avoid deprecation warnings; 2.7 doesn't have - # assertRaisesRegex + # assertRaisesRegex or assertRegex assertRaisesRegex = getattr( unittest.TestCase, 'assertRaisesRegex', None ) or getattr(unittest.TestCase, 'assertRaisesRegexp') + assertRegex = getattr( + unittest.TestCase, + 'assertRegex', + None + ) or getattr(unittest.TestCase, 'assertRegexpMatches') def setUp(self): super(TestCase, self).setUp() From 7ea3f22a69e43a5f6cade9b799a444e749aa88b2 Mon Sep 17 00:00:00 2001 From: Jason Madden Date: Fri, 19 Jul 2019 09:33:55 -0500 Subject: [PATCH 4/8] Handle types for 'INSERT SELECT %s' queries. --- src/relstorage/adapters/_sql.py | 45 +++++++++++++++------- src/relstorage/adapters/tests/test__sql.py | 27 ++++++++++++- 2 files changed, 58 insertions(+), 14 deletions(-) diff --git a/src/relstorage/adapters/_sql.py b/src/relstorage/adapters/_sql.py index 43ad584f..cbadfe95 100644 --- a/src/relstorage/adapters/_sql.py +++ b/src/relstorage/adapters/_sql.py @@ -45,6 +45,7 @@ from zope.interface import implementer from relstorage._compat import NStringIO +from relstorage._compat import intern from relstorage._util import CachedIn from .interfaces import IDBDialect @@ -168,6 +169,11 @@ def _col_list(self): def __compile_visit__(self, compiler): compiler.visit_csv(self._columns) + def has_bind_param(self): + return any( + isinstance(c, (_BindParam, _OrderedBindParam)) + for c in self._columns + ) _ColumnList = _Columns @@ -344,11 +350,11 @@ def __init__(self, lhs, rhs): def __compile_visit__(self, compiler): compiler.visit(self.lhs) - compiler.emit(' JOIN ') + compiler.emit_keyword('JOIN') compiler.visit(self.rhs) # careful with USING clause in a join: Oracle doesn't allow such # columns to have a prefix. - compiler.emit_keyword(' USING') + compiler.emit_keyword('USING') compiler.visit_grouped(self._join_columns) @@ -500,16 +506,28 @@ def _quote_query_for_prepare(self, query): def _find_datatypes_for_prepared_query(self): # Deduce the datatypes based on the types of the columns # we're sending as params. - if isinstance(self.root, Insert) and self.root.values and self.root.column_list: - # If we're sending in a list of values, those have to - # exactly match the columns, so we can easily get a list - # of datatypes. - # - # TODO: We should be able to do this for an `INSERT (col) SELECT $1` too, - # by matching the parameter to the column name. + if isinstance(self.root, Insert): + root = self.root + dialect = root.context # TODO: Should probably delegate this to the node. - column_list = self.root.column_list - datatypes = self.root.context.datatypes_for_columns(column_list) + if root.values and root.column_list: + # If we're sending in a list of values, those have to + # exactly match the columns, so we can easily get a list + # of datatypes. + column_list = root.column_list + datatypes = dialect.datatypes_for_columns(column_list) + elif root.select and root.select.column_list.has_bind_param(): + targets = root.column_list + sources = root.select.column_list + # TODO: This doesn't support bind params anywhere except the + # select list! + columns_with_params = [ + target + for target, source in zip(targets, sources) + if isinstance(source, _OrderedBindParam) + ] + assert len(self.placeholders) == len(columns_with_params) + datatypes = dialect.datatypes_for_columns(columns_with_params) return datatypes return () @@ -550,6 +568,7 @@ def prepare(self): conjunction=self._PREPARED_CONJUNCTION, ) + if placeholder_to_number: execute = 'EXECUTE {name}({params})'.format( name=name, @@ -574,10 +593,10 @@ def convert(d): params[ix - 1] = d[placeholder_name] return params - return stmt, execute, convert + return intern(stmt), intern(execute), convert def finalize(self): - return self.buf.getvalue().strip(), {v: k for k, v in self.placeholders.items()} + return intern(self.buf.getvalue().strip()), {v: k for k, v in self.placeholders.items()} def visit(self, node): node.__compile_visit__(self) diff --git a/src/relstorage/adapters/tests/test__sql.py b/src/relstorage/adapters/tests/test__sql.py index 27557431..1cb14e4a 100644 --- a/src/relstorage/adapters/tests/test__sql.py +++ b/src/relstorage/adapters/tests/test__sql.py @@ -49,7 +49,7 @@ objects = HistoryVariantTable( current_object, - object_state, + object_state, ) object_and_state = HistoryVariantTable( @@ -217,3 +217,28 @@ def test_prepared_insert_values(self): stmt._prepare_stmt, r"PREPARE rs_prep_stmt_[0-9]* \(BIGINT\) AS.*" ) + + def test_prepared_insert_select_with_param(self): + stmt = current_object.insert().from_select( + (current_object.c.zoid, + current_object.c.tid), + object_state.select( + object_state.c.zoid, + object_state.orderedbindparam() + ) + ) + self.assertEqual( + str(stmt), + 'INSERT INTO current_object(zoid, tid) SELECT zoid, %s FROM object_state' + ) + + stmt = stmt.prepared() + self.assertTrue( + str(stmt).startswith('EXECUTE rs_prep_stmt') + ) + + stmt = stmt.compiled() + self.assertRegex( + stmt._prepare_stmt, + r"PREPARE rs_prep_stmt_[0-9]* \(BIGINT\) AS.*" + ) From 26bdcf0ef9c4e2916b5cda5ba51fa1d6f025a1b1 Mon Sep 17 00:00:00 2001 From: Jason Madden Date: Fri, 19 Jul 2019 15:20:50 -0500 Subject: [PATCH 5/8] More structured statements. Begin an Oracle dialect to cover those differences, specifically boolean handling and named vs pyformat query args. --- src/relstorage/_compat.py | 2 + src/relstorage/adapters/_sql.py | 186 +++++++++++++++--- src/relstorage/adapters/dbiter.py | 167 +++++++++------- src/relstorage/adapters/mysql/adapter.py | 2 - src/relstorage/adapters/oracle/adapter.py | 2 - src/relstorage/adapters/oracle/dialect.py | 49 +++++ src/relstorage/adapters/oracle/drivers.py | 2 + src/relstorage/adapters/oracle/mover.py | 26 +-- .../adapters/oracle/scriptrunner.py | 34 ++-- .../adapters/oracle/tests/__init__.py | 1 + .../adapters/oracle/tests/test_dialect.py | 182 +++++++++++++++++ src/relstorage/adapters/postgresql/adapter.py | 2 - src/relstorage/adapters/tests/test__sql.py | 118 ++++++++++- src/relstorage/cache/interfaces.py | 8 +- src/relstorage/tests/__init__.py | 4 + 15 files changed, 633 insertions(+), 152 deletions(-) create mode 100644 src/relstorage/adapters/oracle/dialect.py create mode 100644 src/relstorage/adapters/oracle/tests/__init__.py create mode 100644 src/relstorage/adapters/oracle/tests/test_dialect.py diff --git a/src/relstorage/_compat.py b/src/relstorage/_compat.py index 338f2511..ee4dd1d9 100644 --- a/src/relstorage/_compat.py +++ b/src/relstorage/_compat.py @@ -55,6 +55,8 @@ def list_values(d): OID_OBJECT_MAP_TYPE = dict OID_SET_TYPE = set +MAX_TID = BTrees.family64.maxint + def iteroiditems(d): # Could be either a BTree, which always has 'iteritems', # or a plain dict, which may or may not have iteritems. diff --git a/src/relstorage/adapters/_sql.py b/src/relstorage/adapters/_sql.py index cbadfe95..9c39e18a 100644 --- a/src/relstorage/adapters/_sql.py +++ b/src/relstorage/adapters/_sql.py @@ -34,6 +34,8 @@ we don't want to take a dependency that can conflict with applications using RelStorage. """ +# pylint:disable=too-many-lines + from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -61,6 +63,9 @@ class Type(object): A database type. """ +class _Unknown(Type): + "Unspecified." + class Integer64(Type): """ A 64-bit integer. @@ -91,17 +96,26 @@ class Boolean(Type): A two-value column. """ -class Column(object): +class _Resolvable(object): + + def resolve_against(self, table): + # pylint:disable=unused-argument + return self + +class Column(_Resolvable): """ Defines a column in a table. """ - def __init__(self, name, type_=None, primary_key=False, nullable=True): + def __init__(self, name, type_=_Unknown, primary_key=False, nullable=True): self.name = name - self.type_ = type_ + self.type_ = type_ if isinstance(type_, Type) else type_() self.primary_key = primary_key self.nullable = False if primary_key else nullable + def is_type(self, type_): + return isinstance(self.type_, type_) + def __str__(self): return self.name @@ -121,15 +135,22 @@ def __le__(self, other): return _LessEqualExpression(self, other) def __compile_visit__(self, compiler): - compiler.emit_identifier(self.name) + compiler.visit_column(self) -class _TextNode(object): +class _LiteralNode(_Resolvable): def __init__(self, raw): self.raw = raw + self.name = 'anon_%x' % (id(self),) def __compile_visit__(self, compiler): - compiler.emit(self.raw) + compiler.emit(str(self.raw)) + + def resolve_against(self, table): + return self + +class _TextNode(_LiteralNode): + pass class _Columns(object): """ @@ -139,11 +160,7 @@ class _Columns(object): def __init__(self, columns): cs = [] for c in columns: - try: - setattr(self, c.name, c) - except AttributeError: - # Must be a string. - c = _TextNode(c) + setattr(self, c.name, c) cs.append(c) self._columns = tuple(cs) @@ -175,8 +192,16 @@ def has_bind_param(self): for c in self._columns ) + def as_select_list(self): + return _SelectColumns(self._columns) + _ColumnList = _Columns +class _SelectColumns(_Columns): + + def __compile_visit__(self, compiler): + compiler.visit_select_list_csv(self._columns) + class Table(object): """ A table relation. @@ -258,7 +283,7 @@ def bind(self, context): return new def compiler_class(self): - return _Compiler + return Compiler def compiler(self, root): return self.compiler_class()(root) @@ -267,7 +292,7 @@ def datatypes_for_columns(self, column_list): columns = list(column_list) datatypes = [] for column in columns: - datatype = self.datatype_map[column.type_] + datatype = self.datatype_map[type(column.type_)] datatypes.append(datatype) return datatypes @@ -461,10 +486,15 @@ def execute(self, cursor, params=None): __traceback_info__ = stmt, params if params: cursor.execute(stmt, params) + elif self.params: + # XXX: This isn't really good. + # If there are both literals in the SQL and params, + # we don't handle that. + cursor.execute(stmt, self.params) else: cursor.execute(stmt) -class _Compiler(object): +class Compiler(object): def __init__(self, root): self.buf = NStringIO() @@ -473,7 +503,8 @@ def __init__(self, root): def __repr__(self): - return "" % ( + return "<%s %s %r>" % ( + type(self).__name__, self.buf.getvalue(), self.placeholders ) @@ -627,12 +658,27 @@ def visit_column_list(self, column_list): clist = column_list.c if hasattr(column_list, 'c') else column_list self.visit(clist) + def visit_select_list(self, column_list): + clist = column_list.c if hasattr(column_list, 'c') else column_list + self.visit(clist.as_select_list()) + def visit_csv(self, nodes): self.visit(nodes[0]) for node in nodes[1:]: self.emit(', ') self.visit(node) + visit_select_expression = visit + + def visit_select_list_csv(self, nodes): + self.visit_select_expression(nodes[0]) + for node in nodes[1:]: + self.emit(', ') + self.visit_select_expression(node) + + def visit_column(self, column_node): + self.emit_identifier(column_node.name) + def visit_from(self, from_): self.emit_keyword('FROM') self.visit(from_) @@ -645,8 +691,8 @@ def visit_grouped(self, clause): def visit_op(self, op): self.emit(' ' + op + ' ') - def _next_placeholder_name(self): - return 'param_%d' % (len(self.placeholders),) + def _next_placeholder_name(self, prefix='param'): + return '%s_%d' % (prefix, len(self.placeholders),) def _placeholder(self, key): # Write things in `pyformat` style by default, assuming a @@ -658,15 +704,21 @@ def _placeholder(self, key): def _placeholder_for_literal_param_value(self, value): placeholder = self.placeholders.get(value) if not placeholder: - placeholder_name = self._next_placeholder_name() + placeholder_name = self._next_placeholder_name(prefix='literal') placeholder = self._placeholder(placeholder_name) self.placeholders[value] = placeholder_name return placeholder - def visit_literal_param(self, value): + def visit_literal_expression(self, value): placeholder = self._placeholder_for_literal_param_value(value) self.emit(placeholder) + def visit_boolean_literal_expression(self, value): + # In the oracle dialect, this needs to be + # either "'Y'" or "'N'" + assert isinstance(value, bool) + self.emit(str(value).upper()) + def visit_bind_param(self, bind_param): self.placeholders[bind_param] = bind_param.key self.emit(self._placeholder(bind_param.key)) @@ -675,10 +727,14 @@ def visit_ordered_bind_param(self, bind_param): self.placeholders[bind_param] = '%s' self.emit('%s') -class _Expression(_Bindable): +_Compiler = Compiler # BWC. Remove + +class _Expression(_Bindable, + _Resolvable): """ A SQL expression. """ + class _BindParam(_Expression): def __init__(self, key): @@ -697,7 +753,12 @@ def __init__(self, value): self.value = value def __compile_visit__(self, compiler): - compiler.visit_literal_param(self.value) + compiler.visit_literal_expression(self.value) + +class _BooleanLiteralExpression(_LiteralExpression): + + def __compile_visit__(self, compiler): + compiler.visit_boolean_literal_expression(self.value) class _OrderedBindParam(_Expression): @@ -709,6 +770,26 @@ def __compile_visit__(self, compiler): def orderedbindparam(): return _OrderedBindParam() +def _as_node(c): + if isinstance(c, int): + return _LiteralNode(c) + if isinstance(c, str): + return _TextNode(c) + return c + +def _as_expression(stmt): + if hasattr(stmt, '__compile_visit__'): + return stmt + + if isinstance(stmt, bool): + # The values True and False are handled + # specially because their representation varies + # among databases (Oracle) + stmt = _BooleanLiteralExpression(stmt) + else: + stmt = _LiteralExpression(stmt) + return stmt + class _BinaryExpression(_Expression): """ Expresses a comparison. @@ -717,9 +798,9 @@ class _BinaryExpression(_Expression): def __init__(self, op, lhs, rhs): self.op = op self.lhs = lhs # type: Column - # rhs is either a literal or a column - if not hasattr(rhs, '__compile_visit__'): - rhs = _LiteralExpression(rhs) + # rhs is either a literal or a column; + # certain literals are handled specially. + rhs = _as_expression(rhs) self.rhs = rhs def __str__(self): @@ -734,6 +815,13 @@ def __compile_visit__(self, compiler): compiler.visit_op(self.op) compiler.visit(self.rhs) + def resolve_against(self, table): + lhs = self.lhs.resolve_against(table) + rhs = self.rhs.resolve_against(table) + new = copy(self) + new.rhs = rhs + new.lhs = lhs + return new class _EmptyExpression(_Expression): """ @@ -795,6 +883,10 @@ def __init__(self, lhs, rhs): def __compile_visit__(self, compiler): compiler.visit_grouped(_BinaryExpression('AND', self.lhs, self.rhs)) + def resolve_against(self, table): + return type(self)(self.lhs.resolve_against(table), + self.rhs.resolve_against(table)) + class _WhereClause(_Clause): def __init__(self, expression): @@ -875,6 +967,36 @@ def prepared(self): s.prepare = True return s +class _DeferredColumn(Column): + + def resolve_against(self, table): + return getattr(table.c, self.name) + +class _DeferredColumns(object): + + def __getattr__(self, name): + return _DeferredColumn(name) + +class _It(object): + """ + A proxy that select can resolve to tables in the current table. + """ + + c = _DeferredColumns() + + def bindparam(self, name): + return bindparam(name) + +it = _It() + +def _resolved_against(columns, table): + resolved = [ + _as_node(c).resolve_against(table) + for c + in columns + ] + return resolved + class Select(_Query): """ A Select query. @@ -885,6 +1007,7 @@ class Select(_Query): appropriate SQL syntax and compile themselves into a string. """ + _distinct = _EmptyExpression() _where = _EmptyExpression() _order_by = _EmptyExpression() _limit = None @@ -894,21 +1017,24 @@ class Select(_Query): def __init__(self, table, *columns): self.table = table if columns: - self.column_list = _ColumnList(columns) + self.column_list = _ColumnList(_resolved_against(columns, table)) else: self.column_list = table def where(self, expression): + expression = expression.resolve_against(self.table) s = copy(self) s._where = _where(expression) return s def and_(self, expression): + expression = expression.resolve_against(self.table) s = copy(self) s._where = self._where.and_(expression) return s def order_by(self, expression, dir=None): + expression = expression.resolve_against(self.table) s = copy(self) s._order_by = _OrderBy(expression, dir) return s @@ -928,9 +1054,15 @@ def nowait(self): s._nowait = 'NOWAIT' return s + def distinct(self): + s = copy(self) + s._distinct = _TextNode('DISTINCT') + return s + def __compile_visit__(self, compiler): compiler.emit_keyword('SELECT') - compiler.visit_column_list(self.column_list) + compiler.visit(self._distinct) + compiler.visit_select_list(self.column_list) compiler.visit_from(self.table) compiler.visit_clause(self._where) compiler.visit_clause(self._order_by) @@ -969,7 +1101,7 @@ class Insert(_Query): def __init__(self, table, *columns): self.table = table if columns: - self.column_list = _Columns(columns) + self.column_list = _Columns(_resolved_against(columns, table)) # TODO: Probably want a different type, like a ValuesList self.values = _Columns([orderedbindparam() for _ in columns]) diff --git a/src/relstorage/adapters/dbiter.py b/src/relstorage/adapters/dbiter.py index 9bb882f2..5f341db2 100644 --- a/src/relstorage/adapters/dbiter.py +++ b/src/relstorage/adapters/dbiter.py @@ -16,29 +16,38 @@ from zope.interface import implementer -from .interfaces import IDatabaseIterator +from relstorage._compat import MAX_TID +from .interfaces import IDatabaseIterator +from .schema import Schema +from ._sql import it class DatabaseIterator(object): - """Abstract base class for database iteration. + """ + Abstract base class for database iteration. """ - def __init__(self, database_driver, runner): - self.runner = runner + def __init__(self, database_driver): + """ + :param database_driver: Necessary to bind queries correctly. + """ self.driver = database_driver + _iter_objects_query = Schema.object_state.select( + it.c.zoid, + it.c.state + ).where( + it.c.tid == it.bindparam('tid') + ).order_by( + it.c.zoid + ) + def iter_objects(self, cursor, tid): """Iterate over object states in a transaction. - Yields (oid, prev_tid, state) for each object state. + Yields ``(oid, state)`` for each object in the transaction. """ - stmt = """ - SELECT zoid, state - FROM object_state - WHERE tid = %(tid)s - ORDER BY zoid - """ - self.runner.run_script_stmt(cursor, stmt, {'tid': tid}) + self._iter_objects_query.execute(cursor, {'tid': tid}) for oid, state in cursor: state = self.driver.binary_column_as_state_type(state) yield oid, state @@ -47,6 +56,8 @@ def iter_objects(self, cursor, tid): @implementer(IDatabaseIterator) class HistoryPreservingDatabaseIterator(DatabaseIterator): + keep_history = True + def _transaction_iterator(self, cursor): """ Iterate over a list of transactions returned from the database. @@ -71,22 +82,38 @@ def _transaction_iterator(self, cursor): yield (tid, username, description, ext) + tuple(row[4:]) + _iter_transactions_query = Schema.transaction.select( + it.c.tid, it.c.username, it.c.description, it.c.extension + ).where( + it.c.packed == False # pylint:disable=singleton-comparison + ).and_( + it.c.tid != 0 + ).order_by( + it.c.tid, 'DESC' + ) + def iter_transactions(self, cursor): """Iterate over the transaction log, newest first. Skips packed transactions. Yields (tid, username, description, extension) for each transaction. """ - stmt = """ - SELECT tid, username, description, extension - FROM transaction - WHERE packed = %(FALSE)s - AND tid != 0 - ORDER BY tid DESC - """ - self.runner.run_script_stmt(cursor, stmt) + self._iter_transactions_query.execute(cursor) return self._transaction_iterator(cursor) + _iter_transactions_range_query = Schema.transaction.select( + it.c.tid, + it.c.username, + it.c.description, + it.c.extension, + it.c.packed, + ).where( + it.c.tid >= it.bindparam('min_tid') + ).and_( + it.c.tid <= it.bindparam('max_tid') + ).order_by( + it.c.tid + ) def iter_transactions_range(self, cursor, start=None, stop=None): """Iterate over the transactions in the given range, oldest first. @@ -95,21 +122,31 @@ def iter_transactions_range(self, cursor, start=None, stop=None): Yields (tid, username, description, extension, packed) for each transaction. """ - stmt = """ - SELECT tid, username, description, extension, - CASE WHEN packed = %(TRUE)s THEN 1 ELSE 0 END - FROM transaction - WHERE tid >= 0 - """ - if start is not None: - stmt += " AND tid >= %(min_tid)s" - if stop is not None: - stmt += " AND tid <= %(max_tid)s" - stmt += " ORDER BY tid" - self.runner.run_script_stmt(cursor, stmt, - {'min_tid': start, 'max_tid': stop}) + params = { + 'min_tid': start if start else 0, + 'max_tid': stop if stop else MAX_TID + } + self._iter_transactions_range_query.execute(cursor, params) return self._transaction_iterator(cursor) + _object_exists_query = Schema.current_object.select( + 1 + ).where( + it.c.zoid == it.bindparam('oid') + ) + + _object_history_query = Schema.transaction.natural_join( + Schema.object_state + ).select( + it.c.tid, it.c.username, it.c.description, it.c.extension, + Schema.object_state.c.state_size + ).where( + it.c.zoid == it.bindparam("oid") + ).and_( + it.c.packed == False # pylint:disable=singleton-comparison + ).order_by( + it.c.tid, "DESC" + ) def iter_object_history(self, cursor, oid): """Iterate over an object's history. @@ -118,59 +155,60 @@ def iter_object_history(self, cursor, oid): Yields (tid, username, description, extension, pickle_size) for each modification. """ - stmt = """ - SELECT 1 FROM current_object WHERE zoid = %(oid)s - """ - self.runner.run_script_stmt(cursor, stmt, {'oid': oid}) + params = {'oid': oid} + self._object_exists_query.execute(cursor, params) if not cursor.fetchall(): raise KeyError(oid) - stmt = """ - SELECT tid, username, description, extension, state_size - FROM transaction - JOIN object_state USING (tid) - WHERE zoid = %(oid)s - AND packed = %(FALSE)s - ORDER BY tid DESC - """ - self.runner.run_script_stmt(cursor, stmt, {'oid': oid}) + self._object_history_query.execute(cursor, params) return self._transaction_iterator(cursor) @implementer(IDatabaseIterator) class HistoryFreeDatabaseIterator(DatabaseIterator): + keep_history = False + def iter_transactions(self, cursor): """Iterate over the transaction log, newest first. Skips packed transactions. - Yields (tid, username, description, extension) for each transaction. + Yields ``(tid, username, description, extension)`` for each transaction. This always returns an empty iterable. """ # pylint:disable=unused-argument return [] + _iter_transactions_range_query = Schema.object_state.select( + it.c.tid, + ).where( + it.c.tid >= it.bindparam('min_tid') + ).and_( + it.c.tid <= it.bindparam('max_tid') + ).order_by( + it.c.tid + ).distinct() + def iter_transactions_range(self, cursor, start=None, stop=None): """Iterate over the transactions in the given range, oldest first. Includes packed transactions. - Yields (tid, username, description, extension, packed) + Yields ``(tid, username, description, extension, packed)`` for each transaction. """ - stmt = """ - SELECT DISTINCT tid - FROM object_state - WHERE tid > 0 - """ - if start is not None: - stmt += " AND tid >= %(min_tid)s" - if stop is not None: - stmt += " AND tid <= %(max_tid)s" - stmt += " ORDER BY tid" - self.runner.run_script_stmt(cursor, stmt, - {'min_tid': start, 'max_tid': stop}) - return ((tid, '', '', '', True) for (tid,) in cursor) + params = { + 'min_tid': start if start else 0, + 'max_tid': stop if stop else MAX_TID + } + self._iter_transactions_range_query.execute(cursor, params) + return ((tid, b'', b'', b'', True) for (tid,) in cursor) + + _iter_object_history_query = Schema.object_state.select( + it.c.tid, it.c.state_size + ).where( + it.c.zoid == it.bindparam('oid') + ) def iter_object_history(self, cursor, oid): """ @@ -180,12 +218,7 @@ def iter_object_history(self, cursor, oid): Yields a single row, ``(tid, username, description, extension, pickle_size)`` """ - stmt = """ - SELECT tid, state_size - FROM object_state - WHERE zoid = %(oid)s - """ - self.runner.run_script_stmt(cursor, stmt, {'oid': oid}) + self._iter_object_history_query.execute(cursor, {'oid': oid}) rows = cursor.fetchall() if not rows: raise KeyError(oid) diff --git a/src/relstorage/adapters/mysql/adapter.py b/src/relstorage/adapters/mysql/adapter.py index a3e19e89..2f799893 100644 --- a/src/relstorage/adapters/mysql/adapter.py +++ b/src/relstorage/adapters/mysql/adapter.py @@ -147,7 +147,6 @@ def __init__(self, options=None, **params): ) self.dbiter = HistoryPreservingDatabaseIterator( driver, - runner=self.runner, ) else: self.packundo = MySQLHistoryFreePackUndo( @@ -159,7 +158,6 @@ def __init__(self, options=None, **params): ) self.dbiter = HistoryFreeDatabaseIterator( driver, - runner=self.runner, ) self.stats = MySQLStats( diff --git a/src/relstorage/adapters/oracle/adapter.py b/src/relstorage/adapters/oracle/adapter.py index a1d96b2d..3012e509 100644 --- a/src/relstorage/adapters/oracle/adapter.py +++ b/src/relstorage/adapters/oracle/adapter.py @@ -139,7 +139,6 @@ def __init__(self, user, password, dsn, commit_lock_id=0, ) self.dbiter = HistoryPreservingDatabaseIterator( driver, - runner=self.runner, ) else: self.packundo = OracleHistoryFreePackUndo( @@ -151,7 +150,6 @@ def __init__(self, user, password, dsn, commit_lock_id=0, ) self.dbiter = HistoryFreeDatabaseIterator( driver, - runner=self.runner, ) self.stats = OracleStats( diff --git a/src/relstorage/adapters/oracle/dialect.py b/src/relstorage/adapters/oracle/dialect.py new file mode 100644 index 00000000..3b4971cf --- /dev/null +++ b/src/relstorage/adapters/oracle/dialect.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +""" +The Oracle dialect + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from .._sql import DefaultDialect +from .._sql import Compiler +from .._sql import Boolean +from .._sql import Column + +class OracleCompiler(Compiler): + + def visit_boolean_literal_expression(self, value): + sql = "'Y'" if value else "'N'" + self.emit(sql) + + def can_prepare(self): + # We haven't investigated preparing statements manually + # with cx_Oracle. There's a chance that `cx_Oracle.Connection.stmtcachesize` + # will accomplish all we need. + return False + + def _placeholder(self, key): + # XXX: What's that block in the parent about key == '%s'? + # Only during prepare() I think. + return ':' + key + + def visit_ordered_bind_param(self, bind_param): + ph = self.placeholders[bind_param] = ':%d' % (len(self.placeholders) + 1,) + self.emit(ph) + + def visit_select_expression(self, column_node): + if isinstance(column_node, Column) and column_node.is_type(Boolean): + # Fancy CASE statement to get 1 or 0 into Python + self.emit_keyword('CASE WHEN') + self.emit_identifier(column_node.name) + self.emit(" = 'Y' THEN 1 ELSE 0 END") + else: + super(OracleCompiler, self).visit_select_expression(column_node) + + +class OracleDialect(DefaultDialect): + + def compiler_class(self): + return OracleCompiler diff --git a/src/relstorage/adapters/oracle/drivers.py b/src/relstorage/adapters/oracle/drivers.py index 4e34d868..cf28c6b0 100644 --- a/src/relstorage/adapters/oracle/drivers.py +++ b/src/relstorage/adapters/oracle/drivers.py @@ -26,6 +26,7 @@ from .._abstract_drivers import AbstractModuleDriver from .._abstract_drivers import implement_db_driver_options from ..interfaces import IDBDriver +from .dialect import OracleDialect database_type = 'oracle' @@ -37,6 +38,7 @@ class cx_OracleDriver(AbstractModuleDriver): __name__ = 'cx_Oracle' MODULE_NAME = __name__ + dialect = OracleDialect() def __init__(self): super(cx_OracleDriver, self).__init__() diff --git a/src/relstorage/adapters/oracle/mover.py b/src/relstorage/adapters/oracle/mover.py index 39ea25ca..7fff81d8 100644 --- a/src/relstorage/adapters/oracle/mover.py +++ b/src/relstorage/adapters/oracle/mover.py @@ -24,13 +24,6 @@ from ..interfaces import IObjectMover from ..mover import AbstractObjectMover from ..mover import metricmethod_sampled -from .scriptrunner import format_to_named - - -def _to_oracle_ordered(query_tuple): - # Replace %s with :1, :2, etc - assert len(query_tuple) == 2 - return format_to_named(query_tuple[0]), format_to_named(query_tuple[1]) @implementer(IObjectMover) @@ -39,16 +32,6 @@ class OracleObjectMover(AbstractObjectMover): # This is assigned to by the adapter. inputsizes = None - _move_from_temp_hp_insert_query = format_to_named( - AbstractObjectMover._move_from_temp_hp_insert_query) - _move_from_temp_hf_insert_query = format_to_named( - AbstractObjectMover._move_from_temp_hf_insert_query) - _move_from_temp_copy_blob_query = format_to_named( - AbstractObjectMover._move_from_temp_copy_blob_query) - - # XXX: This is definitely broken! - # _load_current_queries = _to_oracle_ordered(AbstractObjectMover._load_current_queries) - @metricmethod_sampled def load_current(self, cursor, oid): stmt = self._load_current_query @@ -56,8 +39,6 @@ def load_current(self, cursor, oid): cursor, stmt, (oid,), default=(None, None)) - _load_revision_query = format_to_named(AbstractObjectMover._load_revision_query) - @metricmethod_sampled def load_revision(self, cursor, oid, tid): stmt = self._load_revision_query @@ -66,9 +47,6 @@ def load_revision(self, cursor, oid, tid): return state - # XXX: Def broken. - #_exists_queries = _to_oracle_ordered(AbstractObjectMover._exists_queries) - @metricmethod_sampled def exists(self, cursor, oid): stmt = self._exists_query @@ -248,9 +226,7 @@ def replace_temp(self, cursor, oid, prev_tid, data): - _update_current_insert_query = format_to_named(AbstractObjectMover._update_current_insert_query) - _update_current_update_query = format_to_named(AbstractObjectMover._update_current_update_query) - _update_current_update_query = _update_current_update_query.replace('ORDER BY zoid', '') + # XXX: For _update_current_update_query we used to remove 'ORDER BY zoid'. Still needed? @metricmethod_sampled def download_blob(self, cursor, oid, tid, filename): diff --git a/src/relstorage/adapters/oracle/scriptrunner.py b/src/relstorage/adapters/oracle/scriptrunner.py index ba0fe99a..ee0ee5b8 100644 --- a/src/relstorage/adapters/oracle/scriptrunner.py +++ b/src/relstorage/adapters/oracle/scriptrunner.py @@ -23,29 +23,29 @@ _stmt_cache = {} -def format_to_named(stmt): +def _format_to_named(stmt): """ Convert '%s' pyformat strings to :n numbered strings. Intended only for static strings. + + This is legacy. Replace strings that use this with SQL statements + constructed from the schema. """ - # XXX: This should be part of the Compiler, as handled - # by the driver. - return stmt - # import re - # from relstorage._compat import intern + import re + from relstorage._compat import intern - # try: - # return _stmt_cache[stmt] - # except KeyError: - # matches = [] + try: + return _stmt_cache[stmt] + except KeyError: + matches = [] - # def replace(_match): - # matches.append(None) - # return ':%d' % len(matches) - # new_stmt = intern(re.sub('%s', replace, stmt)) - # _stmt_cache[stmt] = new_stmt + def replace(_match): + matches.append(None) + return ':%d' % len(matches) + new_stmt = intern(re.sub('%s', replace, stmt)) + _stmt_cache[stmt] = new_stmt - # return new_stmt + return new_stmt class OracleScriptRunner(ScriptRunner): @@ -93,7 +93,7 @@ def run_many(self, cursor, stmt, items): stmt should use '%s' parameter format. """ - cursor.executemany(format_to_named(stmt), items) + cursor.executemany(_format_to_named(stmt), items) class TrackingMap(object): diff --git a/src/relstorage/adapters/oracle/tests/__init__.py b/src/relstorage/adapters/oracle/tests/__init__.py new file mode 100644 index 00000000..68336058 --- /dev/null +++ b/src/relstorage/adapters/oracle/tests/__init__.py @@ -0,0 +1 @@ +# Oracle test package. diff --git a/src/relstorage/adapters/oracle/tests/test_dialect.py b/src/relstorage/adapters/oracle/tests/test_dialect.py new file mode 100644 index 00000000..7004f9c0 --- /dev/null +++ b/src/relstorage/adapters/oracle/tests/test_dialect.py @@ -0,0 +1,182 @@ +# -*- coding: utf-8 -*- +""" +Tests for the Oracle dialect. + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from relstorage.tests import TestCase +from relstorage.tests import MockCursor + +from ...schema import Schema +from ..._sql import it + +from ..dialect import OracleDialect + +class Context(object): + keep_history = True + dialect = OracleDialect() + +class Driver(object): + dialect = OracleDialect() + +class TestOracleDialect(TestCase): + + def test_boolean_parameter(self): + transaction = Schema.transaction + object_state = Schema.object_state + + stmt = transaction.natural_join( + object_state + ).select( + it.c.tid, it.c.username, it.c.description, it.c.extension, + object_state.c.state_size + ).where( + it.c.zoid == it.bindparam("oid") + ).and_( + it.c.packed == False # pylint:disable=singleton-comparison + ).order_by( + it.c.tid, "DESC" + ) + + # By default we get the generic syntax + self.assertEqual( + str(stmt), + 'SELECT tid, username, description, extension, state_size ' + 'FROM transaction ' + 'JOIN object_state ' + 'USING (tid) ' + "WHERE ((zoid = %(oid)s AND packed = FALSE)) " + 'ORDER BY tid DESC' + ) + + # But once we bind to the dialect we get the expected value + stmt = stmt.bind(Context) + self.assertEqual( + str(stmt), + 'SELECT tid, username, description, extension, state_size ' + 'FROM transaction ' + 'JOIN object_state ' + 'USING (tid) ' + "WHERE ((zoid = :oid AND packed = 'N')) " + 'ORDER BY tid DESC' + ) + + def test_boolean_result(self): + """ + SELECT tid, username, description, extension, + + FROM transaction + WHERE tid >= 0 + """ + + transaction = Schema.transaction + stmt = transaction.select( + transaction.c.tid, + transaction.c.username, + transaction.c.description, + transaction.c.extension, + transaction.c.packed + ).where( + transaction.c.tid >= 0 + ) + + stmt = stmt.bind(Context) + + self.assertEqual( + str(stmt), + 'SELECT tid, username, description, extension, ' + "CASE WHEN packed = 'Y' THEN 1 ELSE 0 END " + 'FROM transaction WHERE (tid >= :literal_0)' + ) + +class TestMoverQueries(TestCase): + + def _makeOne(self): + from ..mover import OracleObjectMover + from relstorage.tests import MockOptions + + return OracleObjectMover(Driver(), MockOptions()) + + def test_move_named_query(self): + + inst = self._makeOne() + + unbound = type(inst)._move_from_temp_hf_insert_query + + self.assertRegex( + str(unbound), + r'EXECUTE rs_prep_stmt_.*' + ) + + # Bound, the py-format style escapes are replaced with + # :name params; these are simple numbers. + self.assertEqual( + str(inst._move_from_temp_hf_insert_query), + 'INSERT INTO object_state(zoid, tid, state_size, state) ' + 'SELECT zoid, :1, COALESCE(LENGTH(state), 0), state ' + 'FROM temp_store ORDER BY zoid' + ) + +class TestDatabaseIteratorQueries(TestCase): + + def _makeOne(self): + from relstorage.adapters.dbiter import HistoryPreservingDatabaseIterator + return HistoryPreservingDatabaseIterator(Driver()) + + def test_query(self): + inst = self._makeOne() + + unbound = type(inst)._iter_objects_query + + self.assertEqual( + str(unbound), + 'SELECT zoid, state FROM object_state WHERE (tid = %(tid)s) ORDER BY zoid' + ) + + # Bound we get pyformat %(name)s replaced with :name + self.assertEqual( + str(inst._iter_objects_query), + 'SELECT zoid, state FROM object_state WHERE (tid = :tid) ORDER BY zoid' + ) + + cursor = MockCursor() + list(inst.iter_objects(cursor, 1)) # Iterator, must flatten + + self.assertEqual(1, len(cursor.executed)) + stmt, params = cursor.executed[0] + + self.assertEqual( + stmt, + 'SELECT zoid, state FROM object_state WHERE (tid = :tid) ORDER BY zoid' + ) + + self.assertEqual( + params, + {'tid': 1} + ) + + def test_iter_transactions(self): + inst = self._makeOne() + cursor = MockCursor() + + inst.iter_transactions(cursor) + + self.assertEqual(1, len(cursor.executed)) + stmt, params = cursor.executed[0] + + self.assertEqual( + stmt, + 'SELECT tid, username, description, extension ' + 'FROM transaction ' + "WHERE ((packed = 'N' AND tid <> :literal_0)) " + 'ORDER BY tid DESC' + ) + + self.assertEqual( + params, + {'literal_0': 0} + ) diff --git a/src/relstorage/adapters/postgresql/adapter.py b/src/relstorage/adapters/postgresql/adapter.py index f869d9d3..d2e1edc1 100644 --- a/src/relstorage/adapters/postgresql/adapter.py +++ b/src/relstorage/adapters/postgresql/adapter.py @@ -128,7 +128,6 @@ def __init__(self, dsn='', options=None): ) self.dbiter = HistoryPreservingDatabaseIterator( driver, - runner=self.runner, ) else: self.packundo = HistoryFreePackUndo( @@ -142,7 +141,6 @@ def __init__(self, dsn='', options=None): self.packundo._lock_for_share = 'FOR KEY SHARE OF object_state' self.dbiter = HistoryFreeDatabaseIterator( driver, - runner=self.runner, ) self.stats = PostgreSQLStats( diff --git a/src/relstorage/adapters/tests/test__sql.py b/src/relstorage/adapters/tests/test__sql.py index 1cb14e4a..9dc29a2a 100644 --- a/src/relstorage/adapters/tests/test__sql.py +++ b/src/relstorage/adapters/tests/test__sql.py @@ -31,6 +31,8 @@ from .._sql import OID from .._sql import TID from .._sql import State +from .._sql import Boolean +from .._sql import BinaryString current_object = Table( 'current_object', @@ -43,13 +45,14 @@ Column('zoid', OID), Column('tid', TID), Column('state', State), + Column('state_size'), ) hp_object_and_state = current_object.natural_join(object_state) objects = HistoryVariantTable( current_object, - object_state, + object_state, ) object_and_state = HistoryVariantTable( @@ -57,6 +60,15 @@ object_state ) +transaction = Table( + 'transaction', + Column('tid', TID), + Column('packed', Boolean), + Column('username', BinaryString), + Column('description', BinaryString), + Column('extension', BinaryString), +) + class TestTableSelect(TestCase): def test_simple_eq_select(self): @@ -66,7 +78,15 @@ def test_simple_eq_select(self): self.assertEqual( str(stmt), - 'SELECT zoid, tid, state FROM object_state WHERE (zoid = tid)' + 'SELECT zoid, tid, state, state_size FROM object_state WHERE (zoid = tid)' + ) + + def test_distinct(self): + table = object_state + stmt = table.select(table.c.zoid).where(table.c.tid == table.bindparam('tid')).distinct() + self.assertEqual( + str(stmt), + 'SELECT DISTINCT zoid FROM object_state WHERE (tid = %(tid)s)' ) def test_simple_eq_select_and(self): @@ -77,13 +97,14 @@ def test_simple_eq_select_and(self): self.assertEqual( str(stmt), - 'SELECT zoid, tid, state FROM object_state WHERE (zoid = tid)' + 'SELECT zoid, tid, state, state_size FROM object_state WHERE (zoid = tid)' ) stmt = stmt.and_(table.c.zoid > 5) self.assertEqual( str(stmt), - 'SELECT zoid, tid, state FROM object_state WHERE ((zoid = tid AND zoid > %(param_0)s))' + 'SELECT zoid, tid, state, state_size ' + 'FROM object_state WHERE ((zoid = tid AND zoid > %(literal_0)s))' ) def test_simple_eq_select_literal(self): @@ -94,12 +115,12 @@ def test_simple_eq_select_literal(self): self.assertEqual( str(stmt), - 'SELECT zoid, tid, state FROM object_state WHERE (zoid = %(param_0)s)' + 'SELECT zoid, tid, state, state_size FROM object_state WHERE (zoid = %(literal_0)s)' ) self.assertEqual( stmt.compiled().params, - {'param_0': 7}) + {'literal_0': 7}) def test_column_query_variant_table(self): stmt = objects.select(objects.c.tid, objects.c.zoid).where( @@ -242,3 +263,88 @@ def test_prepared_insert_select_with_param(self): stmt._prepare_stmt, r"PREPARE rs_prep_stmt_[0-9]* \(BIGINT\) AS.*" ) + + def test_it(self): + from .._sql import it + stmt = object_state.select( + it.c.zoid, + it.c.state + ).where( + it.c.tid == it.bindparam('tid') + ).order_by( + it.c.zoid + ) + + self.assertEqual( + str(stmt), + 'SELECT zoid, state FROM object_state WHERE (tid = %(tid)s) ORDER BY zoid' + ) + + # Now something that won't resolve. + col_ref = it.c.dne + + # In the column list + with self.assertRaisesRegex(AttributeError, 'does not include dne'): + object_state.select(col_ref) + + stmt = object_state.select(it.c.zoid) + + # In the where clause + with self.assertRaisesRegex(AttributeError, 'does not include dne'): + stmt.where(col_ref == object_state.c.state) + + # In order by + with self.assertRaisesRegex(AttributeError, 'does not include dne'): + stmt.order_by(col_ref == object_state.c.state) + + def test_boolean_literal(self): + from .._sql import it + stmt = transaction.select( + transaction.c.tid + ).where( + it.c.packed == False # pylint:disable=singleton-comparison + ).order_by( + transaction.c.tid, 'DESC' + ) + + self.assertEqual( + str(stmt), + 'SELECT tid FROM transaction WHERE (packed = FALSE) ORDER BY tid DESC' + ) + + def test_literal_in_select(self): + stmt = current_object.select( + 1 + ).where( + current_object.c.zoid == current_object.bindparam('oid') + ) + + self.assertEqual( + str(stmt), + 'SELECT 1 FROM current_object WHERE (zoid = %(oid)s)' + ) + + def test_boolean_literal_it_joined_table(self): + from .._sql import it + stmt = transaction.natural_join( + object_state + ).select( + it.c.tid, it.c.username, it.c.description, it.c.extension, + object_state.c.state_size + ).where( + it.c.zoid == it.bindparam("oid") + ).and_( + it.c.packed == False # pylint:disable=singleton-comparison + ).order_by( + it.c.tid, "DESC" + ) + + self.assertEqual( + str(stmt), + 'SELECT tid, username, description, extension, state_size ' + 'FROM transaction ' + 'JOIN object_state ' + 'USING (tid) ' + 'WHERE ((zoid = %(oid)s AND packed = FALSE)) ' + 'ORDER BY tid DESC' + ) diff --git a/src/relstorage/cache/interfaces.py b/src/relstorage/cache/interfaces.py index 8cf3a44c..68b0f184 100644 --- a/src/relstorage/cache/interfaces.py +++ b/src/relstorage/cache/interfaces.py @@ -18,17 +18,17 @@ from zope.interface import Attribute from zope.interface import Interface -import BTrees + from transaction.interfaces import TransientError from ZODB.POSException import StorageError +# Export +from relstorage._compat import MAX_TID # pylint:disable=unused-import + # pylint: disable=inherit-non-class,no-method-argument,no-self-argument # pylint:disable=unexpected-special-method-signature # pylint:disable=signature-differs -# An LLBTree uses much less memory than a dict, and is still plenty fast on CPython; -# it's just as big and slower on PyPy, though. -MAX_TID = BTrees.family64.maxint class IStateCache(Interface): """ diff --git a/src/relstorage/tests/__init__.py b/src/relstorage/tests/__init__.py index b0fffed7..3c112deb 100644 --- a/src/relstorage/tests/__init__.py +++ b/src/relstorage/tests/__init__.py @@ -234,6 +234,10 @@ def fetchall(self): def close(self): self.closed = True + def __iter__(self): + for row in self.results: + yield row + class MockOptions(Options): cache_module_name = '' # disable cache_servers = '' From 10eddacd1b57ca9fcec078912a5d13e19bc334ee Mon Sep 17 00:00:00 2001 From: Jason Madden Date: Fri, 19 Jul 2019 16:14:33 -0500 Subject: [PATCH 6/8] Try harder for unique prepared statement ids. --- src/relstorage/adapters/_sql.py | 18 +++++++++++++++--- src/relstorage/adapters/tests/test__sql.py | 4 ++-- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/src/relstorage/adapters/_sql.py b/src/relstorage/adapters/_sql.py index 9c39e18a..46856226 100644 --- a/src/relstorage/adapters/_sql.py +++ b/src/relstorage/adapters/_sql.py @@ -522,9 +522,21 @@ def can_prepare(self): _prepared_stmt_counter = 0 @classmethod - def _next_prepared_stmt_name(cls): + def _next_prepared_stmt_name(cls, query): + # Even with the GIL, this isn't fully safe to do; two threads + # can still get the same value. We don't want to allocate a + # lock because we might be patched by gevent later. So that's + # where `query` comes in: we add the hash as a disambiguator. + # Of course, for there to be a duplicate prepared statement + # sent to the database, that would mean that we were somehow + # using the same cursor or connection in multiple threads at + # once (or perhaps we got more than one cursor from a + # connection? We should only have one.) cls._prepared_stmt_counter += 1 - return 'rs_prep_stmt_%d' % (cls._prepared_stmt_counter,) + return 'rs_prep_stmt_%d_%d' % ( + cls._prepared_stmt_counter, + abs(hash(query)), + ) def _prepared_param(self, number): return '$' + str(number) @@ -568,7 +580,7 @@ def prepare(self): datatypes = self._find_datatypes_for_prepared_query() query = self.buf.getvalue() - name = self._next_prepared_stmt_name() + name = self._next_prepared_stmt_name(query) if datatypes: assert isinstance(datatypes, (list, tuple)) diff --git a/src/relstorage/adapters/tests/test__sql.py b/src/relstorage/adapters/tests/test__sql.py index 9dc29a2a..5183a7b1 100644 --- a/src/relstorage/adapters/tests/test__sql.py +++ b/src/relstorage/adapters/tests/test__sql.py @@ -236,7 +236,7 @@ def test_prepared_insert_values(self): stmt = stmt.compiled() self.assertRegex( stmt._prepare_stmt, - r"PREPARE rs_prep_stmt_[0-9]* \(BIGINT\) AS.*" + r"PREPARE rs_prep_stmt_[0-9]*_[0-9]* \(BIGINT\) AS.*" ) def test_prepared_insert_select_with_param(self): @@ -261,7 +261,7 @@ def test_prepared_insert_select_with_param(self): stmt = stmt.compiled() self.assertRegex( stmt._prepare_stmt, - r"PREPARE rs_prep_stmt_[0-9]* \(BIGINT\) AS.*" + r"PREPARE rs_prep_stmt_[0-9]*_[0-9]* \(BIGINT\) AS.*" ) def test_it(self): From 7570606d381a5b9d58946a28c69bb324283f4531 Mon Sep 17 00:00:00 2001 From: Jason Madden Date: Fri, 19 Jul 2019 17:49:36 -0500 Subject: [PATCH 7/8] Make _sql a package; it needs broken up. --- src/relstorage/adapters/dbiter.py | 2 +- .../adapters/mysql/drivers/__init__.py | 6 ++--- src/relstorage/adapters/oracle/dialect.py | 8 +++--- .../adapters/oracle/tests/test_dialect.py | 2 +- src/relstorage/adapters/poller.py | 2 +- .../adapters/postgresql/drivers/__init__.py | 2 +- .../adapters/postgresql/drivers/pg8000.py | 4 +-- src/relstorage/adapters/schema.py | 18 ++++++------- .../adapters/{_sql.py => sql/__init__.py} | 9 ++++++- src/relstorage/adapters/sql/tests/__init__.py | 0 .../test__sql.py => sql/tests/test_sql.py} | 26 +++++++++---------- src/relstorage/tests/__init__.py | 2 +- 12 files changed, 44 insertions(+), 37 deletions(-) rename src/relstorage/adapters/{_sql.py => sql/__init__.py} (98%) create mode 100644 src/relstorage/adapters/sql/tests/__init__.py rename src/relstorage/adapters/{tests/test__sql.py => sql/tests/test_sql.py} (96%) diff --git a/src/relstorage/adapters/dbiter.py b/src/relstorage/adapters/dbiter.py index 5f341db2..f2fb4da5 100644 --- a/src/relstorage/adapters/dbiter.py +++ b/src/relstorage/adapters/dbiter.py @@ -20,7 +20,7 @@ from .interfaces import IDatabaseIterator from .schema import Schema -from ._sql import it +from .sql import it class DatabaseIterator(object): """ diff --git a/src/relstorage/adapters/mysql/drivers/__init__.py b/src/relstorage/adapters/mysql/drivers/__init__.py index 77557e8c..b83d4c01 100644 --- a/src/relstorage/adapters/mysql/drivers/__init__.py +++ b/src/relstorage/adapters/mysql/drivers/__init__.py @@ -20,12 +20,12 @@ from ..._abstract_drivers import AbstractModuleDriver from ..._abstract_drivers import implement_db_driver_options -from ..._sql import _Compiler -from ..._sql import DefaultDialect +from ...sql import Compiler +from ...sql import DefaultDialect database_type = 'mysql' -class MySQLCompiler(_Compiler): +class MySQLCompiler(Compiler): def can_prepare(self): # If there are params, we can't prepare unless we're using diff --git a/src/relstorage/adapters/oracle/dialect.py b/src/relstorage/adapters/oracle/dialect.py index 3b4971cf..dde5dc34 100644 --- a/src/relstorage/adapters/oracle/dialect.py +++ b/src/relstorage/adapters/oracle/dialect.py @@ -7,10 +7,10 @@ from __future__ import division from __future__ import print_function -from .._sql import DefaultDialect -from .._sql import Compiler -from .._sql import Boolean -from .._sql import Column +from ..sql import DefaultDialect +from ..sql import Compiler +from ..sql import Boolean +from ..sql import Column class OracleCompiler(Compiler): diff --git a/src/relstorage/adapters/oracle/tests/test_dialect.py b/src/relstorage/adapters/oracle/tests/test_dialect.py index 7004f9c0..8fe43559 100644 --- a/src/relstorage/adapters/oracle/tests/test_dialect.py +++ b/src/relstorage/adapters/oracle/tests/test_dialect.py @@ -12,7 +12,7 @@ from relstorage.tests import MockCursor from ...schema import Schema -from ..._sql import it +from ...sql import it from ..dialect import OracleDialect diff --git a/src/relstorage/adapters/poller.py b/src/relstorage/adapters/poller.py index 184818b2..e59bf88c 100644 --- a/src/relstorage/adapters/poller.py +++ b/src/relstorage/adapters/poller.py @@ -21,7 +21,7 @@ from .interfaces import StaleConnectionError from .schema import Schema -from ._sql import func +from .sql import func log = logging.getLogger(__name__) diff --git a/src/relstorage/adapters/postgresql/drivers/__init__.py b/src/relstorage/adapters/postgresql/drivers/__init__.py index ba6ad737..1b506045 100644 --- a/src/relstorage/adapters/postgresql/drivers/__init__.py +++ b/src/relstorage/adapters/postgresql/drivers/__init__.py @@ -23,7 +23,7 @@ from ..._abstract_drivers import implement_db_driver_options from ..._abstract_drivers import AbstractModuleDriver -from ..._sql import DefaultDialect +from ...sql import DefaultDialect class PostgreSQLDialect(DefaultDialect): """ diff --git a/src/relstorage/adapters/postgresql/drivers/pg8000.py b/src/relstorage/adapters/postgresql/drivers/pg8000.py index 2b5957b3..6a6d6a91 100644 --- a/src/relstorage/adapters/postgresql/drivers/pg8000.py +++ b/src/relstorage/adapters/postgresql/drivers/pg8000.py @@ -24,7 +24,7 @@ from zope.interface import implementer from ...interfaces import IDBDriver -from ..._sql import _Compiler +from ...sql import Compiler from . import AbstractPostgreSQLDriver from . import PostgreSQLDialect @@ -118,7 +118,7 @@ class _tuple_deque(deque): def append(self, row): # pylint:disable=arguments-differ deque.append(self, tuple(row)) -class PG8000Compiler(_Compiler): +class PG8000Compiler(Compiler): def can_prepare(self): # Important: pg8000 1.10 - 1.13, at least, can't handle prepared diff --git a/src/relstorage/adapters/schema.py b/src/relstorage/adapters/schema.py index 16cf36ec..3fca0f06 100644 --- a/src/relstorage/adapters/schema.py +++ b/src/relstorage/adapters/schema.py @@ -25,15 +25,15 @@ from ._util import query_property from ._util import noop_when_history_free -from ._sql import Table -from ._sql import TemporaryTable -from ._sql import Column -from ._sql import HistoryVariantTable -from ._sql import OID -from ._sql import TID -from ._sql import State -from ._sql import Boolean -from ._sql import BinaryString +from .sql import Table +from .sql import TemporaryTable +from .sql import Column +from .sql import HistoryVariantTable +from .sql import OID +from .sql import TID +from .sql import State +from .sql import Boolean +from .sql import BinaryString log = logging.getLogger("relstorage") diff --git a/src/relstorage/adapters/_sql.py b/src/relstorage/adapters/sql/__init__.py similarity index 98% rename from src/relstorage/adapters/_sql.py rename to src/relstorage/adapters/sql/__init__.py index 46856226..78812fe7 100644 --- a/src/relstorage/adapters/_sql.py +++ b/src/relstorage/adapters/sql/__init__.py @@ -49,7 +49,7 @@ from relstorage._compat import NStringIO from relstorage._compat import intern from relstorage._util import CachedIn -from .interfaces import IDBDialect +from ..interfaces import IDBDialect def copy(obj): new = stdlib_copy(obj) @@ -532,6 +532,13 @@ def _next_prepared_stmt_name(cls, query): # using the same cursor or connection in multiple threads at # once (or perhaps we got more than one cursor from a # connection? We should only have one.) + # + # TODO: Sidestep this problem by allocating this earlier; + # the SELECT or INSERT statement could pick it when it is created; + # that happens at the class level at import time, when we should be + # single-threaded. + # + # That may also help facilitate caching. cls._prepared_stmt_counter += 1 return 'rs_prep_stmt_%d_%d' % ( cls._prepared_stmt_counter, diff --git a/src/relstorage/adapters/sql/tests/__init__.py b/src/relstorage/adapters/sql/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/relstorage/adapters/tests/test__sql.py b/src/relstorage/adapters/sql/tests/test_sql.py similarity index 96% rename from src/relstorage/adapters/tests/test__sql.py rename to src/relstorage/adapters/sql/tests/test_sql.py index 5183a7b1..d318b277 100644 --- a/src/relstorage/adapters/tests/test__sql.py +++ b/src/relstorage/adapters/sql/tests/test_sql.py @@ -23,16 +23,16 @@ from relstorage.tests import TestCase -from .._sql import Table -from .._sql import HistoryVariantTable -from .._sql import Column -from .._sql import bindparam -from .._sql import DefaultDialect -from .._sql import OID -from .._sql import TID -from .._sql import State -from .._sql import Boolean -from .._sql import BinaryString +from .. import Table +from .. import HistoryVariantTable +from .. import Column +from .. import bindparam +from .. import DefaultDialect +from .. import OID +from .. import TID +from .. import State +from .. import Boolean +from .. import BinaryString current_object = Table( 'current_object', @@ -265,7 +265,7 @@ def test_prepared_insert_select_with_param(self): ) def test_it(self): - from .._sql import it + from .. import it stmt = object_state.select( it.c.zoid, it.c.state @@ -298,7 +298,7 @@ def test_it(self): stmt.order_by(col_ref == object_state.c.state) def test_boolean_literal(self): - from .._sql import it + from .. import it stmt = transaction.select( transaction.c.tid ).where( @@ -325,7 +325,7 @@ def test_literal_in_select(self): ) def test_boolean_literal_it_joined_table(self): - from .._sql import it + from .. import it stmt = transaction.natural_join( object_state ).select( diff --git a/src/relstorage/tests/__init__.py b/src/relstorage/tests/__init__.py index 3c112deb..9492d04f 100644 --- a/src/relstorage/tests/__init__.py +++ b/src/relstorage/tests/__init__.py @@ -11,7 +11,7 @@ from relstorage._compat import ABC from relstorage.options import Options -from relstorage.adapters._sql import DefaultDialect +from relstorage.adapters.sql import DefaultDialect try: from unittest import mock From d524bb31fb40e03505b36347c908aeb811f1c140 Mon Sep 17 00:00:00 2001 From: Jason Madden Date: Mon, 22 Jul 2019 11:12:49 -0500 Subject: [PATCH 8/8] Break the sql abstraction layer apart and add unit tests. --- src/relstorage/adapters/sql/__init__.py | 1153 +---------------- src/relstorage/adapters/sql/_util.py | 73 ++ src/relstorage/adapters/sql/ast.py | 54 + src/relstorage/adapters/sql/dialect.py | 373 ++++++ src/relstorage/adapters/sql/expressions.py | 235 ++++ src/relstorage/adapters/sql/functions.py | 28 + src/relstorage/adapters/sql/insert.py | 89 ++ src/relstorage/adapters/sql/interfaces.py | 43 + src/relstorage/adapters/sql/query.py | 177 +++ src/relstorage/adapters/sql/schema.py | 161 +++ src/relstorage/adapters/sql/select.py | 147 +++ src/relstorage/adapters/sql/tests/test_ast.py | 34 + .../adapters/sql/tests/test_dialect.py | 100 ++ .../adapters/sql/tests/test_expressions.py | 65 + .../adapters/sql/tests/test_query.py | 125 ++ .../adapters/sql/tests/test_schema.py | 29 + src/relstorage/adapters/sql/tests/test_sql.py | 81 +- src/relstorage/adapters/sql/types.py | 48 + src/relstorage/tests/__init__.py | 5 +- 19 files changed, 1896 insertions(+), 1124 deletions(-) create mode 100644 src/relstorage/adapters/sql/_util.py create mode 100644 src/relstorage/adapters/sql/ast.py create mode 100644 src/relstorage/adapters/sql/dialect.py create mode 100644 src/relstorage/adapters/sql/expressions.py create mode 100644 src/relstorage/adapters/sql/functions.py create mode 100644 src/relstorage/adapters/sql/insert.py create mode 100644 src/relstorage/adapters/sql/interfaces.py create mode 100644 src/relstorage/adapters/sql/query.py create mode 100644 src/relstorage/adapters/sql/schema.py create mode 100644 src/relstorage/adapters/sql/select.py create mode 100644 src/relstorage/adapters/sql/tests/test_ast.py create mode 100644 src/relstorage/adapters/sql/tests/test_dialect.py create mode 100644 src/relstorage/adapters/sql/tests/test_expressions.py create mode 100644 src/relstorage/adapters/sql/tests/test_query.py create mode 100644 src/relstorage/adapters/sql/tests/test_schema.py create mode 100644 src/relstorage/adapters/sql/types.py diff --git a/src/relstorage/adapters/sql/__init__.py b/src/relstorage/adapters/sql/__init__.py index 78812fe7..f2d94540 100644 --- a/src/relstorage/adapters/sql/__init__.py +++ b/src/relstorage/adapters/sql/__init__.py @@ -40,1112 +40,47 @@ from __future__ import division from __future__ import print_function -from copy import copy as stdlib_copy -from operator import attrgetter -from weakref import WeakKeyDictionary - -from zope.interface import implementer - -from relstorage._compat import NStringIO -from relstorage._compat import intern -from relstorage._util import CachedIn -from ..interfaces import IDBDialect - -def copy(obj): - new = stdlib_copy(obj) - volatile = [k for k in vars(new) if k.startswith('_v')] - for k in volatile: - delattr(new, k) - return new - -class Type(object): - """ - A database type. - """ - -class _Unknown(Type): - "Unspecified." - -class Integer64(Type): - """ - A 64-bit integer. - """ - -class OID(Integer64): - """ - Type of an OID. - """ - -class TID(Integer64): - """ - Type of a TID. - """ - -class BinaryString(Type): - """ - Arbitrary sized binary string. - """ - -class State(Type): - """ - Used for storing object state. - """ - -class Boolean(Type): - """ - A two-value column. - """ - -class _Resolvable(object): - - def resolve_against(self, table): - # pylint:disable=unused-argument - return self - -class Column(_Resolvable): - """ - Defines a column in a table. - """ - - def __init__(self, name, type_=_Unknown, primary_key=False, nullable=True): - self.name = name - self.type_ = type_ if isinstance(type_, Type) else type_() - self.primary_key = primary_key - self.nullable = False if primary_key else nullable - - def is_type(self, type_): - return isinstance(self.type_, type_) - - def __str__(self): - return self.name - - def __eq__(self, other): - return _EqualExpression(self, other) - - def __gt__(self, other): - return _GreaterExpression(self, other) - - def __ge__(self, other): - return _GreaterEqualExpression(self, other) - - def __ne__(self, other): - return _NotEqualExpression(self, other) - - def __le__(self, other): - return _LessEqualExpression(self, other) - - def __compile_visit__(self, compiler): - compiler.visit_column(self) - - -class _LiteralNode(_Resolvable): - def __init__(self, raw): - self.raw = raw - self.name = 'anon_%x' % (id(self),) - - def __compile_visit__(self, compiler): - compiler.emit(str(self.raw)) - - def resolve_against(self, table): - return self - -class _TextNode(_LiteralNode): - pass - -class _Columns(object): - """ - Grab bag of columns. - """ - - def __init__(self, columns): - cs = [] - for c in columns: - setattr(self, c.name, c) - cs.append(c) - self._columns = tuple(cs) - - def __bool__(self): - return bool(self._columns) - - __nonzero__ = __bool__ - - def __getattr__(self, name): - # Here only so that pylint knows this class has a set of - # dynamic attributes. - raise AttributeError("Column list %s does not include %s" % ( - self._col_list(), - name - )) - - def __getitem__(self, ix): - return self._columns[ix] - - def _col_list(self): - return ','.join(str(c) for c in self._columns) - - def __compile_visit__(self, compiler): - compiler.visit_csv(self._columns) - - def has_bind_param(self): - return any( - isinstance(c, (_BindParam, _OrderedBindParam)) - for c in self._columns - ) - - def as_select_list(self): - return _SelectColumns(self._columns) - -_ColumnList = _Columns - -class _SelectColumns(_Columns): - - def __compile_visit__(self, compiler): - compiler.visit_select_list_csv(self._columns) - -class Table(object): - """ - A table relation. - """ - - def __init__(self, name, *columns): - self.name = name - self.columns = columns - self.c = _Columns(columns) - - def __str__(self): - return self.name - - def select(self, *args, **kwargs): - return Select(self, *args, **kwargs) - - def __compile_visit__(self, compiler): - compiler.emit_identifier(self.name) - - def bindparam(self, key): - return bindparam(key) - - def orderedbindparam(self): - return orderedbindparam() - - def natural_join(self, other_table): - return NaturalJoinedTable(self, other_table) - - def insert(self, *args, **kwargs): - return Insert(self, *args, **kwargs) - -class TemporaryTable(Table): - """ - A temporary table. - """ - -class _CompositeTableMixin(object): - - def __init__(self, lhs, rhs): - self.lhs = lhs - self.rhs = rhs - - common_columns = [] - for col in lhs.columns: - if hasattr(rhs.c, col.name): - common_columns.append(col) - self.columns = common_columns - self.c = _Columns(common_columns) - - def select(self, *args, **kwargs): - return Select(self, *args, **kwargs) - - def bindparam(self, key): - return bindparam(key) - - def orderedbindparam(self): - return orderedbindparam() - -@implementer(IDBDialect) -class DefaultDialect(object): - - keep_history = True - - datatype_map = { - OID: 'BIGINT', - TID: 'BIGINT', - BinaryString: 'BYTEA', - State: 'BYTEA', - Boolean: 'BOOLEAN', - } - - def bind(self, context): - # The context will reference us most likely - # (compiled statement in instance dictionary) - # so try to avoid reference cycles. - keep_history = context.keep_history - new = copy(self) - new.keep_history = keep_history - return new - - def compiler_class(self): - return Compiler - - def compiler(self, root): - return self.compiler_class()(root) - - def datatypes_for_columns(self, column_list): - columns = list(column_list) - datatypes = [] - for column in columns: - datatype = self.datatype_map[type(column.type_)] - datatypes.append(datatype) - return datatypes - - def __eq__(self, other): - if isinstance(other, DefaultDialect): - return other.keep_history == self.keep_history - return NotImplemented - - -class _MissingDialect(DefaultDialect): - def __bool__(self): - return False - - __nonzero__ = __bool__ - - -class _Bindable(object): - - context = _MissingDialect() - - _dialect_locations = ( - attrgetter('dialect'), - attrgetter('driver.dialect'), - attrgetter('poller.driver.dialect'), - attrgetter('connmanager.driver.dialect'), - attrgetter('adapter.driver.dialect') - ) - - def _find_dialect(self, context): - # Find the dialect to use for the context. If it specifies - # one, then use it. Otherwise go hunting for the database - # driver and use *it*. Preferably the driver is attached to - # the object we're looking at, but if not that, we'll look at - # some common attributes for adapter objects for it. - if isinstance(context, DefaultDialect): - return context - - for getter in self._dialect_locations: - try: - dialect = getter(context) - except AttributeError: - pass - else: - return dialect.bind(context) - __traceback_info__ = vars(context) - raise TypeError("Unable to bind to %s; no dialect found" % (context,)) - - def bind(self, context): - context = self._find_dialect(context) - if context is None: - return self - - new = copy(self) - new.context = context - bound_replacements = { - k: v.bind(context) - for k, v - in vars(new).items() - if isinstance(v, _Bindable) - } - for k, v in bound_replacements.items(): - setattr(new, k, v) - return new - -class NaturalJoinedTable(_Bindable, - _CompositeTableMixin): - - def __init__(self, lhs, rhs): - super(NaturalJoinedTable, self).__init__(lhs, rhs) - assert self.c - self._join_columns = self.c - - # TODO: Check for data type mismatches, etc? - self.columns = list(self.lhs.columns) - for col in self.rhs.columns: - if not hasattr(self.lhs.c, col.name): - self.columns.append(col) - self.columns = tuple(self.columns) - self.c = _Columns(self.columns) - - def __compile_visit__(self, compiler): - compiler.visit(self.lhs) - compiler.emit_keyword('JOIN') - compiler.visit(self.rhs) - # careful with USING clause in a join: Oracle doesn't allow such - # columns to have a prefix. - compiler.emit_keyword('USING') - compiler.visit_grouped(self._join_columns) - - -class HistoryVariantTable(_Bindable, - _CompositeTableMixin): - """ - A table that can be one of two tables, depending on whether - the instance is keeping history or not. - """ - - @property - def history_preserving(self): - return self.lhs - - @property - def history_free(self): - return self.rhs - - def __compile_visit__(self, compiler): - keep_history = self.context.keep_history - node = self.history_preserving if keep_history else self.history_free - return compiler.visit(node) - - -class _CompiledQuery(object): - """ - Represents a completed query. - """ - - stmt = None - params = None - _raw_stmt = None - _prepare_stmt = None - _prepare_converter = None - - def __init__(self, root): - self.root = root - # We do not keep a reference to the context; - # it's likely to be an instance object that's - # going to have us stored in its dictionary. - context = root.context - - compiler = context.compiler(root) - self.stmt, self.params = compiler.compile() - self._raw_stmt = self.stmt # for debugging - if compiler.can_prepare(): - self._prepare_stmt, self.stmt, self._prepare_converter = compiler.prepare() - - def __repr__(self): - if self._prepare_stmt: - return "%s (%s)" % ( - self.stmt, - self._prepare_stmt - ) - return self.stmt - - def __str__(self): - return self.stmt - - _cursor_cache = WeakKeyDictionary() - - def _stmt_cache_for_cursor(self, cursor): - """Returns a dictionary.""" - # If we can't store it directly on the cursor, as happens for - # types implemented in C, we use a weakkey dictionary. - try: - cursor_prep_stmts = cursor._rs_prepared_statements - except AttributeError: - try: - cursor_prep_stmts = cursor._rs_prepared_statements = {} - except AttributeError: - cursor_prep_stmts = self._cursor_cache.get(cursor) - if cursor_prep_stmts is None: - cursor_prep_stmts = self._cursor_cache[cursor] = {} - return cursor_prep_stmts - - def execute(self, cursor, params=None): - # (Any, dict) -> None - # TODO: Include literals from self.params. - # TODO: Syntax transformation if they don't support names. - # TODO: Validate given params match expected ones, nothing missing? - stmt = self.stmt - if self._prepare_stmt: - # Prepare on demand. - - # In all databases, prepared statements - # persist past COMMIT/ROLLBACK (though in PostgreSQL - # preparing them takes locks that only go away at - # COMMIT/ROLLBACK). But they don't persist past a session - # restart (new connection) (obviously). - # - # Thus we keep a cache of statements we have prepared for - # this particular connection/cursor. - # - cursor_prep_stmts = self._stmt_cache_for_cursor(cursor) - try: - stmt = cursor_prep_stmts[self._prepare_stmt] - except KeyError: - stmt = cursor_prep_stmts[self._prepare_stmt] = self.stmt - __traceback_info__ = self._prepare_stmt, self, self.root.context.compiler(self.root) - cursor.execute(self._prepare_stmt) - params = self._prepare_converter(params) - - __traceback_info__ = stmt, params - if params: - cursor.execute(stmt, params) - elif self.params: - # XXX: This isn't really good. - # If there are both literals in the SQL and params, - # we don't handle that. - cursor.execute(stmt, self.params) - else: - cursor.execute(stmt) - -class Compiler(object): - - def __init__(self, root): - self.buf = NStringIO() - self.placeholders = {} - self.root = root - - - def __repr__(self): - return "<%s %s %r>" % ( - type(self).__name__, - self.buf.getvalue(), - self.placeholders - ) - - def compile(self): - self.visit(self.root) - return self.finalize() - - def can_prepare(self): - # Obviously this needs to be more general. - # Some drivers, for example, can't deal with parameters - # in a prepared statement. - return self.root.prepare and isinstance(self.root, _Query) - - _prepared_stmt_counter = 0 - - @classmethod - def _next_prepared_stmt_name(cls, query): - # Even with the GIL, this isn't fully safe to do; two threads - # can still get the same value. We don't want to allocate a - # lock because we might be patched by gevent later. So that's - # where `query` comes in: we add the hash as a disambiguator. - # Of course, for there to be a duplicate prepared statement - # sent to the database, that would mean that we were somehow - # using the same cursor or connection in multiple threads at - # once (or perhaps we got more than one cursor from a - # connection? We should only have one.) - # - # TODO: Sidestep this problem by allocating this earlier; - # the SELECT or INSERT statement could pick it when it is created; - # that happens at the class level at import time, when we should be - # single-threaded. - # - # That may also help facilitate caching. - cls._prepared_stmt_counter += 1 - return 'rs_prep_stmt_%d_%d' % ( - cls._prepared_stmt_counter, - abs(hash(query)), - ) - - def _prepared_param(self, number): - return '$' + str(number) - - _PREPARED_CONJUNCTION = 'AS' - - def _quote_query_for_prepare(self, query): - return query - - def _find_datatypes_for_prepared_query(self): - # Deduce the datatypes based on the types of the columns - # we're sending as params. - if isinstance(self.root, Insert): - root = self.root - dialect = root.context - # TODO: Should probably delegate this to the node. - if root.values and root.column_list: - # If we're sending in a list of values, those have to - # exactly match the columns, so we can easily get a list - # of datatypes. - column_list = root.column_list - datatypes = dialect.datatypes_for_columns(column_list) - elif root.select and root.select.column_list.has_bind_param(): - targets = root.column_list - sources = root.select.column_list - # TODO: This doesn't support bind params anywhere except the - # select list! - columns_with_params = [ - target - for target, source in zip(targets, sources) - if isinstance(source, _OrderedBindParam) - ] - assert len(self.placeholders) == len(columns_with_params) - datatypes = dialect.datatypes_for_columns(columns_with_params) - return datatypes - return () - - def prepare(self): - # This is correct for PostgreSQL. This needs moved to a dialect specific - # spot. - - datatypes = self._find_datatypes_for_prepared_query() - query = self.buf.getvalue() - name = self._next_prepared_stmt_name(query) - - if datatypes: - assert isinstance(datatypes, (list, tuple)) - datatypes = ', '.join(datatypes) - datatypes = ' (%s)' % (datatypes,) - else: - datatypes = '' - - q = query.strip() - - # PREPARE needs the query string to use $1, $2, $3, etc, - # as placeholders. - # In MySQL, it's a plain question mark. - placeholder_to_number = {} - counter = 0 - for placeholder_name in self.placeholders.values(): - counter += 1 - placeholder = self._placeholder(placeholder_name) - placeholder_to_number[placeholder_name] = counter - param = self._prepared_param(counter) - q = q.replace(placeholder, param, 1) - - q = self._quote_query_for_prepare(q) - - stmt = 'PREPARE {name}{datatypes} {conjunction} {query}'.format( - name=name, datatypes=datatypes, - query=q, - conjunction=self._PREPARED_CONJUNCTION, - ) - - - if placeholder_to_number: - execute = 'EXECUTE {name}({params})'.format( - name=name, - params=','.join(['%s'] * len(self.placeholders)), - ) - else: - # Neither MySQL nor PostgreSQL like a set of empty parens: () - execute = 'EXECUTE {name}'.format(name=name) - - if '%s' in placeholder_to_number: - # There was an ordered param. If there was one, - # they must all be ordered, so there's no need to convert anything. - assert len(placeholder_to_number) == 1 - def convert(p): - return p - else: - def convert(d): - # TODO: This may not actually be needed, since we issue a regular - # cursor.execute(), it may be able to handle named? - params = [None] * len(placeholder_to_number) - for placeholder_name, ix in placeholder_to_number.items(): - params[ix - 1] = d[placeholder_name] - return params - - return intern(stmt), intern(execute), convert - - def finalize(self): - return intern(self.buf.getvalue().strip()), {v: k for k, v in self.placeholders.items()} - - def visit(self, node): - node.__compile_visit__(self) - - visit_clause = visit - - def emit(self, *contents): - for content in contents: - self.buf.write(content) - - def emit_w_padding_space(self, value): - ended_in_space = self.buf.getvalue().endswith(' ') - value = value.strip() - if not ended_in_space: - self.buf.write(' ') - self.emit(value, ' ') - - emit_keyword = emit_w_padding_space - - def emit_identifier(self, identifier): - last_char = self.buf.getvalue()[-1] - if last_char not in ('(', ' '): - self.emit(' ', identifier) - else: - self.emit(identifier) - - def visit_column_list(self, column_list): - clist = column_list.c if hasattr(column_list, 'c') else column_list - self.visit(clist) - - def visit_select_list(self, column_list): - clist = column_list.c if hasattr(column_list, 'c') else column_list - self.visit(clist.as_select_list()) - - def visit_csv(self, nodes): - self.visit(nodes[0]) - for node in nodes[1:]: - self.emit(', ') - self.visit(node) - - visit_select_expression = visit - - def visit_select_list_csv(self, nodes): - self.visit_select_expression(nodes[0]) - for node in nodes[1:]: - self.emit(', ') - self.visit_select_expression(node) - - def visit_column(self, column_node): - self.emit_identifier(column_node.name) - - def visit_from(self, from_): - self.emit_keyword('FROM') - self.visit(from_) - - def visit_grouped(self, clause): - self.emit('(') - self.visit(clause) - self.emit(')') - - def visit_op(self, op): - self.emit(' ' + op + ' ') - - def _next_placeholder_name(self, prefix='param'): - return '%s_%d' % (prefix, len(self.placeholders),) - - def _placeholder(self, key): - # Write things in `pyformat` style by default, assuming a - # dictionary of params; this is supported by most drivers. - if key == '%s': - return key - return '%%(%s)s' % (key,) - - def _placeholder_for_literal_param_value(self, value): - placeholder = self.placeholders.get(value) - if not placeholder: - placeholder_name = self._next_placeholder_name(prefix='literal') - placeholder = self._placeholder(placeholder_name) - self.placeholders[value] = placeholder_name - return placeholder - - def visit_literal_expression(self, value): - placeholder = self._placeholder_for_literal_param_value(value) - self.emit(placeholder) - - def visit_boolean_literal_expression(self, value): - # In the oracle dialect, this needs to be - # either "'Y'" or "'N'" - assert isinstance(value, bool) - self.emit(str(value).upper()) - - def visit_bind_param(self, bind_param): - self.placeholders[bind_param] = bind_param.key - self.emit(self._placeholder(bind_param.key)) - - def visit_ordered_bind_param(self, bind_param): - self.placeholders[bind_param] = '%s' - self.emit('%s') - -_Compiler = Compiler # BWC. Remove - -class _Expression(_Bindable, - _Resolvable): - """ - A SQL expression. - """ - -class _BindParam(_Expression): - - def __init__(self, key): - self.key = key - - def __compile_visit__(self, compiler): - compiler.visit_bind_param(self) - - -def bindparam(key): - return _BindParam(key) - -class _LiteralExpression(_Expression): - - def __init__(self, value): - self.value = value - - def __compile_visit__(self, compiler): - compiler.visit_literal_expression(self.value) - -class _BooleanLiteralExpression(_LiteralExpression): - - def __compile_visit__(self, compiler): - compiler.visit_boolean_literal_expression(self.value) - -class _OrderedBindParam(_Expression): - - name = '%s' - - def __compile_visit__(self, compiler): - compiler.visit_ordered_bind_param(self) - -def orderedbindparam(): - return _OrderedBindParam() - -def _as_node(c): - if isinstance(c, int): - return _LiteralNode(c) - if isinstance(c, str): - return _TextNode(c) - return c - -def _as_expression(stmt): - if hasattr(stmt, '__compile_visit__'): - return stmt - - if isinstance(stmt, bool): - # The values True and False are handled - # specially because their representation varies - # among databases (Oracle) - stmt = _BooleanLiteralExpression(stmt) - else: - stmt = _LiteralExpression(stmt) - return stmt - -class _BinaryExpression(_Expression): - """ - Expresses a comparison. - """ - - def __init__(self, op, lhs, rhs): - self.op = op - self.lhs = lhs # type: Column - # rhs is either a literal or a column; - # certain literals are handled specially. - rhs = _as_expression(rhs) - self.rhs = rhs - - def __str__(self): - return '%s %s %s' % ( - self.lhs, - self.op, - self.rhs - ) - - def __compile_visit__(self, compiler): - compiler.visit(self.lhs) - compiler.visit_op(self.op) - compiler.visit(self.rhs) - - def resolve_against(self, table): - lhs = self.lhs.resolve_against(table) - rhs = self.rhs.resolve_against(table) - new = copy(self) - new.rhs = rhs - new.lhs = lhs - return new - -class _EmptyExpression(_Expression): - """ - No comparison at all. - """ - - def __bool__(self): - return False - - __nonzero__ = __bool__ - - def __str__(self): - return '' - - def and_(self, expression): - return expression - - def __compile_visit__(self, compiler): - "Does nothing" - -class _EqualExpression(_BinaryExpression): - - def __init__(self, lhs, rhs): - _BinaryExpression.__init__(self, '=', lhs, rhs) - -class _NotEqualExpression(_BinaryExpression): - - def __init__(self, lhs, rhs): - _BinaryExpression.__init__(self, '<>', lhs, rhs) - - -class _GreaterExpression(_BinaryExpression): - - def __init__(self, lhs, rhs): - _BinaryExpression.__init__(self, '>', lhs, rhs) - -class _GreaterEqualExpression(_BinaryExpression): - - def __init__(self, lhs, rhs): - _BinaryExpression.__init__(self, '>=', lhs, rhs) - -class _LessEqualExpression(_BinaryExpression): - - def __init__(self, lhs, rhs): - _BinaryExpression.__init__(self, '<=', lhs, rhs) - - -class _Clause(_Bindable): - """ - A portion of a SQL statement. - """ - -class _And(_Expression): - - def __init__(self, lhs, rhs): - self.lhs = lhs - self.rhs = rhs - - def __compile_visit__(self, compiler): - compiler.visit_grouped(_BinaryExpression('AND', self.lhs, self.rhs)) - - def resolve_against(self, table): - return type(self)(self.lhs.resolve_against(table), - self.rhs.resolve_against(table)) - -class _WhereClause(_Clause): - - def __init__(self, expression): - self.expression = expression - - def and_(self, expression): - expression = _And(self.expression, expression) - new = copy(self) - new.expression = expression - return new - - def __compile_visit__(self, compiler): - compiler.emit_keyword(' WHERE') - compiler.visit_grouped(self.expression) - -class _OrderBy(_Clause): - - def __init__(self, expression, dir): - self.expression = expression - self.dir = dir - - def __compile_visit__(self, compiler): - compiler.emit(' ORDER BY ') - compiler.visit(self.expression) - if self.dir: - compiler.emit(' ' + self.dir) - - -def _where(expression): - if expression: - return _WhereClause(expression) - -class _Query(_Bindable): - __name__ = None - prepare = False - - def __str__(self): - return str(self.compiled()) - - def __get__(self, inst, klass): - if inst is None: - return self - # We need to set this into the instance's dictionary. - # Otherwise we'll be rebinding and recompiling each time we're - # accessed which is not good. On Python 3.6+, there's the - # `__set_name__(klass, name)` called on the descriptor which - # does the job perfectly. In earlier versions, we're on our - # own. - # TODO: In test cases, we spend a lot of time binding and compiling. - # Can we find another layer of caching somewhere? - result = self.bind(inst).compiled() - if not self.__name__: - # Go through the class hierarchy, find out what we're called. - for base in klass.mro(): - for k, v in vars(base).items(): - if v is self: - self.__name__ = k - break - assert self.__name__ - - vars(inst)[self.__name__] = result - - return result - - def __set_name__(self, owner, name): - self.__name__ = name - - @CachedIn('_v_compiled') - def compiled(self): - return _CompiledQuery(self) - - def prepared(self): - """ - Note that it's good to prepare this query, if - supported by the driver. - """ - s = copy(self) - s.prepare = True - return s - -class _DeferredColumn(Column): - - def resolve_against(self, table): - return getattr(table.c, self.name) - -class _DeferredColumns(object): - - def __getattr__(self, name): - return _DeferredColumn(name) - -class _It(object): - """ - A proxy that select can resolve to tables in the current table. - """ - - c = _DeferredColumns() - - def bindparam(self, name): - return bindparam(name) - -it = _It() - -def _resolved_against(columns, table): - resolved = [ - _as_node(c).resolve_against(table) - for c - in columns - ] - return resolved - -class Select(_Query): - """ - A Select query. - - When instances of this class are stored in a class dictionary, - they function as non-data descriptors: The first time they are - accessed, they *bind* themselves to the instance and select the - appropriate SQL syntax and compile themselves into a string. - """ - - _distinct = _EmptyExpression() - _where = _EmptyExpression() - _order_by = _EmptyExpression() - _limit = None - _for_update = None - _nowait = None - - def __init__(self, table, *columns): - self.table = table - if columns: - self.column_list = _ColumnList(_resolved_against(columns, table)) - else: - self.column_list = table - - def where(self, expression): - expression = expression.resolve_against(self.table) - s = copy(self) - s._where = _where(expression) - return s - - def and_(self, expression): - expression = expression.resolve_against(self.table) - s = copy(self) - s._where = self._where.and_(expression) - return s - - def order_by(self, expression, dir=None): - expression = expression.resolve_against(self.table) - s = copy(self) - s._order_by = _OrderBy(expression, dir) - return s - - def limit(self, literal): - s = copy(self) - s._limit = literal - return s - - def for_update(self): - s = copy(self) - s._for_update = 'FOR UPDATE' - return s - - def nowait(self): - s = copy(self) - s._nowait = 'NOWAIT' - return s - - def distinct(self): - s = copy(self) - s._distinct = _TextNode('DISTINCT') - return s - - def __compile_visit__(self, compiler): - compiler.emit_keyword('SELECT') - compiler.visit(self._distinct) - compiler.visit_select_list(self.column_list) - compiler.visit_from(self.table) - compiler.visit_clause(self._where) - compiler.visit_clause(self._order_by) - if self._limit: - compiler.emit_keyword('LIMIT') - compiler.emit(str(self._limit)) - if self._for_update: - compiler.emit_keyword(self._for_update) - if self._nowait: - compiler.emit_keyword(self._nowait) - -class _Functions(object): - - def max(self, column): - return _Function('max', column) - -class _Function(_Expression): - - def __init__(self, name, expression): - self.name = name - self.expression = expression - - def __compile_visit__(self, compiler): - compiler.emit_identifier(self.name) - compiler.visit_grouped(self.expression) - -func = _Functions() - -class Insert(_Query): - - column_list = None - select = None - epilogue = '' - values = None - - def __init__(self, table, *columns): - self.table = table - if columns: - self.column_list = _Columns(_resolved_against(columns, table)) - # TODO: Probably want a different type, like a ValuesList - self.values = _Columns([orderedbindparam() for _ in columns]) - - def from_select(self, names, select): - i = copy(self) - i.column_list = _Columns(names) - i.select = select - return i - - def __compile_visit__(self, compiler): - compiler.emit_keyword('INSERT INTO') - compiler.visit(self.table) - compiler.visit_grouped(self.column_list) - if self.select: - compiler.visit(self.select) - else: - compiler.emit_keyword('VALUES') - compiler.visit_grouped(self.values) - compiler.emit(self.epilogue) - - def __add__(self, extension): - # This appends a textual epilogue. It's a temporary - # measure until we have more nodes and can model what - # we're trying to accomplish. - assert isinstance(extension, str) - i = copy(self) - i.epilogue += extension - return i +from .schema import Table +from .schema import TemporaryTable +from .schema import HistoryVariantTable +from .schema import Column +from .schema import ColumnResolvingProxy + +from .dialect import DefaultDialect +from .dialect import Compiler + +from .types import OID +from .types import TID +from .types import State +from .types import Boolean +from .types import BinaryString + +from .functions import func + +it = ColumnResolvingProxy() + +__all__ = [ + # Schema elements + 'Table', + 'Column', + "TemporaryTable", + "HistoryVariantTable", + + # Query helpers + 'it', + + # Dialect + "DefaultDialect", + "Compiler", + + # Types + "OID", + "TID", + "State", + "Boolean", + "BinaryString", + + # Functions + "func", + +] diff --git a/src/relstorage/adapters/sql/_util.py b/src/relstorage/adapters/sql/_util.py new file mode 100644 index 00000000..077943ee --- /dev/null +++ b/src/relstorage/adapters/sql/_util.py @@ -0,0 +1,73 @@ +# -*- coding: utf-8 -*- +""" +Utility functions and base classes. + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from copy import copy as stdlib_copy + +from .interfaces import IBindParam + +def copy(obj): + new = stdlib_copy(obj) + volatile = [k for k in vars(new) if k.startswith('_v')] + for k in volatile: + delattr(new, k) + return new + +class Resolvable(object): + + __slots__ = () + + def resolve_against(self, table): + # pylint:disable=unused-argument + return self + + + +class Columns(object): + """ + Grab bag of columns. + """ + + def __init__(self, columns): + cs = [] + for c in columns: + setattr(self, c.name, c) + cs.append(c) + self._columns = tuple(cs) + + def __bool__(self): + return bool(self._columns) + + __nonzero__ = __bool__ + + def __getattr__(self, name): + # Here only so that pylint knows this class has a set of + # dynamic attributes. + raise AttributeError("Column list %s does not include %s" % ( + self._col_list(), + name + )) + + def __getitem__(self, ix): + return self._columns[ix] + + def _col_list(self): + return ','.join(str(c) for c in self._columns) + + def __compile_visit__(self, compiler): + compiler.visit_csv(self._columns) + + def has_bind_param(self): + return any( + IBindParam.providedBy(c) + for c in self._columns + ) + + def as_select_list(self): + from .select import _SelectColumns + return _SelectColumns(self._columns) diff --git a/src/relstorage/adapters/sql/ast.py b/src/relstorage/adapters/sql/ast.py new file mode 100644 index 00000000..b7a6180e --- /dev/null +++ b/src/relstorage/adapters/sql/ast.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +""" +Syntax elements. + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from ._util import Resolvable + + +class LiteralNode(Resolvable): + + __slots__ = ( + 'raw', + 'name', + ) + + def __init__(self, raw): + self.raw = raw + self.name = 'anon_%x' % (id(self),) + + def __compile_visit__(self, compiler): + compiler.emit(str(self.raw)) + + def resolve_against(self, table): + return self + +class BooleanNode(LiteralNode): + + __slots__ = () + +class TextNode(LiteralNode): + __slots__ = () + +def as_node(c): + if isinstance(c, bool): + return BooleanNode(c) + if isinstance(c, int): + return LiteralNode(c) + if isinstance(c, str): + return TextNode(c) + return c + + + +def resolved_against(columns, table): + resolved = [ + as_node(c).resolve_against(table) + for c + in columns + ] + return resolved diff --git a/src/relstorage/adapters/sql/dialect.py b/src/relstorage/adapters/sql/dialect.py new file mode 100644 index 00000000..ed84efbf --- /dev/null +++ b/src/relstorage/adapters/sql/dialect.py @@ -0,0 +1,373 @@ +# -*- coding: utf-8 -*- +""" +RDBMS-specific SQL. + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from operator import attrgetter + +from zope.interface import implementer + +from relstorage._compat import NStringIO +from relstorage._compat import intern +from ..interfaces import IDBDialect + +from .types import OID +from .types import TID +from .types import BinaryString +from .types import State +from .types import Boolean + +from ._util import copy +from .interfaces import ITypedParams + +# pylint:disable=too-many-function-args + +@implementer(IDBDialect) +class DefaultDialect(object): + + keep_history = True + _context_repr = None + + datatype_map = { + OID: 'BIGINT', + TID: 'BIGINT', + BinaryString: 'BYTEA', + State: 'BYTEA', + Boolean: 'BOOLEAN', + } + + def bind(self, context): + # The context will reference us most likely + # (compiled statement in instance dictionary) + # so try to avoid reference cycles. + keep_history = context.keep_history + new = copy(self) + new.keep_history = keep_history + new._context_repr = repr(context) + return new + + def compiler_class(self): + return Compiler + + def compiler(self, root): + return self.compiler_class()(root) + + def datatypes_for_columns(self, column_list): + columns = list(column_list) + datatypes = [] + for column in columns: + datatype = self.datatype_map[type(column.type_)] + datatypes.append(datatype) + return datatypes + + def __eq__(self, other): + if isinstance(other, DefaultDialect): + return other.keep_history == self.keep_history + return NotImplemented # pragma: no cover + + def __repr__(self): + return "<%s at %x keep_history=%s context=%s>" % ( + type(self).__name__, + id(self), + self.keep_history, + self._context_repr + ) + + +class _MissingDialect(DefaultDialect): + def __bool__(self): + return False + + __nonzero__ = __bool__ + + +class Compiler(object): + + def __init__(self, root): + self.buf = NStringIO() + self.placeholders = {} + self.root = root + + def __repr__(self): + return "<%s %s %r>" % ( + type(self).__name__, + self.buf.getvalue(), + self.placeholders + ) + + def compile(self): + self.visit(self.root) + return self.finalize() + + def can_prepare(self): + # Obviously this needs to be more general. + # Some drivers, for example, can't deal with parameters + # in a prepared statement; we currently handle that by overriding + # this method. + return self.root.prepare + + _prepared_stmt_counter = 0 + + @classmethod + def _next_prepared_stmt_name(cls, query): + # Even with the GIL, this isn't fully safe to do; two threads + # can still get the same value. We don't want to allocate a + # lock because we might be patched by gevent later. So that's + # where `query` comes in: we add the hash as a disambiguator. + # Of course, for there to be a duplicate prepared statement + # sent to the database, that would mean that we were somehow + # using the same cursor or connection in multiple threads at + # once (or perhaps we got more than one cursor from a + # connection? We should only have one.) + # + # TODO: Sidestep this problem by allocating this earlier; + # the SELECT or INSERT statement could pick it when it is created; + # that happens at the class level at import time, when we should be + # single-threaded. + # + # That may also help facilitate caching. + cls._prepared_stmt_counter += 1 + return 'rs_prep_stmt_%d_%d' % ( + cls._prepared_stmt_counter, + abs(hash(query)), + ) + + def _prepared_param(self, number): + return '$' + str(number) + + _PREPARED_CONJUNCTION = 'AS' + + def _quote_query_for_prepare(self, query): + return query + + def _find_datatypes_for_prepared_query(self): + # Deduce the datatypes based on the types of the columns + # we're sending as params. + result = () + param_provider = ITypedParams(self.root, None) + if param_provider is not None: + result = param_provider.datatypes_for_parameters() # pylint:disable=assignment-from-no-return + return result + + def prepare(self): + # This is correct for PostgreSQL. This needs moved to a dialect specific + # spot. + + datatypes = self._find_datatypes_for_prepared_query() + query = self.buf.getvalue() + name = self._next_prepared_stmt_name(query) + + if datatypes: + assert isinstance(datatypes, (list, tuple)) + datatypes = ', '.join(datatypes) + datatypes = ' (%s)' % (datatypes,) + else: + datatypes = '' + + q = query.strip() + + # PREPARE needs the query string to use $1, $2, $3, etc, + # as placeholders. + # In MySQL, it's a plain question mark. + placeholder_to_number = {} + counter = 0 + for placeholder_name in self.placeholders.values(): + counter += 1 + placeholder = self._placeholder(placeholder_name) + placeholder_to_number[placeholder_name] = counter + param = self._prepared_param(counter) + q = q.replace(placeholder, param, 1) + + q = self._quote_query_for_prepare(q) + + stmt = 'PREPARE {name}{datatypes} {conjunction} {query}'.format( + name=name, datatypes=datatypes, + query=q, + conjunction=self._PREPARED_CONJUNCTION, + ) + + + if placeholder_to_number: + execute = 'EXECUTE {name}({params})'.format( + name=name, + params=','.join(['%s'] * len(self.placeholders)), + ) + else: + # Neither MySQL nor PostgreSQL like a set of empty parens: () + execute = 'EXECUTE {name}'.format(name=name) + + if '%s' in placeholder_to_number: + # There was an ordered param. If there was one, + # they must all be ordered, so there's no need to convert anything. + assert len(placeholder_to_number) == 1 + def convert(p): + return p + else: + def convert(d): + # TODO: This may not actually be needed, since we issue a regular + # cursor.execute(), it may be able to handle named? + params = [None] * len(placeholder_to_number) + for placeholder_name, ix in placeholder_to_number.items(): + params[ix - 1] = d[placeholder_name] + return params + + return intern(stmt), intern(execute), convert + + def finalize(self): + return intern(self.buf.getvalue().strip()), {v: k for k, v in self.placeholders.items()} + + def visit(self, node): + node.__compile_visit__(self) + + visit_clause = visit + + def emit(self, *contents): + for content in contents: + self.buf.write(content) + + def emit_w_padding_space(self, value): + ended_in_space = self.buf.getvalue().endswith(' ') + value = value.strip() + if not ended_in_space: + self.buf.write(' ') + self.emit(value, ' ') + + emit_keyword = emit_w_padding_space + + def emit_identifier(self, identifier): + last_char = self.buf.getvalue()[-1] + if last_char not in ('(', ' '): + self.emit(' ', identifier) + else: + self.emit(identifier) + + def visit_select_list(self, column_list): + clist = column_list.c if hasattr(column_list, 'c') else column_list + self.visit(clist.as_select_list()) + + def visit_csv(self, nodes): + self.visit(nodes[0]) + for node in nodes[1:]: + self.emit(', ') + self.visit(node) + + visit_select_expression = visit + + def visit_select_list_csv(self, nodes): + self.visit_select_expression(nodes[0]) + for node in nodes[1:]: + self.emit(', ') + self.visit_select_expression(node) + + def visit_column(self, column_node): + self.emit_identifier(column_node.name) + + def visit_from(self, from_): + self.emit_keyword('FROM') + self.visit(from_) + + def visit_grouped(self, clause): + self.emit('(') + self.visit(clause) + self.emit(')') + + def visit_op(self, op): + self.emit(' ' + op + ' ') + + def _next_placeholder_name(self, prefix='param'): + return '%s_%d' % (prefix, len(self.placeholders),) + + def _placeholder(self, key): + # Write things in `pyformat` style by default, assuming a + # dictionary of params; this is supported by most drivers. + if key == '%s': + return key + return '%%(%s)s' % (key,) + + def _placeholder_for_literal_param_value(self, value): + placeholder = self.placeholders.get(value) + if not placeholder: + placeholder_name = self._next_placeholder_name(prefix='literal') + placeholder = self._placeholder(placeholder_name) + self.placeholders[value] = placeholder_name + return placeholder + + def visit_literal_expression(self, value): + placeholder = self._placeholder_for_literal_param_value(value) + self.emit(placeholder) + + def visit_boolean_literal_expression(self, value): + # In the oracle dialect, this needs to be + # either "'Y'" or "'N'" + assert isinstance(value, bool) + self.emit(str(value).upper()) + + def visit_bind_param(self, bind_param): + self.placeholders[bind_param] = bind_param.key + self.emit(self._placeholder(bind_param.key)) + + def visit_ordered_bind_param(self, bind_param): + self.placeholders[bind_param] = '%s' + self.emit('%s') + + +class _DefaultContext(object): + + keep_history = True + + +class DialectAware(object): + + context = _DefaultContext() + dialect = _MissingDialect() + + _dialect_locations = ( + attrgetter('dialect'), + attrgetter('driver.dialect'), + attrgetter('poller.driver.dialect'), + attrgetter('connmanager.driver.dialect'), + attrgetter('adapter.driver.dialect') + ) + + def _find_dialect(self, context): + # Find the dialect to use for the context. If it specifies + # one, then use it. Otherwise go hunting for the database + # driver and use *it*. Preferably the driver is attached to + # the object we're looking at, but if not that, we'll look at + # some common attributes for adapter objects for it. + if isinstance(context, DefaultDialect): + return context + + for getter in self._dialect_locations: + try: + dialect = getter(context) + except AttributeError: + pass + else: + return dialect.bind(context) + __traceback_info__ = getattr(context, '__dict__', ()) # vars() doesn't work on e.g., None + raise TypeError("Unable to bind to %s; no dialect found" % (context,)) + + def bind(self, context, dialect=None): + if dialect is None: + dialect = self._find_dialect(context) + + assert dialect is not None + + new = copy(self) + if context is not None: + new.context = context + new.dialect = dialect + bound_replacements = { + k: v.bind(context, dialect) + for k, v + in vars(new).items() + if isinstance(v, DialectAware) + } + for k, v in bound_replacements.items(): + setattr(new, k, v) + return new diff --git a/src/relstorage/adapters/sql/expressions.py b/src/relstorage/adapters/sql/expressions.py new file mode 100644 index 00000000..3c4eafcd --- /dev/null +++ b/src/relstorage/adapters/sql/expressions.py @@ -0,0 +1,235 @@ +# -*- coding: utf-8 -*- +""" +Expressions in the AST. + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from zope.interface import implementer + +from ._util import Resolvable +from ._util import copy + +from .dialect import DialectAware +from .interfaces import INamedBindParam +from .interfaces import IOrderedBindParam + +class Expression(DialectAware, + Resolvable): + """ + A SQL expression. + """ + + __slots__ = () + +@implementer(INamedBindParam) +class BindParam(Expression): + + __slots__ = ( + 'key', + ) + + def __init__(self, key): + self.key = key + + def __compile_visit__(self, compiler): + compiler.visit_bind_param(self) + +@implementer(INamedBindParam) +def bindparam(key): + return BindParam(key) + + +class LiteralExpression(Expression): + + __slots__ = ( + 'value', + ) + + def __init__(self, value): + self.value = value + + def __str__(self): + return str(self.value) + + def __compile_visit__(self, compiler): + compiler.visit_literal_expression(self.value) + +class BooleanLiteralExpression(LiteralExpression): + + __slots__ = () + + def __compile_visit__(self, compiler): + compiler.visit_boolean_literal_expression(self.value) + +@implementer(IOrderedBindParam) +class OrderedBindParam(Expression): + + __slots__ = () + + name = '%s' + + def __compile_visit__(self, compiler): + compiler.visit_ordered_bind_param(self) + +@implementer(IOrderedBindParam) +def orderedbindparam(): + return OrderedBindParam() + + +def as_expression(stmt): + if hasattr(stmt, '__compile_visit__'): + return stmt + + if isinstance(stmt, bool): + # The values True and False are handled + # specially because their representation varies + # among databases (Oracle) + stmt = BooleanLiteralExpression(stmt) + else: + stmt = LiteralExpression(stmt) + return stmt + +class BinaryExpression(Expression): + """ + Expresses a comparison. + """ + + __slots__ = ( + 'op', + 'lhs', + 'rhs', + ) + + def __init__(self, op, lhs, rhs): + self.op = op + self.lhs = lhs # type: Column + # rhs is either a literal or a column; + # certain literals are handled specially. + rhs = as_expression(rhs) + self.rhs = rhs + + def __str__(self): + return '%s %s %s' % ( + self.lhs, + self.op, + self.rhs + ) + + def __compile_visit__(self, compiler): + compiler.visit(self.lhs) + compiler.visit_op(self.op) + compiler.visit(self.rhs) + + def resolve_against(self, table): + lhs = self.lhs.resolve_against(table) + rhs = self.rhs.resolve_against(table) + new = copy(self) + new.rhs = rhs + new.lhs = lhs + return new + +class EmptyExpression(Expression): + """ + No comparison at all. + """ + + __slots__ = () + + def __bool__(self): + return False + + __nonzero__ = __bool__ + + def __str__(self): + return '' + + def and_(self, expression): + return expression + + def __compile_visit__(self, compiler): + "Does nothing" + +class EqualExpression(BinaryExpression): + + __slots__ = () + + def __init__(self, lhs, rhs): + BinaryExpression.__init__(self, '=', lhs, rhs) + +class NotEqualExpression(BinaryExpression): + + __slots__ = () + + def __init__(self, lhs, rhs): + BinaryExpression.__init__(self, '<>', lhs, rhs) + + +class GreaterExpression(BinaryExpression): + + __slots__ = () + + def __init__(self, lhs, rhs): + BinaryExpression.__init__(self, '>', lhs, rhs) + +class GreaterEqualExpression(BinaryExpression): + + __slots__ = () + + def __init__(self, lhs, rhs): + BinaryExpression.__init__(self, '>=', lhs, rhs) + +class LessEqualExpression(BinaryExpression): + + __slots__ = () + + def __init__(self, lhs, rhs): + BinaryExpression.__init__(self, '<=', lhs, rhs) + +class And(Expression): + + __slots__ = ( + 'lhs', + 'rhs', + ) + + def __init__(self, lhs, rhs): + self.lhs = as_expression(lhs) + self.rhs = as_expression(rhs) + + def __compile_visit__(self, compiler): + compiler.visit_grouped(BinaryExpression('AND', self.lhs, self.rhs)) + + def resolve_against(self, table): + lhs = self.lhs.resolve_against(table) + rhs = self.rhs.resolve_against(table) + new = copy(self) + self.lhs = lhs + self.rhs = rhs + return new + + +class ParamMixin(object): + def bindparam(self, key): + return bindparam(key) + + def orderedbindparam(self): + return orderedbindparam() + +class ExpressionOperatorMixin(object): + def __eq__(self, other): + return EqualExpression(self, other) + + def __gt__(self, other): + return GreaterExpression(self, other) + + def __ge__(self, other): + return GreaterEqualExpression(self, other) + + def __ne__(self, other): + return NotEqualExpression(self, other) + + def __le__(self, other): + return LessEqualExpression(self, other) diff --git a/src/relstorage/adapters/sql/functions.py b/src/relstorage/adapters/sql/functions.py new file mode 100644 index 00000000..96333097 --- /dev/null +++ b/src/relstorage/adapters/sql/functions.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +""" +Function expressions. + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from .expressions import Expression + +class _Functions(object): + + def max(self, column): + return _Function('max', column) + +class _Function(Expression): + + def __init__(self, name, expression): + self.name = name + self.expression = expression + + def __compile_visit__(self, compiler): + compiler.emit_identifier(self.name) + compiler.visit_grouped(self.expression) + +func = _Functions() diff --git a/src/relstorage/adapters/sql/insert.py b/src/relstorage/adapters/sql/insert.py new file mode 100644 index 00000000..8542c3c7 --- /dev/null +++ b/src/relstorage/adapters/sql/insert.py @@ -0,0 +1,89 @@ +# -*- coding: utf-8 -*- +""" +The ``INSERT`` statement. + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from zope.interface import implementer + +from .query import Query +from ._util import copy +from .query import ColumnList +from .ast import resolved_against + +from .interfaces import ITypedParams +from .interfaces import IOrderedBindParam + +@implementer(ITypedParams) +class Insert(Query): + + column_list = None + select = None + epilogue = '' + values = None + + def __init__(self, table, *columns): + super(Insert, self).__init__() + self.table = table + if columns: + self.column_list = ColumnList(resolved_against(columns, table)) + # TODO: Probably want a different type, like a ValuesList + self.values = ColumnList([self.orderedbindparam() for _ in columns]) + + def from_select(self, names, select): + i = copy(self) + i.column_list = ColumnList(names) + i.select = select + return i + + def __compile_visit__(self, compiler): + compiler.emit_keyword('INSERT INTO') + compiler.visit(self.table) + compiler.visit_grouped(self.column_list) + if self.select: + compiler.visit(self.select) + else: + compiler.emit_keyword('VALUES') + compiler.visit_grouped(self.values) + compiler.emit(self.epilogue) + + def __add__(self, extension): + # This appends a textual epilogue. It's a temporary + # measure until we have more nodes and can model what + # we're trying to accomplish. + assert isinstance(extension, str) + i = copy(self) + i.epilogue += extension + return i + + def datatypes_for_parameters(self): + dialect = self.dialect + if self.values and self.column_list: + # If we're sending in a list of values, those have to + # exactly match the columns, so we can easily get a list + # of datatypes. + column_list = self.column_list + datatypes = dialect.datatypes_for_columns(column_list) + elif self.select and self.select.column_list.has_bind_param(): + targets = self.column_list + sources = self.select.column_list + # TODO: This doesn't support bind params anywhere except the + # select list! + # TODO: This doesn't support named bind params. + columns_with_params = [ + target + for target, source in zip(targets, sources) + if IOrderedBindParam.providedBy(source) + ] + datatypes = dialect.datatypes_for_columns(columns_with_params) + return datatypes + + + +class Insertable(object): + + def insert(self, *columns): + return Insert(self, *columns) diff --git a/src/relstorage/adapters/sql/interfaces.py b/src/relstorage/adapters/sql/interfaces.py new file mode 100644 index 00000000..46126a7e --- /dev/null +++ b/src/relstorage/adapters/sql/interfaces.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- +""" +Interfaces, mostly internal, for the sql module. + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint:disable=inherit-non-class,no-method-argument + +from zope.interface import Interface + +class ITypedParams(Interface): + """ + Something that accepts parameters, and knows + what their types should be. + """ + + def datatypes_for_parameters(): + """ + Returns a sequence of datatypes. + + XXX: This only works for ordered params; make this work + for named parameters. Probably want to treat the two the same, + with ordered parameters using indexes as their name. + """ + + +class IBindParam(Interface): + """ + A parameter to a query. + """ + +class INamedBindParam(IBindParam): + """ + A named parameter. + """ + +class IOrderedBindParam(IBindParam): + """ + A anonymous parameter, identified only by order. + """ diff --git a/src/relstorage/adapters/sql/query.py b/src/relstorage/adapters/sql/query.py new file mode 100644 index 00000000..abd45344 --- /dev/null +++ b/src/relstorage/adapters/sql/query.py @@ -0,0 +1,177 @@ +# -*- coding: utf-8 -*- +""" +Compiled queries ready for execution. + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from weakref import WeakKeyDictionary + +from relstorage._util import CachedIn + +from ._util import copy +from ._util import Columns +from .dialect import DialectAware + +from .expressions import ParamMixin + + +class Clause(DialectAware): + """ + A portion of a SQL statement. + """ + + + +class ColumnList(Columns): + """ + List of columns used in a query. + """ + + # This class exists for semantics, it currently doesn't + # do anything different than the super. + + +class Query(ParamMixin, + Clause): + __name__ = None + prepare = False + + def __str__(self): + return str(self.compiled()) + + def __get__(self, inst, klass): + if inst is None: + return self + # We need to set this into the instance's dictionary. + # Otherwise we'll be rebinding and recompiling each time we're + # accessed which is not good (in fact it's a step backwards + # from query_property()). On Python 3.6+, there's the + # `__set_name__(klass, name)` called on the descriptor which + # does the job perfectly. In earlier versions, we're on our + # own. + # + # TODO: In test cases, we spend a lot of time binding and compiling. + # Can we find another layer of caching somewhere? + result = self.bind(inst).compiled() + if not self.__name__: + # Go through the class hierarchy, find out what we're called. + for base in klass.mro(): + for k, v in vars(base).items(): + if v is self: + self.__name__ = k + break + assert self.__name__ + + vars(inst)[self.__name__] = result + + return result + + def __set_name__(self, owner, name): + self.__name__ = name + + @CachedIn('_v_compiled') + def compiled(self): + return CompiledQuery(self) + + def prepared(self): + """ + Note that it's good to prepare this query, if + supported by the driver. + """ + s = copy(self) + s.prepare = True + return s + + +class CompiledQuery(object): + """ + Represents a completed query. + """ + + stmt = None + params = None + _raw_stmt = None + _prepare_stmt = None + _prepare_converter = None + + def __init__(self, root): + self.root = root + # We do not keep a reference to the context; + # it's likely to be an instance object that's + # going to have us stored in its dictionary. + dialect = root.dialect + + compiler = dialect.compiler(root) + self.stmt, self.params = compiler.compile() + self._raw_stmt = self.stmt # for debugging + if compiler.can_prepare(): + self._prepare_stmt, self.stmt, self._prepare_converter = compiler.prepare() + + def __repr__(self): + if self._prepare_stmt: + return "%s (%s)" % ( + self.stmt, + self._prepare_stmt + ) + return self.stmt + + def __str__(self): + return self.stmt + + _cursor_cache = WeakKeyDictionary() + + def _stmt_cache_for_cursor(self, cursor): + """Returns a dictionary.""" + # If we can't store it directly on the cursor, as happens for + # types implemented in C, we use a weakkey dictionary. + try: + cursor_prep_stmts = cursor._rs_prepared_statements + except AttributeError: + try: + cursor_prep_stmts = cursor._rs_prepared_statements = {} + except AttributeError: + cursor_prep_stmts = self._cursor_cache.get(cursor) + if cursor_prep_stmts is None: + cursor_prep_stmts = self._cursor_cache[cursor] = {} + return cursor_prep_stmts + + def execute(self, cursor, params=None): + # (Any, dict) -> None + # TODO: Include literals from self.params. + # TODO: Syntax transformation if they don't support names. + # TODO: Validate given params match expected ones, nothing missing? + stmt = self.stmt + if self._prepare_stmt: + # Prepare on demand. + + # In all databases, prepared statements + # persist past COMMIT/ROLLBACK (though in PostgreSQL + # preparing them takes locks that only go away at + # COMMIT/ROLLBACK). But they don't persist past a session + # restart (new connection) (obviously). + # + # Thus we keep a cache of statements we have prepared for + # this particular connection/cursor. + # + cursor_prep_stmts = self._stmt_cache_for_cursor(cursor) + try: + stmt = cursor_prep_stmts[self._prepare_stmt] + except KeyError: + stmt = cursor_prep_stmts[self._prepare_stmt] = self.stmt + __traceback_info__ = self._prepare_stmt, self, self.root.dialect.compiler(self.root) + cursor.execute(self._prepare_stmt) + params = self._prepare_converter(params) + + __traceback_info__ = stmt, params + if params: + cursor.execute(stmt, params) + elif self.params: + # XXX: This isn't really good. + # If there are both literals in the SQL and params, + # we don't handle that. + cursor.execute(stmt, self.params) + else: + cursor.execute(stmt) diff --git a/src/relstorage/adapters/sql/schema.py b/src/relstorage/adapters/sql/schema.py new file mode 100644 index 00000000..00993851 --- /dev/null +++ b/src/relstorage/adapters/sql/schema.py @@ -0,0 +1,161 @@ +# -*- coding: utf-8 -*- +""" +Concepts that exist in SQL, such as columns and tables. + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from ._util import Resolvable +from ._util import Columns + +from .types import Type +from .types import Unknown + +from .expressions import ExpressionOperatorMixin +from .expressions import ParamMixin + +from .dialect import DialectAware + +from .select import Selectable +from .insert import Insertable + +class Column(ExpressionOperatorMixin, + Resolvable): + """ + Defines a column in a table. + """ + + def __init__(self, name, type_=Unknown, primary_key=False, nullable=True): + self.name = name + self.type_ = type_ if isinstance(type_, Type) else type_() + self.primary_key = primary_key + self.nullable = False if primary_key else nullable + + def is_type(self, type_): + return isinstance(self.type_, type_) + + def __str__(self): + return self.name + + + def __compile_visit__(self, compiler): + compiler.visit_column(self) + + +class _DeferredColumn(Column): + + def resolve_against(self, table): + return getattr(table.c, self.name) + +class _DeferredColumns(object): + + def __getattr__(self, name): + return _DeferredColumn(name) + +class ColumnResolvingProxy(ParamMixin): + """ + A proxy that select can resolve to tables in the current table. + """ + + c = _DeferredColumns() + + +class SchemaItem(object): + """ + A permanent item in a schema (aka the data dictionary). + """ + + +class Table(Selectable, + Insertable, + ParamMixin, + SchemaItem): + """ + A table relation. + """ + + def __init__(self, name, *columns): + self.name = name + self.columns = columns + self.c = Columns(columns) + + def __str__(self): + return self.name + + def __compile_visit__(self, compiler): + compiler.emit_identifier(self.name) + + def natural_join(self, other_table): + return NaturalJoinedTable(self, other_table) + + + +class TemporaryTable(Table): + """ + A temporary table. + """ + + +class _CompositeTableMixin(Selectable, + ParamMixin): + + def __init__(self, lhs, rhs): + self.lhs = lhs + self.rhs = rhs + + common_columns = [] + for col in lhs.columns: + if hasattr(rhs.c, col.name): + common_columns.append(col) + self.columns = common_columns + self.c = Columns(common_columns) + + +class NaturalJoinedTable(DialectAware, + _CompositeTableMixin): + + def __init__(self, lhs, rhs): + super(NaturalJoinedTable, self).__init__(lhs, rhs) + assert self.c + self._join_columns = self.c + + # TODO: Check for data type mismatches, etc? + self.columns = list(self.lhs.columns) + for col in self.rhs.columns: + if not hasattr(self.lhs.c, col.name): + self.columns.append(col) + self.columns = tuple(self.columns) + self.c = Columns(self.columns) + + def __compile_visit__(self, compiler): + compiler.visit(self.lhs) + compiler.emit_keyword('JOIN') + compiler.visit(self.rhs) + # careful with USING clause in a join: Oracle doesn't allow such + # columns to have a prefix. + compiler.emit_keyword('USING') + compiler.visit_grouped(self._join_columns) + + +class HistoryVariantTable(DialectAware, + _CompositeTableMixin): + """ + A table that can be one of two tables, depending on whether + the instance is keeping history or not. + """ + + @property + def history_preserving(self): + return self.lhs + + @property + def history_free(self): + return self.rhs + + def __compile_visit__(self, compiler): + keep_history = self.context.keep_history + node = self.history_preserving if keep_history else self.history_free + return compiler.visit(node) diff --git a/src/relstorage/adapters/sql/select.py b/src/relstorage/adapters/sql/select.py new file mode 100644 index 00000000..558d590c --- /dev/null +++ b/src/relstorage/adapters/sql/select.py @@ -0,0 +1,147 @@ +# -*- coding: utf-8 -*- +""" +Elements of select queries. + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from .query import Query +from .query import Clause +from .query import ColumnList +from ._util import copy + +from .ast import TextNode +from .ast import resolved_against + +from .expressions import And +from .expressions import EmptyExpression + +class WhereClause(Clause): + + def __init__(self, expression): + self.expression = expression + + def and_(self, expression): + expression = And(self.expression, expression) + new = copy(self) + new.expression = expression + return new + + def __compile_visit__(self, compiler): + compiler.emit_keyword(' WHERE') + compiler.visit_grouped(self.expression) + +class OrderBy(Clause): + + def __init__(self, expression, dir): + self.expression = expression + self.dir = dir + + def __compile_visit__(self, compiler): + compiler.emit(' ORDER BY ') + compiler.visit(self.expression) + if self.dir: + compiler.emit(' ' + self.dir) + + +def _where(expression): + if expression: + return WhereClause(expression) + + +class _SelectColumns(ColumnList): + + def __compile_visit__(self, compiler): + compiler.visit_select_list_csv(self._columns) + + def as_select_list(self): + return self + + +class Select(Query): + """ + A Select query. + + When instances of this class are stored in a class dictionary, + they function as non-data descriptors: The first time they are + accessed, they *bind* themselves to the instance and select the + appropriate SQL syntax and compile themselves into a string. + """ + + _distinct = EmptyExpression() + _where = EmptyExpression() + _order_by = EmptyExpression() + _limit = None + _for_update = None + _nowait = None + + def __init__(self, table, *columns): + self.table = table + if columns: + self.column_list = _SelectColumns(resolved_against(columns, table)) + else: + self.column_list = table + + def where(self, expression): + expression = expression.resolve_against(self.table) + s = copy(self) + s._where = _where(expression) + return s + + def and_(self, expression): + expression = expression.resolve_against(self.table) + s = copy(self) + s._where = self._where.and_(expression) + return s + + def order_by(self, expression, dir=None): + expression = expression.resolve_against(self.table) + s = copy(self) + s._order_by = OrderBy(expression, dir) + return s + + def limit(self, literal): + s = copy(self) + s._limit = literal + return s + + def for_update(self): + s = copy(self) + s._for_update = 'FOR UPDATE' + return s + + def nowait(self): + s = copy(self) + s._nowait = 'NOWAIT' + return s + + def distinct(self): + s = copy(self) + s._distinct = TextNode('DISTINCT') + return s + + def __compile_visit__(self, compiler): + compiler.emit_keyword('SELECT') + compiler.visit(self._distinct) + compiler.visit_select_list(self.column_list) + compiler.visit_from(self.table) + compiler.visit_clause(self._where) + compiler.visit_clause(self._order_by) + if self._limit: + compiler.emit_keyword('LIMIT') + compiler.emit(str(self._limit)) + if self._for_update: + compiler.emit_keyword(self._for_update) + if self._nowait: + compiler.emit_keyword(self._nowait) + + +class Selectable(object): + """ + Mixin for something that can form the root of a selet query. + """ + + def select(self, *args, **kwargs): + return Select(self, *args, **kwargs) diff --git a/src/relstorage/adapters/sql/tests/test_ast.py b/src/relstorage/adapters/sql/tests/test_ast.py new file mode 100644 index 00000000..f78735e9 --- /dev/null +++ b/src/relstorage/adapters/sql/tests/test_ast.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +############################################################################## +# +# Copyright (c) 2019 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## + +""" +Tests for AST. + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from relstorage.tests import TestCase + +from .. import ast + + +class TestFuncs(TestCase): + + def test_as_node_boolean(self): + node = ast.as_node(True) + self.assertIsInstance(node, ast.BooleanNode) + self.assertIs(node.raw, True) diff --git a/src/relstorage/adapters/sql/tests/test_dialect.py b/src/relstorage/adapters/sql/tests/test_dialect.py new file mode 100644 index 00000000..4197accd --- /dev/null +++ b/src/relstorage/adapters/sql/tests/test_dialect.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- +############################################################################## +# +# Copyright (c) 2019 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## + +""" +Tests for dialects. + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from relstorage.tests import TestCase + +from .. import dialect + +class TestMissingDialect(TestCase): + + def test_boolean(self): + d = dialect._MissingDialect() + self.assertFalse(d) + + +class TestCompiler(TestCase): + + def test_prepare_no_datatypes(self): + + class C(dialect.Compiler): + def _next_prepared_stmt_name(self, query): + return 'my_stmt' + + def _find_datatypes_for_prepared_query(self): + return () + + compiler = C(None) + + stmt, execute, convert = compiler.prepare() + + self.assertEqual( + stmt, + 'PREPARE my_stmt AS ' + ) + + self.assertEqual( + execute, + 'EXECUTE my_stmt' + ) + + # We get the default dictionary converter even if we + # don't need it. + + self.assertEqual( + [], + convert({'a': 42}) + ) + + def test_prepare_named_datatypes(self): + + compiler = dialect.Compiler(None) + compiler.placeholders[object()] = 'name' + + _s, _x, convert = compiler.prepare() + + self.assertEqual( + [64], + convert({'name': 64}) + ) + + +class TestDialectAware(TestCase): + + def test_bind_none(self): + + aware = dialect.DialectAware() + + with self.assertRaisesRegex(TypeError, 'no dialect found'): + aware.bind(None) + + def test_bind_dialect(self): + class Dialect(dialect.DefaultDialect): + def bind(self, context): + raise AssertionError("Not supposed to re-bind") + + d = Dialect() + + aware = dialect.DialectAware() + new_aware = aware.bind(d) + self.assertIsNot(aware, new_aware) + self.assertIs(new_aware.dialect, d) diff --git a/src/relstorage/adapters/sql/tests/test_expressions.py b/src/relstorage/adapters/sql/tests/test_expressions.py new file mode 100644 index 00000000..37b3c7e8 --- /dev/null +++ b/src/relstorage/adapters/sql/tests/test_expressions.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +############################################################################## +# +# Copyright (c) 2019 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## + +""" +Tests for expressions. + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from relstorage.tests import TestCase + +from .. import expressions + +class TestBinaryExpression(TestCase): + + def test_str(self): + + exp = expressions.BinaryExpression('=', 'lhs', 'rhs') + self.assertEqual( + str(exp), + 'lhs = rhs' + ) + + +class TestEmptyExpression(TestCase): + + exp = expressions.EmptyExpression() + + def test_boolean(self): + self.assertFalse(self.exp) + + def test_str(self): + self.assertEqual(str(self.exp), '') + + def test_and(self): + self.assertIs(self.exp.and_(self), self) + + +class TestAnd(TestCase): + + def test_resolve(self): + # This will wrap them in literal nodes, which + # do nothing when resolved. + + exp = expressions.And('a', 'b') + + resolved = exp.resolve_against(None) + + self.assertIsInstance(resolved, expressions.And) + self.assertIs(resolved.lhs, exp.lhs) + self.assertIs(resolved.rhs, exp.rhs) diff --git a/src/relstorage/adapters/sql/tests/test_query.py b/src/relstorage/adapters/sql/tests/test_query.py new file mode 100644 index 00000000..df85bafc --- /dev/null +++ b/src/relstorage/adapters/sql/tests/test_query.py @@ -0,0 +1,125 @@ +# -*- coding: utf-8 -*- +############################################################################## +# +# Copyright (c) 2019 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from relstorage.tests import TestCase + +from ..query import Query as _BaseQuery +from ..query import CompiledQuery + +class MockDialect(object): + + def bind(self, context): # pylint:disable=unused-argument + return self + +class Query(_BaseQuery): + + def compiled(self): + return self + +class TestQuery(TestCase): + + def test_name_discovery(self): + # If a __name__ isn't assigned when a query is a + # class property and used as a non-data-descriptor, + # it finds it. + + class C(object): + dialect = MockDialect() + + q1 = Query() + q2 = Query() + q_over = Query() + + class D(C): + + q3 = Query() + q_over = Query() + + inst = D() + + # Undo the effects of Python 3.6's __set_name__. + D.q1.__name__ = None + D.q2.__name__ = None + D.q3.__name__ = None + C.q_over.__name__ = None + D.q_over.__name__ = None + + # get them to trigger them to search their name + getattr(inst, 'q1') + getattr(inst, 'q2') + getattr(inst, 'q3') + getattr(inst, 'q_over') + + + self.assertEqual(C.q1.__name__, 'q1') + self.assertEqual(C.q2.__name__, 'q2') + self.assertEqual(D.q3.__name__, 'q3') + self.assertIsNone(C.q_over.__name__) + self.assertEqual(D.q_over.__name__, 'q_over') + + +class TestCompiledQuery(TestCase): + + def test_stmt_cache_on_bad_cursor(self): + + unique_execute_stmt = [] + + class MockStatement(object): + class dialect(object): + class compiler(object): + def __init__(self, _): + "Does nothing" + def compile(self): + return 'stmt', () + def can_prepare(self): + # We have to prepare if we want to try the cache + return True + def prepare(self): + o = object() + unique_execute_stmt.append(o) + return "prepare", o, lambda params: params + + executed = [] + + class Cursor(object): + __slots__ = ('__weakref__',) + + def execute(self, stmt): + executed.append(stmt) + + cursor = Cursor() + + query = CompiledQuery(MockStatement()) + query.execute(cursor) + + self.assertLength(unique_execute_stmt, 1) + self.assertLength(executed, 2) + self.assertEqual(executed, [ + "prepare", + unique_execute_stmt[0], + ]) + + query.execute(cursor) + self.assertLength(unique_execute_stmt, 1) + self.assertLength(executed, 3) + self.assertEqual(executed, [ + "prepare", + unique_execute_stmt[0], + unique_execute_stmt[0], + ]) diff --git a/src/relstorage/adapters/sql/tests/test_schema.py b/src/relstorage/adapters/sql/tests/test_schema.py new file mode 100644 index 00000000..113ab5c1 --- /dev/null +++ b/src/relstorage/adapters/sql/tests/test_schema.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +############################################################################## +# +# Copyright (c) 2019 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from relstorage.tests import TestCase + +from ..schema import Table + + +class TestTable(TestCase): + + def test_str(self): + + self.assertEqual(str(Table("table")), "table") diff --git a/src/relstorage/adapters/sql/tests/test_sql.py b/src/relstorage/adapters/sql/tests/test_sql.py index d318b277..d44dd984 100644 --- a/src/relstorage/adapters/sql/tests/test_sql.py +++ b/src/relstorage/adapters/sql/tests/test_sql.py @@ -26,13 +26,16 @@ from .. import Table from .. import HistoryVariantTable from .. import Column -from .. import bindparam +from .. import it from .. import DefaultDialect from .. import OID from .. import TID from .. import State from .. import Boolean from .. import BinaryString +from .. import func + +from ..expressions import bindparam current_object = Table( 'current_object', @@ -81,6 +84,43 @@ def test_simple_eq_select(self): 'SELECT zoid, tid, state, state_size FROM object_state WHERE (zoid = tid)' ) + def test_simple_eq_limit(self): + table = object_state + + stmt = table.select().where(table.c.zoid == table.c.tid).limit(1) + + self.assertEqual( + str(stmt), + 'SELECT zoid, tid, state, state_size FROM object_state WHERE (zoid = tid) LIMIT 1' + ) + + def test_simple_eq_for_update(self): + table = object_state + + stmt = table.select().where(table.c.zoid == table.c.tid).for_update() + + self.assertEqual( + str(stmt), + 'SELECT zoid, tid, state, state_size FROM object_state WHERE (zoid = tid) FOR UPDATE' + ) + + stmt = stmt.nowait() + self.assertEqual( + str(stmt), + 'SELECT zoid, tid, state, state_size FROM object_state WHERE (zoid = tid) FOR UPDATE ' + 'NOWAIT' + ) + + def test_max_select(self): + table = object_state + + stmt = table.select(func.max(it.c.tid)) + + self.assertEqual( + str(stmt), + 'SELECT max(tid) FROM object_state' + ) + def test_distinct(self): table = object_state stmt = table.select(table.c.zoid).where(table.c.tid == table.bindparam('tid')).distinct() @@ -161,12 +201,13 @@ class H(object): ) def test_bind(self): - select = objects.select(objects.c.tid, objects.c.zoid).where( + from operator import attrgetter + query = objects.select(objects.c.tid, objects.c.zoid).where( objects.c.tid > bindparam('tid') ) # Unbound we assume history self.assertEqual( - str(select), + str(query), 'SELECT tid, zoid FROM current_object WHERE (tid > %(tid)s)' ) @@ -176,24 +217,39 @@ class Context(object): context = Context() dialect = context.dialect - select = select.bind(context) + query = query.bind(context) + + class Root(object): + select = query + + for item_name in ( + 'select', + 'select.table', + 'select._where', + 'select._where.expression', + ): + __traceback_info__ = item_name + item = attrgetter(item_name)(Root) + # The exact context is passed down the tree. + self.assertIs(item.context, context) + # The dialect is first bound, so it's *not* the same + # as the one we can reference (though it is equal)... + self.assertEqual(item.dialect, dialect) + # ...but it *is* the same throughout the tree + self.assertIs(query.dialect, item.dialect) - self.assertEqual(select.context, dialect) - self.assertEqual(select.table.context, dialect) - self.assertEqual(select._where.context, dialect) - self.assertEqual(select._where.expression.context, dialect) # We take up its history setting self.assertEqual( - str(select), + str(query), 'SELECT tid, zoid FROM current_object WHERE (tid > %(tid)s)' ) # Bound to history-free we use history free context.keep_history = False - select = select.bind(context) + query = query.bind(context) self.assertEqual( - str(select), + str(query), 'SELECT tid, zoid FROM object_state WHERE (tid > %(tid)s)' ) @@ -265,7 +321,6 @@ def test_prepared_insert_select_with_param(self): ) def test_it(self): - from .. import it stmt = object_state.select( it.c.zoid, it.c.state @@ -298,7 +353,6 @@ def test_it(self): stmt.order_by(col_ref == object_state.c.state) def test_boolean_literal(self): - from .. import it stmt = transaction.select( transaction.c.tid ).where( @@ -325,7 +379,6 @@ def test_literal_in_select(self): ) def test_boolean_literal_it_joined_table(self): - from .. import it stmt = transaction.natural_join( object_state ).select( diff --git a/src/relstorage/adapters/sql/types.py b/src/relstorage/adapters/sql/types.py new file mode 100644 index 00000000..670ed153 --- /dev/null +++ b/src/relstorage/adapters/sql/types.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +""" +SQL data types. + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + + +class Type(object): + """ + A database type. + """ + +class Unknown(Type): + "Unspecified." + +class Integer64(Type): + """ + A 64-bit integer. + """ + +class OID(Integer64): + """ + Type of an OID. + """ + +class TID(Integer64): + """ + Type of a TID. + """ + +class BinaryString(Type): + """ + Arbitrary sized binary string. + """ + +class State(Type): + """ + Used for storing object state. + """ + +class Boolean(Type): + """ + A two-value column. + """ diff --git a/src/relstorage/tests/__init__.py b/src/relstorage/tests/__init__.py index 9492d04f..ec535197 100644 --- a/src/relstorage/tests/__init__.py +++ b/src/relstorage/tests/__init__.py @@ -106,10 +106,13 @@ def tearDown(self): super(TestCase, self).tearDown() def assertIsEmpty(self, container): - self.assertEqual(len(container), 0) + self.assertLength(container, 0) assertEmpty = assertIsEmpty + def assertLength(self, container, length): + self.assertEqual(len(container), length, container) + class StorageCreatingMixin(ABC): keep_history = None # Override