diff --git a/src/relstorage/_compat.py b/src/relstorage/_compat.py index 3ad91d99..ee4dd1d9 100644 --- a/src/relstorage/_compat.py +++ b/src/relstorage/_compat.py @@ -55,6 +55,8 @@ def list_values(d): OID_OBJECT_MAP_TYPE = dict OID_SET_TYPE = set +MAX_TID = BTrees.family64.maxint + def iteroiditems(d): # Could be either a BTree, which always has 'iteritems', # or a plain dict, which may or may not have iteritems. @@ -65,9 +67,11 @@ def iteroiditems(d): if PY3: string_types = (str,) unicode = str + from io import StringIO as NStringIO else: string_types = (basestring,) unicode = unicode + from io import BytesIO as NStringIO try: from abc import ABC diff --git a/src/relstorage/_util.py b/src/relstorage/_util.py index 1d224a21..2a737f63 100644 --- a/src/relstorage/_util.py +++ b/src/relstorage/_util.py @@ -143,6 +143,45 @@ def __get__(self, inst, class_): inst.__dict__[name] = value return value +class CachedIn(object): + """Cached method with given cache attribute.""" + + def __init__(self, attribute_name, factory=dict): + self.attribute_name = attribute_name + self.factory = factory + + def __call__(self, func): + + @functools.wraps(func) + def decorated(instance): + cache = self.cache(instance) + key = () # We don't support arguments right now, so only one key. + try: + v = cache[key] + except KeyError: + v = cache[key] = func(instance) + return v + + decorated.invalidate = self.invalidate + return decorated + + def invalidate(self, instance): + cache = self.cache(instance) + key = () + try: + del cache[key] + except KeyError: + pass + + def cache(self, instance): + try: + cache = getattr(instance, self.attribute_name) + except AttributeError: + cache = self.factory() + setattr(instance, self.attribute_name, cache) + return cache + + def to_utf8(data): if data is None or isinstance(data, bytes): return data diff --git a/src/relstorage/adapters/dbiter.py b/src/relstorage/adapters/dbiter.py index 9bb882f2..f2fb4da5 100644 --- a/src/relstorage/adapters/dbiter.py +++ b/src/relstorage/adapters/dbiter.py @@ -16,29 +16,38 @@ from zope.interface import implementer -from .interfaces import IDatabaseIterator +from relstorage._compat import MAX_TID +from .interfaces import IDatabaseIterator +from .schema import Schema +from .sql import it class DatabaseIterator(object): - """Abstract base class for database iteration. + """ + Abstract base class for database iteration. """ - def __init__(self, database_driver, runner): - self.runner = runner + def __init__(self, database_driver): + """ + :param database_driver: Necessary to bind queries correctly. + """ self.driver = database_driver + _iter_objects_query = Schema.object_state.select( + it.c.zoid, + it.c.state + ).where( + it.c.tid == it.bindparam('tid') + ).order_by( + it.c.zoid + ) + def iter_objects(self, cursor, tid): """Iterate over object states in a transaction. - Yields (oid, prev_tid, state) for each object state. + Yields ``(oid, state)`` for each object in the transaction. """ - stmt = """ - SELECT zoid, state - FROM object_state - WHERE tid = %(tid)s - ORDER BY zoid - """ - self.runner.run_script_stmt(cursor, stmt, {'tid': tid}) + self._iter_objects_query.execute(cursor, {'tid': tid}) for oid, state in cursor: state = self.driver.binary_column_as_state_type(state) yield oid, state @@ -47,6 +56,8 @@ def iter_objects(self, cursor, tid): @implementer(IDatabaseIterator) class HistoryPreservingDatabaseIterator(DatabaseIterator): + keep_history = True + def _transaction_iterator(self, cursor): """ Iterate over a list of transactions returned from the database. @@ -71,22 +82,38 @@ def _transaction_iterator(self, cursor): yield (tid, username, description, ext) + tuple(row[4:]) + _iter_transactions_query = Schema.transaction.select( + it.c.tid, it.c.username, it.c.description, it.c.extension + ).where( + it.c.packed == False # pylint:disable=singleton-comparison + ).and_( + it.c.tid != 0 + ).order_by( + it.c.tid, 'DESC' + ) + def iter_transactions(self, cursor): """Iterate over the transaction log, newest first. Skips packed transactions. Yields (tid, username, description, extension) for each transaction. """ - stmt = """ - SELECT tid, username, description, extension - FROM transaction - WHERE packed = %(FALSE)s - AND tid != 0 - ORDER BY tid DESC - """ - self.runner.run_script_stmt(cursor, stmt) + self._iter_transactions_query.execute(cursor) return self._transaction_iterator(cursor) + _iter_transactions_range_query = Schema.transaction.select( + it.c.tid, + it.c.username, + it.c.description, + it.c.extension, + it.c.packed, + ).where( + it.c.tid >= it.bindparam('min_tid') + ).and_( + it.c.tid <= it.bindparam('max_tid') + ).order_by( + it.c.tid + ) def iter_transactions_range(self, cursor, start=None, stop=None): """Iterate over the transactions in the given range, oldest first. @@ -95,21 +122,31 @@ def iter_transactions_range(self, cursor, start=None, stop=None): Yields (tid, username, description, extension, packed) for each transaction. """ - stmt = """ - SELECT tid, username, description, extension, - CASE WHEN packed = %(TRUE)s THEN 1 ELSE 0 END - FROM transaction - WHERE tid >= 0 - """ - if start is not None: - stmt += " AND tid >= %(min_tid)s" - if stop is not None: - stmt += " AND tid <= %(max_tid)s" - stmt += " ORDER BY tid" - self.runner.run_script_stmt(cursor, stmt, - {'min_tid': start, 'max_tid': stop}) + params = { + 'min_tid': start if start else 0, + 'max_tid': stop if stop else MAX_TID + } + self._iter_transactions_range_query.execute(cursor, params) return self._transaction_iterator(cursor) + _object_exists_query = Schema.current_object.select( + 1 + ).where( + it.c.zoid == it.bindparam('oid') + ) + + _object_history_query = Schema.transaction.natural_join( + Schema.object_state + ).select( + it.c.tid, it.c.username, it.c.description, it.c.extension, + Schema.object_state.c.state_size + ).where( + it.c.zoid == it.bindparam("oid") + ).and_( + it.c.packed == False # pylint:disable=singleton-comparison + ).order_by( + it.c.tid, "DESC" + ) def iter_object_history(self, cursor, oid): """Iterate over an object's history. @@ -118,59 +155,60 @@ def iter_object_history(self, cursor, oid): Yields (tid, username, description, extension, pickle_size) for each modification. """ - stmt = """ - SELECT 1 FROM current_object WHERE zoid = %(oid)s - """ - self.runner.run_script_stmt(cursor, stmt, {'oid': oid}) + params = {'oid': oid} + self._object_exists_query.execute(cursor, params) if not cursor.fetchall(): raise KeyError(oid) - stmt = """ - SELECT tid, username, description, extension, state_size - FROM transaction - JOIN object_state USING (tid) - WHERE zoid = %(oid)s - AND packed = %(FALSE)s - ORDER BY tid DESC - """ - self.runner.run_script_stmt(cursor, stmt, {'oid': oid}) + self._object_history_query.execute(cursor, params) return self._transaction_iterator(cursor) @implementer(IDatabaseIterator) class HistoryFreeDatabaseIterator(DatabaseIterator): + keep_history = False + def iter_transactions(self, cursor): """Iterate over the transaction log, newest first. Skips packed transactions. - Yields (tid, username, description, extension) for each transaction. + Yields ``(tid, username, description, extension)`` for each transaction. This always returns an empty iterable. """ # pylint:disable=unused-argument return [] + _iter_transactions_range_query = Schema.object_state.select( + it.c.tid, + ).where( + it.c.tid >= it.bindparam('min_tid') + ).and_( + it.c.tid <= it.bindparam('max_tid') + ).order_by( + it.c.tid + ).distinct() + def iter_transactions_range(self, cursor, start=None, stop=None): """Iterate over the transactions in the given range, oldest first. Includes packed transactions. - Yields (tid, username, description, extension, packed) + Yields ``(tid, username, description, extension, packed)`` for each transaction. """ - stmt = """ - SELECT DISTINCT tid - FROM object_state - WHERE tid > 0 - """ - if start is not None: - stmt += " AND tid >= %(min_tid)s" - if stop is not None: - stmt += " AND tid <= %(max_tid)s" - stmt += " ORDER BY tid" - self.runner.run_script_stmt(cursor, stmt, - {'min_tid': start, 'max_tid': stop}) - return ((tid, '', '', '', True) for (tid,) in cursor) + params = { + 'min_tid': start if start else 0, + 'max_tid': stop if stop else MAX_TID + } + self._iter_transactions_range_query.execute(cursor, params) + return ((tid, b'', b'', b'', True) for (tid,) in cursor) + + _iter_object_history_query = Schema.object_state.select( + it.c.tid, it.c.state_size + ).where( + it.c.zoid == it.bindparam('oid') + ) def iter_object_history(self, cursor, oid): """ @@ -180,12 +218,7 @@ def iter_object_history(self, cursor, oid): Yields a single row, ``(tid, username, description, extension, pickle_size)`` """ - stmt = """ - SELECT tid, state_size - FROM object_state - WHERE zoid = %(oid)s - """ - self.runner.run_script_stmt(cursor, stmt, {'oid': oid}) + self._iter_object_history_query.execute(cursor, {'oid': oid}) rows = cursor.fetchall() if not rows: raise KeyError(oid) diff --git a/src/relstorage/adapters/interfaces.py b/src/relstorage/adapters/interfaces.py index 2b8cc7cf..543c93b6 100644 --- a/src/relstorage/adapters/interfaces.py +++ b/src/relstorage/adapters/interfaces.py @@ -50,6 +50,14 @@ def __str__(): """Return a short description of the adapter""" +class IDBDialect(Interface): + """ + Handles converting from our internal "standard" SQL queries to + something database specific. + """ + + # TODO: Fill this in. + class IDBDriver(Interface): """ An abstraction over the information needed for RelStorage to work @@ -71,6 +79,8 @@ class IDBDriver(Interface): Binary = Attribute("A callable.") + dialect = Attribute("The IDBDialect for this driver.") + def binary_column_as_state_type(db_column_data): """ Turn *db_column_data* into something that's a valid pickle diff --git a/src/relstorage/adapters/locker.py b/src/relstorage/adapters/locker.py index 336d15fd..acc34ed3 100644 --- a/src/relstorage/adapters/locker.py +++ b/src/relstorage/adapters/locker.py @@ -30,6 +30,7 @@ from ._util import query_property as _query_property from ._util import DatabaseHelpersMixin +from .schema import Schema from .interfaces import UnableToAcquireCommitLockError logger = __import__('logging').getLogger(__name__) @@ -198,39 +199,31 @@ def lock_current_objects(self, cursor, current_oids): consume(rows) - _commit_lock_queries = ( - # MySQL allows aggregates in the top level to use FOR UPDATE, - # but PostgreSQL does not, so we have to use the second form. - # - # 'SELECT MAX(tid) FROM transaction FOR UPDATE', - # 'SELECT tid FROM transaction WHERE tid = (SELECT MAX(tid) FROM transaction) FOR UPDATE', - - # Note that using transaction in history-preserving databases - # can still lead to deadlock in older versions of MySQL (test - # checkPackWhileWriting), and the above lock statement can - # lead to duplicate transaction ids being inserted on older - # versions (5.7.12, PyMySQL: - # https://ci.appveyor.com/project/jamadden/relstorage/builds/25748619/job/cyio3w54uqi026lr#L923). - # So both HF and HP use an artificial lock row. - # - # TODO: Figure out exactly the best way to lock just the rows - # in the transaction table we care about that works - # everywhere, or a better way to choose the next TID. - # gap/intention locks might be a clue. - - 'SELECT tid FROM commit_row_lock FOR UPDATE', - 'SELECT tid FROM commit_row_lock FOR UPDATE' - ) - - _commit_lock_query = _query_property('_commit_lock') - - _commit_lock_nowait_queries = ( - _commit_lock_queries[0] + ' NOWAIT', - _commit_lock_queries[1] + ' NOWAIT', + # MySQL allows aggregates in the top level to use FOR UPDATE, + # but PostgreSQL does not, so we have to use the second form. + # + # 'SELECT MAX(tid) FROM transaction FOR UPDATE', + # 'SELECT tid FROM transaction WHERE tid = (SELECT MAX(tid) FROM transaction) FOR UPDATE', + + # Note that using transaction in history-preserving databases + # can still lead to deadlock in older versions of MySQL (test + # checkPackWhileWriting), and the above lock statement can + # lead to duplicate transaction ids being inserted on older + # versions (5.7.12, PyMySQL: + # https://ci.appveyor.com/project/jamadden/relstorage/builds/25748619/job/cyio3w54uqi026lr#L923). + # So both HF and HP use an artificial lock row. + # + # TODO: Figure out exactly the best way to lock just the rows + # in the transaction table we care about that works + # everywhere, or a better way to choose the next TID. + # gap/intention locks might be a clue. + _commit_lock_query = Schema.commit_row_lock.select( + Schema.commit_row_lock.c.tid + ).for_update( + ).prepared( ) - _commit_lock_nowait_query = _query_property('_commit_lock_nowait') - + _commit_lock_nowait_query = _commit_lock_query.nowait() @metricmethod def hold_commit_lock(self, cursor, ensure_current=False, nowait=False): @@ -241,9 +234,9 @@ def hold_commit_lock(self, cursor, ensure_current=False, nowait=False): lock_stmt = self._commit_lock_nowait_query else: self._set_row_lock_nowait(cursor) - __traceback_info__ = lock_stmt + try: - cursor.execute(lock_stmt) + lock_stmt.execute(cursor) rows = cursor.fetchall() if not rows or not rows[0]: raise UnableToAcquireCommitLockError("No row returned from commit_row_lock") diff --git a/src/relstorage/adapters/mover.py b/src/relstorage/adapters/mover.py index 85b603a0..9d25d125 100644 --- a/src/relstorage/adapters/mover.py +++ b/src/relstorage/adapters/mover.py @@ -28,6 +28,10 @@ from .._compat import ABC from .batch import RowBatcher from .interfaces import IObjectMover +from .schema import Schema + +objects = Schema.all_current_object_state +object_state = Schema.object_state metricmethod_sampled = Metric(method=True, rate=0.1) @@ -54,20 +58,11 @@ def _compute_md5sum(self, data): return None return md5(data).hexdigest() - _load_current_queries = ( - """ - SELECT state, tid - FROM current_object - JOIN object_state USING(zoid, tid) - WHERE zoid = %s - """, - """ - SELECT state, tid - FROM object_state - WHERE zoid = %s - """) - - _load_current_query = _query_property('_load_current') + _load_current_query = objects.select( + objects.c.state, objects.c.tid + ).where( + objects.c.zoid == objects.orderedbindparam() + ).prepared() @metricmethod_sampled def load_current(self, cursor, oid): @@ -76,8 +71,7 @@ def load_current(self, cursor, oid): oid is an integer. Returns (None, None) if object does not exist. """ stmt = self._load_current_query - - cursor.execute(stmt, (oid,)) + stmt.execute(cursor, (oid,)) # Note that we cannot rely on cursor.rowcount being # a valid indicator. The DB-API doesn't require it, and # some implementations, like MySQL Connector/Python are @@ -110,12 +104,13 @@ def load_currents(self, cursor, oids): oid, state, tid = row yield oid, binary_column_as_state_type(state), tid - _load_revision_query = """ - SELECT state - FROM object_state - WHERE zoid = %s - AND tid = %s - """ + _load_revision_query = object_state.select( + object_state.c.state + ).where( + object_state.c.zoid == object_state.orderedbindparam() + ).and_( + object_state.c.tid == object_state.orderedbindparam() + ).prepared() @metricmethod_sampled def load_revision(self, cursor, oid, tid): @@ -124,25 +119,24 @@ def load_revision(self, cursor, oid, tid): Returns None if no such state exists. """ stmt = self._load_revision_query - cursor.execute(stmt, (oid, tid)) + stmt.execute(cursor, (oid, tid)) row = cursor.fetchone() if row: (state,) = row return self.driver.binary_column_as_state_type(state) return None - _exists_queries = ( - "SELECT 1 FROM current_object WHERE zoid = %s", - "SELECT 1 FROM object_state WHERE zoid = %s" + _exists_query = Schema.all_current_object.select( + Schema.all_current_object.c.zoid + ).where( + Schema.all_current_object.c.zoid == Schema.all_current_object.orderedbindparam() ) - _exists_query = _query_property('_exists') - @metricmethod_sampled def exists(self, cursor, oid): """Returns a true value if the given object exists.""" stmt = self._exists_query - cursor.execute(stmt, (oid,)) + stmt.execute(cursor, (oid,)) row = cursor.fetchone() return row @@ -326,24 +320,15 @@ def restore(self, cursor, batcher, oid, tid, data): # careful with USING clause in a join: Oracle doesn't allow such # columns to have a prefix. - _detect_conflict_queries = ( - """ - SELECT zoid, current_object.tid, temp_store.prev_tid - FROM temp_store - JOIN current_object USING (zoid) - WHERE temp_store.prev_tid != current_object.tid - ORDER BY zoid - """, - """ - SELECT zoid, object_state.tid, temp_store.prev_tid - FROM temp_store - JOIN object_state USING (zoid) - WHERE temp_store.prev_tid != object_state.tid - ORDER BY zoid - """ - ) - - _detect_conflict_query = _query_property('_detect_conflict') + _detect_conflict_query = Schema.temp_store.natural_join( + Schema.all_current_object + ).select( + Schema.temp_store.c.zoid, Schema.all_current_object.c.tid, Schema.temp_store.c.prev_tid + ).where( + Schema.temp_store.c.prev_tid != Schema.all_current_object.c.tid + ).order_by( + Schema.temp_store.c.zoid + ).prepared() @metricmethod_sampled def detect_conflict(self, cursor): @@ -351,7 +336,7 @@ def detect_conflict(self, cursor): # passed to tryToResolveConflict, saving extra queries. # OTOH, using extra memory. stmt = self._detect_conflict_query - cursor.execute(stmt) + stmt.execute(cursor) rows = cursor.fetchall() return rows @@ -389,12 +374,21 @@ def replace_temp(self, cursor, oid, prev_tid, data): WHERE zoid IN (SELECT zoid FROM temp_store) """ - _move_from_temp_hf_insert_query = """ - INSERT INTO object_state (zoid, tid, state_size, state) - SELECT zoid, %s, COALESCE(LENGTH(state), 0), state - FROM temp_store - ORDER BY zoid - """ + _move_from_temp_hf_insert_query = Schema.object_state.insert( + ).from_select( + (Schema.object_state.c.zoid, + Schema.object_state.c.tid, + Schema.object_state.c.state_size, + Schema.object_state.c.state), + Schema.temp_store.select( + Schema.temp_store.c.zoid, + Schema.temp_store.orderedbindparam(), + 'COALESCE(LENGTH(state), 0)', + Schema.temp_store.c.state + ).order_by( + Schema.temp_store.c.zoid + ) + ).prepared() _move_from_temp_copy_blob_query = """ INSERT INTO blob_chunk (zoid, tid, chunk_num, chunk) @@ -435,7 +429,8 @@ def _move_from_temp_object_state(self, cursor, tid): cursor.execute(stmt) stmt = self._move_from_temp_hf_insert_query - cursor.execute(stmt, (tid,)) + __traceback_info__ = stmt + stmt.execute(cursor, (tid,)) @metricmethod_sampled diff --git a/src/relstorage/adapters/mysql/adapter.py b/src/relstorage/adapters/mysql/adapter.py index a4fc247b..2f799893 100644 --- a/src/relstorage/adapters/mysql/adapter.py +++ b/src/relstorage/adapters/mysql/adapter.py @@ -58,7 +58,7 @@ from relstorage.options import Options from .._abstract_drivers import _select_driver -from .._util import query_property + from ..dbiter import HistoryFreeDatabaseIterator from ..dbiter import HistoryPreservingDatabaseIterator from ..interfaces import IRelStorageAdapter @@ -69,7 +69,7 @@ from .connmanager import MySQLdbConnectionManager from .locker import MySQLLocker from .mover import MySQLObjectMover -from .mover import to_prepared_queries + from .oidallocator import MySQLOIDAllocator from .packundo import MySQLHistoryFreePackUndo from .packundo import MySQLHistoryPreservingPackUndo @@ -122,21 +122,20 @@ def __init__(self, options=None, **params): self.connmanager.add_on_store_opened(self.mover.on_store_opened) self.connmanager.add_on_load_opened(self.mover.on_load_opened) self.oidallocator = MySQLOIDAllocator(driver) - self.txncontrol = MySQLTransactionControl( - connmanager=self.connmanager, - keep_history=self.keep_history, - Binary=driver.Binary, - ) self.poller = Poller( - poll_query='EXECUTE get_latest_tid', + self.driver, keep_history=self.keep_history, runner=self.runner, revert_when_stale=options.revert_when_stale, ) - self.connmanager.add_on_load_opened(self._prepare_get_latest_tid) - self.connmanager.add_on_store_opened(self._prepare_get_latest_tid) + self.txncontrol = MySQLTransactionControl( + connmanager=self.connmanager, + poller=self.poller, + keep_history=self.keep_history, + Binary=driver.Binary, + ) if self.keep_history: self.packundo = MySQLHistoryPreservingPackUndo( @@ -148,7 +147,6 @@ def __init__(self, options=None, **params): ) self.dbiter = HistoryPreservingDatabaseIterator( driver, - runner=self.runner, ) else: self.packundo = MySQLHistoryFreePackUndo( @@ -160,7 +158,6 @@ def __init__(self, options=None, **params): ) self.dbiter = HistoryFreeDatabaseIterator( driver, - runner=self.runner, ) self.stats = MySQLStats( @@ -168,23 +165,6 @@ def __init__(self, options=None, **params): keep_history=self.keep_history ) - _get_latest_tid_queries = ( - "SELECT MAX(tid) FROM transaction", - "SELECT MAX(tid) FROM object_state", - ) - - _prepare_get_latest_tid_queries = to_prepared_queries( - 'get_latest_tid', - _get_latest_tid_queries) - - _prepare_get_latest_tid_query = query_property('_prepare_get_latest_tid') - - def _prepare_get_latest_tid(self, cursor, restart=False): - if restart: - return - stmt = self._prepare_get_latest_tid_query - cursor.execute(stmt) - def new_instance(self): return type(self)(options=self.options, **self._params) diff --git a/src/relstorage/adapters/mysql/drivers/__init__.py b/src/relstorage/adapters/mysql/drivers/__init__.py index 7d414245..b83d4c01 100644 --- a/src/relstorage/adapters/mysql/drivers/__init__.py +++ b/src/relstorage/adapters/mysql/drivers/__init__.py @@ -20,9 +20,32 @@ from ..._abstract_drivers import AbstractModuleDriver from ..._abstract_drivers import implement_db_driver_options +from ...sql import Compiler +from ...sql import DefaultDialect database_type = 'mysql' +class MySQLCompiler(Compiler): + + def can_prepare(self): + # If there are params, we can't prepare unless we're using + # the binary protocol; otherwise we have to SET user variables + # with extra round trips, which is worse. + return not self.placeholders and super(MySQLCompiler, self).can_prepare() + + _PREPARED_CONJUNCTION = 'FROM' + + def _prepared_param(self, number): + return '?' + + def _quote_query_for_prepare(self, query): + return '"{query}"'.format(query=query) + +class MySQLDialect(DefaultDialect): + + def compiler_class(self): + return MySQLCompiler + class AbstractMySQLDriver(AbstractModuleDriver): # Don't try to decode pickle states as UTF-8 (or whatever the @@ -75,6 +98,9 @@ def callproc_multi_result(self, cursor, proc, args=()): return multi_results + dialect = MySQLDialect() + + implement_db_driver_options( __name__, 'mysqlconnector', 'mysqldb', 'pymysql', diff --git a/src/relstorage/adapters/mysql/locker.py b/src/relstorage/adapters/mysql/locker.py index 7b251de5..503d5b34 100644 --- a/src/relstorage/adapters/mysql/locker.py +++ b/src/relstorage/adapters/mysql/locker.py @@ -97,19 +97,8 @@ class MySQLLocker(AbstractLocker): def __init__(self, options, driver, batcher_factory): super(MySQLLocker, self).__init__(options, driver, batcher_factory) - # TODO: Back to needing a proper prepare registry. - lock_stmt_raw = self._commit_lock_query - lock_stmt_nowait_raw = self._commit_lock_nowait_query - - self._prepare_lock_stmt = 'PREPARE hold_commit_lock FROM "%s"' % ( - lock_stmt_raw) - self._prepare_lock_stmt_nowait = 'PREPARE hold_commit_lock_nowait FROM "%s"' % ( - lock_stmt_nowait_raw) self._supports_row_lock_nowait = None - self._commit_lock_query = 'EXECUTE hold_commit_lock' - self._commit_lock_nowait_query = 'EXECUTE hold_commit_lock_nowait' - # No good preparing this, mysql can't take parameters in EXECUTE, # they have to be user variables, which defeats most of the point # (Although in this case, because it's a static value, maybe not; @@ -136,10 +125,6 @@ def on_store_opened(self, cursor, restart=False): __traceback_info__ = ver, major self._supports_row_lock_nowait = (major >= 8) - cursor.execute(self._prepare_lock_stmt) - if self._supports_row_lock_nowait: - cursor.execute(self._prepare_lock_stmt_nowait) - def _on_store_opened_set_row_lock_timeout(self, cursor, restart=False): if restart: return diff --git a/src/relstorage/adapters/mysql/mover.py b/src/relstorage/adapters/mysql/mover.py index a9e533ad..38c64c5b 100644 --- a/src/relstorage/adapters/mysql/mover.py +++ b/src/relstorage/adapters/mysql/mover.py @@ -22,37 +22,13 @@ from relstorage.adapters.interfaces import IObjectMover -from .._util import query_property from ..mover import AbstractObjectMover from ..mover import metricmethod_sampled -def to_prepared_queries(name, queries, extension=''): - - return [ - 'PREPARE ' + name + ' FROM "' + x.replace('%s', '?') + extension + '"' - for x in queries - ] - @implementer(IObjectMover) class MySQLObjectMover(AbstractObjectMover): - _prepare_detect_conflict_queries = to_prepared_queries( - 'detect_conflicts', - AbstractObjectMover._detect_conflict_queries, - # Now that we explicitly lock the rows before we begin, - # no sense applying a locking clause here too. - ) - - _prepare_detect_conflict_query = query_property('_prepare_detect_conflict') - - _detect_conflict_query = 'EXECUTE detect_conflicts' - - on_load_opened_statement_names = () - - on_store_opened_statement_names = on_load_opened_statement_names - on_store_opened_statement_names += ('_prepare_detect_conflict_query',) - @metricmethod_sampled def on_store_opened(self, cursor, restart=False): """Create the temporary table for storing objects""" diff --git a/src/relstorage/adapters/mysql/txncontrol.py b/src/relstorage/adapters/mysql/txncontrol.py index d25c8ced..d7a01b1c 100644 --- a/src/relstorage/adapters/mysql/txncontrol.py +++ b/src/relstorage/adapters/mysql/txncontrol.py @@ -19,7 +19,4 @@ class MySQLTransactionControl(GenericTransactionControl): - - # See adapter.py for where this is prepared. - # Either history preserving or not, it's the same. - _get_tid_query = 'EXECUTE get_latest_tid' + pass diff --git a/src/relstorage/adapters/oracle/adapter.py b/src/relstorage/adapters/oracle/adapter.py index 88945968..3012e509 100644 --- a/src/relstorage/adapters/oracle/adapter.py +++ b/src/relstorage/adapters/oracle/adapter.py @@ -110,23 +110,24 @@ def __init__(self, user, password, dsn, commit_lock_id=0, self.oidallocator = OracleOIDAllocator( connmanager=self.connmanager, ) + + + self.poller = Poller( + self.driver, + keep_history=self.keep_history, + runner=self.runner, + revert_when_stale=options.revert_when_stale, + ) + self.txncontrol = OracleTransactionControl( connmanager=self.connmanager, + poller=self.poller, keep_history=self.keep_history, Binary=driver.Binary, twophase=twophase, ) - if self.keep_history: - poll_query = "SELECT MAX(tid) FROM transaction" - else: - poll_query = "SELECT MAX(tid) FROM object_state" - self.poller = Poller( - poll_query=poll_query, - keep_history=self.keep_history, - runner=self.runner, - revert_when_stale=options.revert_when_stale, - ) + if self.keep_history: self.packundo = OracleHistoryPreservingPackUndo( @@ -138,7 +139,6 @@ def __init__(self, user, password, dsn, commit_lock_id=0, ) self.dbiter = HistoryPreservingDatabaseIterator( driver, - runner=self.runner, ) else: self.packundo = OracleHistoryFreePackUndo( @@ -150,7 +150,6 @@ def __init__(self, user, password, dsn, commit_lock_id=0, ) self.dbiter = HistoryFreeDatabaseIterator( driver, - runner=self.runner, ) self.stats = OracleStats( diff --git a/src/relstorage/adapters/oracle/dialect.py b/src/relstorage/adapters/oracle/dialect.py new file mode 100644 index 00000000..dde5dc34 --- /dev/null +++ b/src/relstorage/adapters/oracle/dialect.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +""" +The Oracle dialect + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from ..sql import DefaultDialect +from ..sql import Compiler +from ..sql import Boolean +from ..sql import Column + +class OracleCompiler(Compiler): + + def visit_boolean_literal_expression(self, value): + sql = "'Y'" if value else "'N'" + self.emit(sql) + + def can_prepare(self): + # We haven't investigated preparing statements manually + # with cx_Oracle. There's a chance that `cx_Oracle.Connection.stmtcachesize` + # will accomplish all we need. + return False + + def _placeholder(self, key): + # XXX: What's that block in the parent about key == '%s'? + # Only during prepare() I think. + return ':' + key + + def visit_ordered_bind_param(self, bind_param): + ph = self.placeholders[bind_param] = ':%d' % (len(self.placeholders) + 1,) + self.emit(ph) + + def visit_select_expression(self, column_node): + if isinstance(column_node, Column) and column_node.is_type(Boolean): + # Fancy CASE statement to get 1 or 0 into Python + self.emit_keyword('CASE WHEN') + self.emit_identifier(column_node.name) + self.emit(" = 'Y' THEN 1 ELSE 0 END") + else: + super(OracleCompiler, self).visit_select_expression(column_node) + + +class OracleDialect(DefaultDialect): + + def compiler_class(self): + return OracleCompiler diff --git a/src/relstorage/adapters/oracle/drivers.py b/src/relstorage/adapters/oracle/drivers.py index 4e34d868..cf28c6b0 100644 --- a/src/relstorage/adapters/oracle/drivers.py +++ b/src/relstorage/adapters/oracle/drivers.py @@ -26,6 +26,7 @@ from .._abstract_drivers import AbstractModuleDriver from .._abstract_drivers import implement_db_driver_options from ..interfaces import IDBDriver +from .dialect import OracleDialect database_type = 'oracle' @@ -37,6 +38,7 @@ class cx_OracleDriver(AbstractModuleDriver): __name__ = 'cx_Oracle' MODULE_NAME = __name__ + dialect = OracleDialect() def __init__(self): super(cx_OracleDriver, self).__init__() diff --git a/src/relstorage/adapters/oracle/mover.py b/src/relstorage/adapters/oracle/mover.py index d1b54ac2..7fff81d8 100644 --- a/src/relstorage/adapters/oracle/mover.py +++ b/src/relstorage/adapters/oracle/mover.py @@ -24,13 +24,6 @@ from ..interfaces import IObjectMover from ..mover import AbstractObjectMover from ..mover import metricmethod_sampled -from .scriptrunner import format_to_named - - -def _to_oracle_ordered(query_tuple): - # Replace %s with :1, :2, etc - assert len(query_tuple) == 2 - return format_to_named(query_tuple[0]), format_to_named(query_tuple[1]) @implementer(IObjectMover) @@ -39,15 +32,6 @@ class OracleObjectMover(AbstractObjectMover): # This is assigned to by the adapter. inputsizes = None - _move_from_temp_hp_insert_query = format_to_named( - AbstractObjectMover._move_from_temp_hp_insert_query) - _move_from_temp_hf_insert_query = format_to_named( - AbstractObjectMover._move_from_temp_hf_insert_query) - _move_from_temp_copy_blob_query = format_to_named( - AbstractObjectMover._move_from_temp_copy_blob_query) - - _load_current_queries = _to_oracle_ordered(AbstractObjectMover._load_current_queries) - @metricmethod_sampled def load_current(self, cursor, oid): stmt = self._load_current_query @@ -55,8 +39,6 @@ def load_current(self, cursor, oid): cursor, stmt, (oid,), default=(None, None)) - _load_revision_query = format_to_named(AbstractObjectMover._load_revision_query) - @metricmethod_sampled def load_revision(self, cursor, oid, tid): stmt = self._load_revision_query @@ -65,8 +47,6 @@ def load_revision(self, cursor, oid, tid): return state - _exists_queries = _to_oracle_ordered(AbstractObjectMover._exists_queries) - @metricmethod_sampled def exists(self, cursor, oid): stmt = self._exists_query @@ -246,9 +226,7 @@ def replace_temp(self, cursor, oid, prev_tid, data): - _update_current_insert_query = format_to_named(AbstractObjectMover._update_current_insert_query) - _update_current_update_query = format_to_named(AbstractObjectMover._update_current_update_query) - _update_current_update_query = _update_current_update_query.replace('ORDER BY zoid', '') + # XXX: For _update_current_update_query we used to remove 'ORDER BY zoid'. Still needed? @metricmethod_sampled def download_blob(self, cursor, oid, tid, filename): diff --git a/src/relstorage/adapters/oracle/scriptrunner.py b/src/relstorage/adapters/oracle/scriptrunner.py index 08076a6e..ee0ee5b8 100644 --- a/src/relstorage/adapters/oracle/scriptrunner.py +++ b/src/relstorage/adapters/oracle/scriptrunner.py @@ -15,9 +15,6 @@ from __future__ import absolute_import import logging -import re - -from relstorage._compat import intern from relstorage._compat import iteritems from ..scriptrunner import ScriptRunner @@ -26,11 +23,17 @@ _stmt_cache = {} -def format_to_named(stmt): +def _format_to_named(stmt): """ Convert '%s' pyformat strings to :n numbered strings. Intended only for static strings. + + This is legacy. Replace strings that use this with SQL statements + constructed from the schema. """ + import re + from relstorage._compat import intern + try: return _stmt_cache[stmt] except KeyError: @@ -90,7 +93,7 @@ def run_many(self, cursor, stmt, items): stmt should use '%s' parameter format. """ - cursor.executemany(format_to_named(stmt), items) + cursor.executemany(_format_to_named(stmt), items) class TrackingMap(object): diff --git a/src/relstorage/adapters/oracle/tests/__init__.py b/src/relstorage/adapters/oracle/tests/__init__.py new file mode 100644 index 00000000..68336058 --- /dev/null +++ b/src/relstorage/adapters/oracle/tests/__init__.py @@ -0,0 +1 @@ +# Oracle test package. diff --git a/src/relstorage/adapters/oracle/tests/test_dialect.py b/src/relstorage/adapters/oracle/tests/test_dialect.py new file mode 100644 index 00000000..8fe43559 --- /dev/null +++ b/src/relstorage/adapters/oracle/tests/test_dialect.py @@ -0,0 +1,182 @@ +# -*- coding: utf-8 -*- +""" +Tests for the Oracle dialect. + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from relstorage.tests import TestCase +from relstorage.tests import MockCursor + +from ...schema import Schema +from ...sql import it + +from ..dialect import OracleDialect + +class Context(object): + keep_history = True + dialect = OracleDialect() + +class Driver(object): + dialect = OracleDialect() + +class TestOracleDialect(TestCase): + + def test_boolean_parameter(self): + transaction = Schema.transaction + object_state = Schema.object_state + + stmt = transaction.natural_join( + object_state + ).select( + it.c.tid, it.c.username, it.c.description, it.c.extension, + object_state.c.state_size + ).where( + it.c.zoid == it.bindparam("oid") + ).and_( + it.c.packed == False # pylint:disable=singleton-comparison + ).order_by( + it.c.tid, "DESC" + ) + + # By default we get the generic syntax + self.assertEqual( + str(stmt), + 'SELECT tid, username, description, extension, state_size ' + 'FROM transaction ' + 'JOIN object_state ' + 'USING (tid) ' + "WHERE ((zoid = %(oid)s AND packed = FALSE)) " + 'ORDER BY tid DESC' + ) + + # But once we bind to the dialect we get the expected value + stmt = stmt.bind(Context) + self.assertEqual( + str(stmt), + 'SELECT tid, username, description, extension, state_size ' + 'FROM transaction ' + 'JOIN object_state ' + 'USING (tid) ' + "WHERE ((zoid = :oid AND packed = 'N')) " + 'ORDER BY tid DESC' + ) + + def test_boolean_result(self): + """ + SELECT tid, username, description, extension, + + FROM transaction + WHERE tid >= 0 + """ + + transaction = Schema.transaction + stmt = transaction.select( + transaction.c.tid, + transaction.c.username, + transaction.c.description, + transaction.c.extension, + transaction.c.packed + ).where( + transaction.c.tid >= 0 + ) + + stmt = stmt.bind(Context) + + self.assertEqual( + str(stmt), + 'SELECT tid, username, description, extension, ' + "CASE WHEN packed = 'Y' THEN 1 ELSE 0 END " + 'FROM transaction WHERE (tid >= :literal_0)' + ) + +class TestMoverQueries(TestCase): + + def _makeOne(self): + from ..mover import OracleObjectMover + from relstorage.tests import MockOptions + + return OracleObjectMover(Driver(), MockOptions()) + + def test_move_named_query(self): + + inst = self._makeOne() + + unbound = type(inst)._move_from_temp_hf_insert_query + + self.assertRegex( + str(unbound), + r'EXECUTE rs_prep_stmt_.*' + ) + + # Bound, the py-format style escapes are replaced with + # :name params; these are simple numbers. + self.assertEqual( + str(inst._move_from_temp_hf_insert_query), + 'INSERT INTO object_state(zoid, tid, state_size, state) ' + 'SELECT zoid, :1, COALESCE(LENGTH(state), 0), state ' + 'FROM temp_store ORDER BY zoid' + ) + +class TestDatabaseIteratorQueries(TestCase): + + def _makeOne(self): + from relstorage.adapters.dbiter import HistoryPreservingDatabaseIterator + return HistoryPreservingDatabaseIterator(Driver()) + + def test_query(self): + inst = self._makeOne() + + unbound = type(inst)._iter_objects_query + + self.assertEqual( + str(unbound), + 'SELECT zoid, state FROM object_state WHERE (tid = %(tid)s) ORDER BY zoid' + ) + + # Bound we get pyformat %(name)s replaced with :name + self.assertEqual( + str(inst._iter_objects_query), + 'SELECT zoid, state FROM object_state WHERE (tid = :tid) ORDER BY zoid' + ) + + cursor = MockCursor() + list(inst.iter_objects(cursor, 1)) # Iterator, must flatten + + self.assertEqual(1, len(cursor.executed)) + stmt, params = cursor.executed[0] + + self.assertEqual( + stmt, + 'SELECT zoid, state FROM object_state WHERE (tid = :tid) ORDER BY zoid' + ) + + self.assertEqual( + params, + {'tid': 1} + ) + + def test_iter_transactions(self): + inst = self._makeOne() + cursor = MockCursor() + + inst.iter_transactions(cursor) + + self.assertEqual(1, len(cursor.executed)) + stmt, params = cursor.executed[0] + + self.assertEqual( + stmt, + 'SELECT tid, username, description, extension ' + 'FROM transaction ' + "WHERE ((packed = 'N' AND tid <> :literal_0)) " + 'ORDER BY tid DESC' + ) + + self.assertEqual( + params, + {'literal_0': 0} + ) diff --git a/src/relstorage/adapters/oracle/txncontrol.py b/src/relstorage/adapters/oracle/txncontrol.py index 4607a180..8095f3f7 100644 --- a/src/relstorage/adapters/oracle/txncontrol.py +++ b/src/relstorage/adapters/oracle/txncontrol.py @@ -25,8 +25,8 @@ class OracleTransactionControl(GenericTransactionControl): - def __init__(self, connmanager, keep_history, Binary, twophase): - GenericTransactionControl.__init__(self, connmanager, keep_history, Binary) + def __init__(self, connmanager, poller, keep_history, Binary, twophase): + GenericTransactionControl.__init__(self, connmanager, poller, keep_history, Binary) self.twophase = twophase def commit_phase1(self, conn, cursor, tid): diff --git a/src/relstorage/adapters/poller.py b/src/relstorage/adapters/poller.py index dc31c381..e59bf88c 100644 --- a/src/relstorage/adapters/poller.py +++ b/src/relstorage/adapters/poller.py @@ -17,12 +17,13 @@ from zope.interface import implementer -from ._util import formatted_query_property from .interfaces import IPoller from .interfaces import StaleConnectionError -log = logging.getLogger(__name__) +from .schema import Schema +from .sql import func +log = logging.getLogger(__name__) @implementer(IPoller) class Poller(object): @@ -31,44 +32,30 @@ class Poller(object): # The zoid is the primary key on both ``current_object`` (history # preserving) and ``object_state`` (history free), so these # queries are guaranteed to only produce an OID once. - _list_changes_range_queries = ( - """ - SELECT zoid, tid - FROM current_object - WHERE tid > %(min_tid)s - AND tid <= %(max_tid)s - """, - """ - SELECT zoid, tid - FROM object_state - WHERE tid > %(min_tid)s - AND tid <= %(max_tid)s - """ - ) - - _list_changes_range_query = formatted_query_property('_list_changes_range') - - _poll_inv_queries = ( - """ - SELECT zoid, tid - FROM current_object - WHERE tid > %(tid)s - """, - """ - SELECT zoid, tid - FROM object_state - WHERE tid > %(tid)s - """ - ) - - _poll_inv_query = formatted_query_property('_poll_inv') - - _poll_inv_exc_query = formatted_query_property('_poll_inv', - extension=' AND tid != %(self_tid)s') - - - def __init__(self, poll_query, keep_history, runner, revert_when_stale): - self.poll_query = poll_query + _list_changes_range_query = Schema.all_current_object.select( + Schema.all_current_object.c.zoid, Schema.all_current_object.c.tid + ).where( + Schema.all_current_object.c.tid > Schema.all_current_object.bindparam('min_tid') + ).and_( + Schema.all_current_object.c.tid <= Schema.all_current_object.bindparam('max_tid') + ).prepared() + + _poll_inv_query = Schema.all_current_object.select( + Schema.all_current_object.c.zoid, Schema.all_current_object.c.tid + ).where( + Schema.all_current_object.c.tid > Schema.all_current_object.bindparam('tid') + ).prepared() + + _poll_inv_exc_query = _poll_inv_query.and_( + Schema.all_current_object.c.tid != Schema.all_current_object.bindparam('self_tid') + ).prepared() + + poll_query = Schema.all_transaction.select( + func.max(Schema.all_transaction.c.tid) + ).prepared() + + def __init__(self, driver, keep_history, runner, revert_when_stale): + self.driver = driver self.keep_history = keep_history self.runner = runner self.revert_when_stale = revert_when_stale @@ -92,7 +79,7 @@ def poll_invalidations(self, conn, cursor, prev_polled_tid, ignore_tid): """ # pylint:disable=unused-argument # find out the tid of the most recent transaction. - cursor.execute(self.poll_query) + self.poll_query.execute(cursor) rows = cursor.fetchall() if not rows or not rows[0][0]: # No data, must be fresh database, without even @@ -148,8 +135,8 @@ def poll_invalidations(self, conn, cursor, prev_polled_tid, ignore_tid): # all the unreachable objects will be garbage collected # anyway. # - # Thus we became convinced it was safe to remove the check in history-preserving - # databases. + # Thus we became convinced it was safe to remove the check in + # history-preserving databases. # Get the list of changed OIDs and return it. stmt = self._poll_inv_query @@ -158,7 +145,7 @@ def poll_invalidations(self, conn, cursor, prev_polled_tid, ignore_tid): stmt = self._poll_inv_exc_query params['self_tid'] = ignore_tid - cursor.execute(stmt, params) + stmt.execute(cursor, params) changes = cursor.fetchall() return changes, new_polled_tid @@ -167,5 +154,5 @@ def list_changes(self, cursor, after_tid, last_tid): See ``IPoller``. """ params = {'min_tid': after_tid, 'max_tid': last_tid} - cursor.execute(self._list_changes_range_query, params) + self._list_changes_range_query.execute(cursor, params) return cursor.fetchall() diff --git a/src/relstorage/adapters/postgresql/adapter.py b/src/relstorage/adapters/postgresql/adapter.py index 337576f9..d2e1edc1 100644 --- a/src/relstorage/adapters/postgresql/adapter.py +++ b/src/relstorage/adapters/postgresql/adapter.py @@ -22,32 +22,44 @@ from ...options import Options from .._abstract_drivers import _select_driver -from .._util import query_property + from ..dbiter import HistoryFreeDatabaseIterator from ..dbiter import HistoryPreservingDatabaseIterator from ..interfaces import IRelStorageAdapter from ..packundo import HistoryFreePackUndo from ..packundo import HistoryPreservingPackUndo from ..poller import Poller +from ..schema import Schema from ..scriptrunner import ScriptRunner from . import drivers from .batch import PostgreSQLRowBatcher from .connmanager import Psycopg2ConnectionManager from .locker import PostgreSQLLocker -from .mover import PG8000ObjectMover from .mover import PostgreSQLObjectMover -from .mover import to_prepared_queries + from .oidallocator import PostgreSQLOIDAllocator from .schema import PostgreSQLSchemaInstaller from .stats import PostgreSQLStats from .txncontrol import PostgreSQLTransactionControl -from .txncontrol import PG8000TransactionControl + log = logging.getLogger(__name__) def select_driver(options=None): return _select_driver(options or Options(), drivers) +# TODO: Move to own file +class PGPoller(Poller): + + poll_query = Schema.all_transaction.select( + Schema.all_transaction.c.tid + ).order_by( + Schema.all_transaction.c.tid, dir='DESC' + ).limit( + 1 + ).prepared() + + @implementer(IRelStorageAdapter) class PostgreSQLAdapter(object): """PostgreSQL adapter for RelStorage.""" @@ -83,13 +95,7 @@ def __init__(self, dsn='', options=None): locker=self.locker, ) - mover_type = PostgreSQLObjectMover - txn_type = PostgreSQLTransactionControl - if driver.__name__ == 'pg8000': - mover_type = PG8000ObjectMover - txn_type = PG8000TransactionControl - - self.mover = mover_type( + self.mover = PostgreSQLObjectMover( driver, options=options, runner=self.runner, @@ -97,19 +103,21 @@ def __init__(self, dsn='', options=None): batcher_factory=PostgreSQLRowBatcher, ) self.oidallocator = PostgreSQLOIDAllocator() - self.txncontrol = txn_type( - connmanager=self.connmanager, - keep_history=self.keep_history, - driver=driver, - ) - self.poller = Poller( - poll_query="EXECUTE get_latest_tid", + self.poller = PGPoller( + self.driver, keep_history=self.keep_history, runner=self.runner, revert_when_stale=options.revert_when_stale, ) + self.txncontrol = PostgreSQLTransactionControl( + connmanager=self.connmanager, + poller=self.poller, + keep_history=self.keep_history, + Binary=driver.Binary, + ) + if self.keep_history: self.packundo = HistoryPreservingPackUndo( driver, @@ -120,7 +128,6 @@ def __init__(self, dsn='', options=None): ) self.dbiter = HistoryPreservingDatabaseIterator( driver, - runner=self.runner, ) else: self.packundo = HistoryFreePackUndo( @@ -134,7 +141,6 @@ def __init__(self, dsn='', options=None): self.packundo._lock_for_share = 'FOR KEY SHARE OF object_state' self.dbiter = HistoryFreeDatabaseIterator( driver, - runner=self.runner, ) self.stats = PostgreSQLStats( @@ -144,49 +150,11 @@ def __init__(self, dsn='', options=None): self.connmanager.add_on_store_opened(self.mover.on_store_opened) self.connmanager.add_on_load_opened(self.mover.on_load_opened) - self.connmanager.add_on_load_opened(self.__prepare_statements) self.connmanager.add_on_store_opened(self.__prepare_store_statements) - _get_latest_tid_queries = ( - """ - SELECT tid - FROM transaction - ORDER BY tid DESC - LIMIT 1 - """, - """ - SELECT tid - FROM object_state - ORDER BY tid DESC - LIMIT 1 - """ - ) - - _prepare_get_latest_tid_queries = to_prepared_queries( - 'get_latest_tid', - _get_latest_tid_queries) - - _prepare_get_latest_tid_query = query_property('_prepare_get_latest_tid') - - def __prepare_statements(self, cursor, restart=False): - if restart: - return - - # TODO: Generalize all of this better. There should be a - # registry of things to prepare, or we should wrap cursors to - # detect and prepare when needed. Preparation and switching to - # EXECUTE should be automatic for drivers that don't already do that. - - # A meta-class or base class __new__ could handle proper - # history/free query selection without this mass of tuples and - # manual properties and property names. - stmt = self._prepare_get_latest_tid_query - cursor.execute(stmt) - def __prepare_store_statements(self, cursor, restart=False): if not restart: - self.__prepare_statements(cursor, restart) try: stmt = self.txncontrol._prepare_add_transaction_query except (Unsupported, AttributeError): diff --git a/src/relstorage/adapters/postgresql/drivers/__init__.py b/src/relstorage/adapters/postgresql/drivers/__init__.py index 04c2e12b..1b506045 100644 --- a/src/relstorage/adapters/postgresql/drivers/__init__.py +++ b/src/relstorage/adapters/postgresql/drivers/__init__.py @@ -22,6 +22,16 @@ from __future__ import print_function from ..._abstract_drivers import implement_db_driver_options +from ..._abstract_drivers import AbstractModuleDriver +from ...sql import DefaultDialect + +class PostgreSQLDialect(DefaultDialect): + """ + The defaults are setup for PostgreSQL. + """ + +class AbstractPostgreSQLDriver(AbstractModuleDriver): + dialect = PostgreSQLDialect() database_type = 'postgresql' diff --git a/src/relstorage/adapters/postgresql/drivers/pg8000.py b/src/relstorage/adapters/postgresql/drivers/pg8000.py index 57c9bcf2..6a6d6a91 100644 --- a/src/relstorage/adapters/postgresql/drivers/pg8000.py +++ b/src/relstorage/adapters/postgresql/drivers/pg8000.py @@ -23,8 +23,11 @@ from zope.interface import implementer -from ..._abstract_drivers import AbstractModuleDriver from ...interfaces import IDBDriver +from ...sql import Compiler + +from . import AbstractPostgreSQLDriver +from . import PostgreSQLDialect __all__ = [ 'PG8000Driver', @@ -115,8 +118,24 @@ class _tuple_deque(deque): def append(self, row): # pylint:disable=arguments-differ deque.append(self, tuple(row)) +class PG8000Compiler(Compiler): + + def can_prepare(self): + # Important: pg8000 1.10 - 1.13, at least, can't handle prepared + # statements that take parameters but it doesn't need to because it + # prepares every statement anyway. So you must have a backup that you use + # for that driver. + # https://github.com/mfenniak/pg8000/issues/132 + return False + +class PG8000Dialect(PostgreSQLDialect): + + def compiler_class(self): + return PG8000Compiler + + @implementer(IDBDriver) -class PG8000Driver(AbstractModuleDriver): +class PG8000Driver(AbstractPostgreSQLDriver): __name__ = 'pg8000' MODULE_NAME = __name__ PRIORITY = 3 @@ -125,6 +144,8 @@ class PG8000Driver(AbstractModuleDriver): _GEVENT_CAPABLE = True _GEVENT_NEEDS_SOCKET_PATCH = True + dialect = PG8000Dialect() + def __init__(self): super(PG8000Driver, self).__init__() # XXX: global side-effect! @@ -229,3 +250,5 @@ def connect_with_isolation(self, dsn, cursor.execute('SET SESSION CHARACTERISTICS AS ' + transaction_stmt) conn.commit() return conn, cursor + + sql_compiler_class = PG8000Compiler diff --git a/src/relstorage/adapters/postgresql/drivers/psycopg2.py b/src/relstorage/adapters/postgresql/drivers/psycopg2.py index 30e1dca2..b7776fd0 100644 --- a/src/relstorage/adapters/postgresql/drivers/psycopg2.py +++ b/src/relstorage/adapters/postgresql/drivers/psycopg2.py @@ -22,8 +22,8 @@ from zope.interface import implementer from relstorage._compat import PY3 -from ..._abstract_drivers import AbstractModuleDriver from ...interfaces import IDBDriver +from . import AbstractPostgreSQLDriver __all__ = [ @@ -32,7 +32,7 @@ @implementer(IDBDriver) -class Psycopg2Driver(AbstractModuleDriver): +class Psycopg2Driver(AbstractPostgreSQLDriver): __name__ = 'psycopg2' MODULE_NAME = __name__ diff --git a/src/relstorage/adapters/postgresql/mover.py b/src/relstorage/adapters/postgresql/mover.py index 275477dd..b3caf2ca 100644 --- a/src/relstorage/adapters/postgresql/mover.py +++ b/src/relstorage/adapters/postgresql/mover.py @@ -20,76 +20,17 @@ import struct from zope.interface import implementer -from ZODB.POSException import Unsupported -from .._util import query_property + from ..interfaces import IObjectMover from ..mover import AbstractObjectMover from ..mover import metricmethod_sampled -# Important: pg8000 1.10 - 1.13, at least, can't handle prepared -# statements that take parameters but it doesn't need to because it -# prepares every statement anyway. So you must have a backup that you use -# for that driver. -# https://github.com/mfenniak/pg8000/issues/132 - - -def to_prepared_queries(name, queries, datatypes=()): - # Give correct datatypes for the queries, wherever possible. - # The number of parameters should be the same or more than the - # number of datatypes. - # datatypes is a sequence of strings. - - # Maybe instead of having the adapter have to know about all the - # statements that need prepared, we could keep a registry? - if datatypes: - assert isinstance(datatypes, (list, tuple)) - datatypes = ', '.join(datatypes) - datatypes = ' (%s)' % (datatypes,) - else: - datatypes = '' - - result = [] - for q in queries: - if not isinstance(q, str): - # Unsupported marker - result.append(q) - continue - - q = q.strip() - param_count = q.count('%s') - rep_count = 0 - while rep_count < param_count: - rep_count += 1 - q = q.replace('%s', '$' + str(rep_count), 1) - stmt = 'PREPARE {name}{datatypes} AS {query}'.format( - name=name, datatypes=datatypes, query=q - ) - result.append(stmt) - return result - @implementer(IObjectMover) class PostgreSQLObjectMover(AbstractObjectMover): - _prepare_load_current_queries = to_prepared_queries( - 'load_current', - AbstractObjectMover._load_current_queries, - ['BIGINT']) - - _prepare_load_current_query = query_property('_prepare_load_current') - - _load_current_query = 'EXECUTE load_current(%s)' - - _prepare_detect_conflict_queries = to_prepared_queries( - 'detect_conflicts', - AbstractObjectMover._detect_conflict_queries) - - _prepare_detect_conflict_query = query_property('_prepare_detect_conflict') - - _detect_conflict_query = 'EXECUTE detect_conflicts' - - _move_from_temp_hf_insert_query_raw = AbstractObjectMover._move_from_temp_hf_insert_query + """ + _move_from_temp_hf_insert_query = AbstractObjectMover._move_from_temp_hf_insert_query + """ ON CONFLICT (zoid) DO UPDATE SET state_size = COALESCE(LENGTH(excluded.state), 0), @@ -107,36 +48,10 @@ class PostgreSQLObjectMover(AbstractObjectMover): """ _update_current_update_query = None - _move_from_temp_hf_insert_raw_queries = ( - Unsupported("States accumulate in history-preserving mode"), - _move_from_temp_hf_insert_query_raw, - ) - - _prepare_move_from_temp_hf_insert_queries = to_prepared_queries( - 'move_from_temp', - _move_from_temp_hf_insert_raw_queries, - ('BIGINT',) - ) - - _prepare_move_from_temp_hf_insert_query = query_property( - '_prepare_move_from_temp_hf_insert') - - _move_from_temp_hf_insert_queries = ( - Unsupported("States accumulate in history-preserving mode"), - 'EXECUTE move_from_temp(%s)' - ) - - _move_from_temp_hf_insert_query = query_property('_move_from_temp_hf_insert') # We upsert, no need _move_from_temp_hf_delete_query = '' - on_load_opened_statement_names = ('_prepare_load_current_query',) - on_store_opened_statement_names = on_load_opened_statement_names + ( - '_prepare_detect_conflict_query', - '_prepare_move_from_temp_hf_insert_query', - ) - @metricmethod_sampled def on_store_opened(self, cursor, restart=False): @@ -175,6 +90,9 @@ def on_store_opened(self, cursor, restart=False): """, ] + # XXX: we're not preparing statements anymore until just + # before we want to use them, so this is no longer needed. + # For some reason, preparing the INSERT statement also # wants to acquire a lock. If we're committing in another # transaction, this can block indefinitely (if that other @@ -193,7 +111,7 @@ def on_store_opened(self, cursor, restart=False): # this to 100, but under high concurrency (10 processes) # that turned out to be laughably optimistic. We might # actually need to go as high as the commit lock timeout. - cursor.execute('SET lock_timeout = 10000') + # cursor.execute('SET lock_timeout = 10000') for stmt in ddl_stmts: cursor.execute(stmt) @@ -334,19 +252,6 @@ def store_temps(self, cursor, state_oid_tid_iter): cursor.copy_expert(buf.COPY_COMMAND, buf) -class PG8000ObjectMover(PostgreSQLObjectMover): - # Delete the statements that need paramaters. - on_load_opened_statement_names = () - on_store_opened_statement_names = ('_prepare_detect_conflict_query',) - - _load_current_query = AbstractObjectMover._load_current_query - - _move_from_temp_hf_insert_queries = ( - Unsupported("States accumulate in history-preserving mode"), - PostgreSQLObjectMover._move_from_temp_hf_insert_query_raw - ) - - class TempStoreCopyBuffer(io.BufferedIOBase): """ A binary file-like object for putting data into diff --git a/src/relstorage/adapters/postgresql/tests/test_mover.py b/src/relstorage/adapters/postgresql/tests/test_mover.py index be930cc0..d2455f11 100644 --- a/src/relstorage/adapters/postgresql/tests/test_mover.py +++ b/src/relstorage/adapters/postgresql/tests/test_mover.py @@ -16,18 +16,18 @@ from __future__ import division from __future__ import print_function -from ZODB.POSException import Unsupported +import unittest from relstorage.tests import TestCase from relstorage.tests import MockOptions +from relstorage.tests import MockDriver from .. import mover -class MockDriver(object): - pass +@unittest.skip("Needs moved to test__sql") class TestFunctions(TestCase): - + # pylint:disable=no-member def _prepare1(self, query, name='prepped', datatypes=()): return mover.to_prepared_queries(name, [query], datatypes)[0] @@ -93,23 +93,8 @@ def _makeOne(self, **options): def test_prep_statements_hf(self): inst = self._makeOne(keep_history=False) - self.assertEqual( - inst._move_from_temp_hf_insert_query, - self._expected_move_from_temp_hf_insert_query + self.assertTrue( + str(inst._move_from_temp_hf_insert_query).startswith( + 'EXECUTE rs_prep_stmt' + ) ) - - def test_prep_statements_hp(self): - inst = self._makeOne(keep_history=True) - with self.assertRaises(Unsupported): - getattr(inst, '_move_from_temp_hf_insert_query') - - -class TestPG8000ObjectMover(TestPostgreSQLObjectMover): - - def setUp(self): - super(TestPG8000ObjectMover, self).setUp() - raw = mover.PostgreSQLObjectMover._move_from_temp_hf_insert_query_raw - self._expected_move_from_temp_hf_insert_query = raw - - def _getClass(self): - return mover.PG8000ObjectMover diff --git a/src/relstorage/adapters/postgresql/tests/test_txncontrol.py b/src/relstorage/adapters/postgresql/tests/test_txncontrol.py index f556efd5..0d32b565 100644 --- a/src/relstorage/adapters/postgresql/tests/test_txncontrol.py +++ b/src/relstorage/adapters/postgresql/tests/test_txncontrol.py @@ -18,22 +18,9 @@ from relstorage.adapters.tests import test_txncontrol -class MockDriver(object): - Binary = bytes class TestPostgreSQLTransactionControl(test_txncontrol.TestTransactionControl): - def setUp(self): - super(TestPostgreSQLTransactionControl, self).setUp() - - driver = MockDriver() - driver.Binary = super(TestPostgreSQLTransactionControl, self).Binary - self.Binary = driver - - def tearDown(self): - del self.Binary - super(TestPostgreSQLTransactionControl, self).tearDown() - def _getClass(self): from ..txncontrol import PostgreSQLTransactionControl return PostgreSQLTransactionControl diff --git a/src/relstorage/adapters/postgresql/txncontrol.py b/src/relstorage/adapters/postgresql/txncontrol.py index ed9a51b4..93211060 100644 --- a/src/relstorage/adapters/postgresql/txncontrol.py +++ b/src/relstorage/adapters/postgresql/txncontrol.py @@ -15,43 +15,11 @@ from __future__ import absolute_import -from ZODB.POSException import Unsupported from ..txncontrol import GenericTransactionControl -from .._util import query_property -from .mover import to_prepared_queries class _PostgreSQLTransactionControl(GenericTransactionControl): - - # See adapter.py for where this is prepared. - # Either history preserving or not, it's the same. - _get_tid_query = 'EXECUTE get_latest_tid' - - - def __init__(self, connmanager, keep_history, driver): - super(_PostgreSQLTransactionControl, self).__init__( - connmanager, - keep_history, - driver.Binary - ) + pass class PostgreSQLTransactionControl(_PostgreSQLTransactionControl): - - # (tid, packed, username, description, extension) - _add_transaction_query = 'EXECUTE add_transaction(%s, %s, %s, %s, %s)' - - _prepare_add_transaction_queries = to_prepared_queries( - 'add_transaction', - [ - GenericTransactionControl._add_transaction_query, - Unsupported("No transactions in HF mode"), - ], - ('BIGINT', 'BOOLEAN', 'BYTEA', 'BYTEA', 'BYTEA') - ) - - _prepare_add_transaction_query = query_property('_prepare_add_transaction') - - -class PG8000TransactionControl(_PostgreSQLTransactionControl): - # We can't handle the parameterized prepared statements. pass diff --git a/src/relstorage/adapters/schema.py b/src/relstorage/adapters/schema.py index 5c1a89f2..3fca0f06 100644 --- a/src/relstorage/adapters/schema.py +++ b/src/relstorage/adapters/schema.py @@ -25,12 +25,77 @@ from ._util import query_property from ._util import noop_when_history_free +from .sql import Table +from .sql import TemporaryTable +from .sql import Column +from .sql import HistoryVariantTable +from .sql import OID +from .sql import TID +from .sql import State +from .sql import Boolean +from .sql import BinaryString + + log = logging.getLogger("relstorage") tmpl_property = partial(query_property, property_suffix='_TMPLS', lazy_suffix='_TMPL') +class Schema(object): + current_object = Table( + 'current_object', + Column('zoid', OID), + Column('tid', TID) + ) + + object_state = Table( + 'object_state', + Column('zoid', OID), + Column('tid', TID), + Column('state', State), + Column('state_size'), + ) + + # Does the right thing whether history free or preserving + all_current_object = HistoryVariantTable( + current_object, + object_state, + ) + + # Does the right thing whether history free or preserving + all_current_object_state = HistoryVariantTable( + current_object.natural_join(object_state), + object_state + ) + + temp_store = TemporaryTable( + 'temp_store', + Column('zoid', OID), + Column('prev_tid', TID), + Column('md5'), + Column('state', State) + ) + + transaction = Table( + 'transaction', + Column('tid', TID), + Column('packed', Boolean), + Column('username', BinaryString), + Column('description', BinaryString), + Column('extension', BinaryString), + ) + + commit_row_lock = Table( + 'commit_row_lock', + Column('tid'), + ) + + all_transaction = HistoryVariantTable( + transaction, + object_state, + ) + class AbstractSchemaInstaller(DatabaseHelpersMixin, ABC): diff --git a/src/relstorage/adapters/sql/__init__.py b/src/relstorage/adapters/sql/__init__.py new file mode 100644 index 00000000..f2d94540 --- /dev/null +++ b/src/relstorage/adapters/sql/__init__.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- +############################################################################## +# +# Copyright (c) 2019 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## + +""" +A small abstraction layer for SQL queries. + +Features: + + - Simple, readable syntax for writing queries. + + - Automatic switching between history-free and history-preserving + schemas. + + - Support for automatically preparing statements (depending on the + driver; some allow parameters, some do not, for example) + + - Always use bind parameters + + - Take care of minor database syntax issues. + +This is inspired by the SQLAlchemy Core, but we don't use it because +we don't want to take a dependency that can conflict with applications +using RelStorage. +""" +# pylint:disable=too-many-lines + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from .schema import Table +from .schema import TemporaryTable +from .schema import HistoryVariantTable +from .schema import Column +from .schema import ColumnResolvingProxy + +from .dialect import DefaultDialect +from .dialect import Compiler + +from .types import OID +from .types import TID +from .types import State +from .types import Boolean +from .types import BinaryString + +from .functions import func + +it = ColumnResolvingProxy() + +__all__ = [ + # Schema elements + 'Table', + 'Column', + "TemporaryTable", + "HistoryVariantTable", + + # Query helpers + 'it', + + # Dialect + "DefaultDialect", + "Compiler", + + # Types + "OID", + "TID", + "State", + "Boolean", + "BinaryString", + + # Functions + "func", + +] diff --git a/src/relstorage/adapters/sql/_util.py b/src/relstorage/adapters/sql/_util.py new file mode 100644 index 00000000..077943ee --- /dev/null +++ b/src/relstorage/adapters/sql/_util.py @@ -0,0 +1,73 @@ +# -*- coding: utf-8 -*- +""" +Utility functions and base classes. + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from copy import copy as stdlib_copy + +from .interfaces import IBindParam + +def copy(obj): + new = stdlib_copy(obj) + volatile = [k for k in vars(new) if k.startswith('_v')] + for k in volatile: + delattr(new, k) + return new + +class Resolvable(object): + + __slots__ = () + + def resolve_against(self, table): + # pylint:disable=unused-argument + return self + + + +class Columns(object): + """ + Grab bag of columns. + """ + + def __init__(self, columns): + cs = [] + for c in columns: + setattr(self, c.name, c) + cs.append(c) + self._columns = tuple(cs) + + def __bool__(self): + return bool(self._columns) + + __nonzero__ = __bool__ + + def __getattr__(self, name): + # Here only so that pylint knows this class has a set of + # dynamic attributes. + raise AttributeError("Column list %s does not include %s" % ( + self._col_list(), + name + )) + + def __getitem__(self, ix): + return self._columns[ix] + + def _col_list(self): + return ','.join(str(c) for c in self._columns) + + def __compile_visit__(self, compiler): + compiler.visit_csv(self._columns) + + def has_bind_param(self): + return any( + IBindParam.providedBy(c) + for c in self._columns + ) + + def as_select_list(self): + from .select import _SelectColumns + return _SelectColumns(self._columns) diff --git a/src/relstorage/adapters/sql/ast.py b/src/relstorage/adapters/sql/ast.py new file mode 100644 index 00000000..b7a6180e --- /dev/null +++ b/src/relstorage/adapters/sql/ast.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +""" +Syntax elements. + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from ._util import Resolvable + + +class LiteralNode(Resolvable): + + __slots__ = ( + 'raw', + 'name', + ) + + def __init__(self, raw): + self.raw = raw + self.name = 'anon_%x' % (id(self),) + + def __compile_visit__(self, compiler): + compiler.emit(str(self.raw)) + + def resolve_against(self, table): + return self + +class BooleanNode(LiteralNode): + + __slots__ = () + +class TextNode(LiteralNode): + __slots__ = () + +def as_node(c): + if isinstance(c, bool): + return BooleanNode(c) + if isinstance(c, int): + return LiteralNode(c) + if isinstance(c, str): + return TextNode(c) + return c + + + +def resolved_against(columns, table): + resolved = [ + as_node(c).resolve_against(table) + for c + in columns + ] + return resolved diff --git a/src/relstorage/adapters/sql/dialect.py b/src/relstorage/adapters/sql/dialect.py new file mode 100644 index 00000000..ed84efbf --- /dev/null +++ b/src/relstorage/adapters/sql/dialect.py @@ -0,0 +1,373 @@ +# -*- coding: utf-8 -*- +""" +RDBMS-specific SQL. + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from operator import attrgetter + +from zope.interface import implementer + +from relstorage._compat import NStringIO +from relstorage._compat import intern +from ..interfaces import IDBDialect + +from .types import OID +from .types import TID +from .types import BinaryString +from .types import State +from .types import Boolean + +from ._util import copy +from .interfaces import ITypedParams + +# pylint:disable=too-many-function-args + +@implementer(IDBDialect) +class DefaultDialect(object): + + keep_history = True + _context_repr = None + + datatype_map = { + OID: 'BIGINT', + TID: 'BIGINT', + BinaryString: 'BYTEA', + State: 'BYTEA', + Boolean: 'BOOLEAN', + } + + def bind(self, context): + # The context will reference us most likely + # (compiled statement in instance dictionary) + # so try to avoid reference cycles. + keep_history = context.keep_history + new = copy(self) + new.keep_history = keep_history + new._context_repr = repr(context) + return new + + def compiler_class(self): + return Compiler + + def compiler(self, root): + return self.compiler_class()(root) + + def datatypes_for_columns(self, column_list): + columns = list(column_list) + datatypes = [] + for column in columns: + datatype = self.datatype_map[type(column.type_)] + datatypes.append(datatype) + return datatypes + + def __eq__(self, other): + if isinstance(other, DefaultDialect): + return other.keep_history == self.keep_history + return NotImplemented # pragma: no cover + + def __repr__(self): + return "<%s at %x keep_history=%s context=%s>" % ( + type(self).__name__, + id(self), + self.keep_history, + self._context_repr + ) + + +class _MissingDialect(DefaultDialect): + def __bool__(self): + return False + + __nonzero__ = __bool__ + + +class Compiler(object): + + def __init__(self, root): + self.buf = NStringIO() + self.placeholders = {} + self.root = root + + def __repr__(self): + return "<%s %s %r>" % ( + type(self).__name__, + self.buf.getvalue(), + self.placeholders + ) + + def compile(self): + self.visit(self.root) + return self.finalize() + + def can_prepare(self): + # Obviously this needs to be more general. + # Some drivers, for example, can't deal with parameters + # in a prepared statement; we currently handle that by overriding + # this method. + return self.root.prepare + + _prepared_stmt_counter = 0 + + @classmethod + def _next_prepared_stmt_name(cls, query): + # Even with the GIL, this isn't fully safe to do; two threads + # can still get the same value. We don't want to allocate a + # lock because we might be patched by gevent later. So that's + # where `query` comes in: we add the hash as a disambiguator. + # Of course, for there to be a duplicate prepared statement + # sent to the database, that would mean that we were somehow + # using the same cursor or connection in multiple threads at + # once (or perhaps we got more than one cursor from a + # connection? We should only have one.) + # + # TODO: Sidestep this problem by allocating this earlier; + # the SELECT or INSERT statement could pick it when it is created; + # that happens at the class level at import time, when we should be + # single-threaded. + # + # That may also help facilitate caching. + cls._prepared_stmt_counter += 1 + return 'rs_prep_stmt_%d_%d' % ( + cls._prepared_stmt_counter, + abs(hash(query)), + ) + + def _prepared_param(self, number): + return '$' + str(number) + + _PREPARED_CONJUNCTION = 'AS' + + def _quote_query_for_prepare(self, query): + return query + + def _find_datatypes_for_prepared_query(self): + # Deduce the datatypes based on the types of the columns + # we're sending as params. + result = () + param_provider = ITypedParams(self.root, None) + if param_provider is not None: + result = param_provider.datatypes_for_parameters() # pylint:disable=assignment-from-no-return + return result + + def prepare(self): + # This is correct for PostgreSQL. This needs moved to a dialect specific + # spot. + + datatypes = self._find_datatypes_for_prepared_query() + query = self.buf.getvalue() + name = self._next_prepared_stmt_name(query) + + if datatypes: + assert isinstance(datatypes, (list, tuple)) + datatypes = ', '.join(datatypes) + datatypes = ' (%s)' % (datatypes,) + else: + datatypes = '' + + q = query.strip() + + # PREPARE needs the query string to use $1, $2, $3, etc, + # as placeholders. + # In MySQL, it's a plain question mark. + placeholder_to_number = {} + counter = 0 + for placeholder_name in self.placeholders.values(): + counter += 1 + placeholder = self._placeholder(placeholder_name) + placeholder_to_number[placeholder_name] = counter + param = self._prepared_param(counter) + q = q.replace(placeholder, param, 1) + + q = self._quote_query_for_prepare(q) + + stmt = 'PREPARE {name}{datatypes} {conjunction} {query}'.format( + name=name, datatypes=datatypes, + query=q, + conjunction=self._PREPARED_CONJUNCTION, + ) + + + if placeholder_to_number: + execute = 'EXECUTE {name}({params})'.format( + name=name, + params=','.join(['%s'] * len(self.placeholders)), + ) + else: + # Neither MySQL nor PostgreSQL like a set of empty parens: () + execute = 'EXECUTE {name}'.format(name=name) + + if '%s' in placeholder_to_number: + # There was an ordered param. If there was one, + # they must all be ordered, so there's no need to convert anything. + assert len(placeholder_to_number) == 1 + def convert(p): + return p + else: + def convert(d): + # TODO: This may not actually be needed, since we issue a regular + # cursor.execute(), it may be able to handle named? + params = [None] * len(placeholder_to_number) + for placeholder_name, ix in placeholder_to_number.items(): + params[ix - 1] = d[placeholder_name] + return params + + return intern(stmt), intern(execute), convert + + def finalize(self): + return intern(self.buf.getvalue().strip()), {v: k for k, v in self.placeholders.items()} + + def visit(self, node): + node.__compile_visit__(self) + + visit_clause = visit + + def emit(self, *contents): + for content in contents: + self.buf.write(content) + + def emit_w_padding_space(self, value): + ended_in_space = self.buf.getvalue().endswith(' ') + value = value.strip() + if not ended_in_space: + self.buf.write(' ') + self.emit(value, ' ') + + emit_keyword = emit_w_padding_space + + def emit_identifier(self, identifier): + last_char = self.buf.getvalue()[-1] + if last_char not in ('(', ' '): + self.emit(' ', identifier) + else: + self.emit(identifier) + + def visit_select_list(self, column_list): + clist = column_list.c if hasattr(column_list, 'c') else column_list + self.visit(clist.as_select_list()) + + def visit_csv(self, nodes): + self.visit(nodes[0]) + for node in nodes[1:]: + self.emit(', ') + self.visit(node) + + visit_select_expression = visit + + def visit_select_list_csv(self, nodes): + self.visit_select_expression(nodes[0]) + for node in nodes[1:]: + self.emit(', ') + self.visit_select_expression(node) + + def visit_column(self, column_node): + self.emit_identifier(column_node.name) + + def visit_from(self, from_): + self.emit_keyword('FROM') + self.visit(from_) + + def visit_grouped(self, clause): + self.emit('(') + self.visit(clause) + self.emit(')') + + def visit_op(self, op): + self.emit(' ' + op + ' ') + + def _next_placeholder_name(self, prefix='param'): + return '%s_%d' % (prefix, len(self.placeholders),) + + def _placeholder(self, key): + # Write things in `pyformat` style by default, assuming a + # dictionary of params; this is supported by most drivers. + if key == '%s': + return key + return '%%(%s)s' % (key,) + + def _placeholder_for_literal_param_value(self, value): + placeholder = self.placeholders.get(value) + if not placeholder: + placeholder_name = self._next_placeholder_name(prefix='literal') + placeholder = self._placeholder(placeholder_name) + self.placeholders[value] = placeholder_name + return placeholder + + def visit_literal_expression(self, value): + placeholder = self._placeholder_for_literal_param_value(value) + self.emit(placeholder) + + def visit_boolean_literal_expression(self, value): + # In the oracle dialect, this needs to be + # either "'Y'" or "'N'" + assert isinstance(value, bool) + self.emit(str(value).upper()) + + def visit_bind_param(self, bind_param): + self.placeholders[bind_param] = bind_param.key + self.emit(self._placeholder(bind_param.key)) + + def visit_ordered_bind_param(self, bind_param): + self.placeholders[bind_param] = '%s' + self.emit('%s') + + +class _DefaultContext(object): + + keep_history = True + + +class DialectAware(object): + + context = _DefaultContext() + dialect = _MissingDialect() + + _dialect_locations = ( + attrgetter('dialect'), + attrgetter('driver.dialect'), + attrgetter('poller.driver.dialect'), + attrgetter('connmanager.driver.dialect'), + attrgetter('adapter.driver.dialect') + ) + + def _find_dialect(self, context): + # Find the dialect to use for the context. If it specifies + # one, then use it. Otherwise go hunting for the database + # driver and use *it*. Preferably the driver is attached to + # the object we're looking at, but if not that, we'll look at + # some common attributes for adapter objects for it. + if isinstance(context, DefaultDialect): + return context + + for getter in self._dialect_locations: + try: + dialect = getter(context) + except AttributeError: + pass + else: + return dialect.bind(context) + __traceback_info__ = getattr(context, '__dict__', ()) # vars() doesn't work on e.g., None + raise TypeError("Unable to bind to %s; no dialect found" % (context,)) + + def bind(self, context, dialect=None): + if dialect is None: + dialect = self._find_dialect(context) + + assert dialect is not None + + new = copy(self) + if context is not None: + new.context = context + new.dialect = dialect + bound_replacements = { + k: v.bind(context, dialect) + for k, v + in vars(new).items() + if isinstance(v, DialectAware) + } + for k, v in bound_replacements.items(): + setattr(new, k, v) + return new diff --git a/src/relstorage/adapters/sql/expressions.py b/src/relstorage/adapters/sql/expressions.py new file mode 100644 index 00000000..3c4eafcd --- /dev/null +++ b/src/relstorage/adapters/sql/expressions.py @@ -0,0 +1,235 @@ +# -*- coding: utf-8 -*- +""" +Expressions in the AST. + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from zope.interface import implementer + +from ._util import Resolvable +from ._util import copy + +from .dialect import DialectAware +from .interfaces import INamedBindParam +from .interfaces import IOrderedBindParam + +class Expression(DialectAware, + Resolvable): + """ + A SQL expression. + """ + + __slots__ = () + +@implementer(INamedBindParam) +class BindParam(Expression): + + __slots__ = ( + 'key', + ) + + def __init__(self, key): + self.key = key + + def __compile_visit__(self, compiler): + compiler.visit_bind_param(self) + +@implementer(INamedBindParam) +def bindparam(key): + return BindParam(key) + + +class LiteralExpression(Expression): + + __slots__ = ( + 'value', + ) + + def __init__(self, value): + self.value = value + + def __str__(self): + return str(self.value) + + def __compile_visit__(self, compiler): + compiler.visit_literal_expression(self.value) + +class BooleanLiteralExpression(LiteralExpression): + + __slots__ = () + + def __compile_visit__(self, compiler): + compiler.visit_boolean_literal_expression(self.value) + +@implementer(IOrderedBindParam) +class OrderedBindParam(Expression): + + __slots__ = () + + name = '%s' + + def __compile_visit__(self, compiler): + compiler.visit_ordered_bind_param(self) + +@implementer(IOrderedBindParam) +def orderedbindparam(): + return OrderedBindParam() + + +def as_expression(stmt): + if hasattr(stmt, '__compile_visit__'): + return stmt + + if isinstance(stmt, bool): + # The values True and False are handled + # specially because their representation varies + # among databases (Oracle) + stmt = BooleanLiteralExpression(stmt) + else: + stmt = LiteralExpression(stmt) + return stmt + +class BinaryExpression(Expression): + """ + Expresses a comparison. + """ + + __slots__ = ( + 'op', + 'lhs', + 'rhs', + ) + + def __init__(self, op, lhs, rhs): + self.op = op + self.lhs = lhs # type: Column + # rhs is either a literal or a column; + # certain literals are handled specially. + rhs = as_expression(rhs) + self.rhs = rhs + + def __str__(self): + return '%s %s %s' % ( + self.lhs, + self.op, + self.rhs + ) + + def __compile_visit__(self, compiler): + compiler.visit(self.lhs) + compiler.visit_op(self.op) + compiler.visit(self.rhs) + + def resolve_against(self, table): + lhs = self.lhs.resolve_against(table) + rhs = self.rhs.resolve_against(table) + new = copy(self) + new.rhs = rhs + new.lhs = lhs + return new + +class EmptyExpression(Expression): + """ + No comparison at all. + """ + + __slots__ = () + + def __bool__(self): + return False + + __nonzero__ = __bool__ + + def __str__(self): + return '' + + def and_(self, expression): + return expression + + def __compile_visit__(self, compiler): + "Does nothing" + +class EqualExpression(BinaryExpression): + + __slots__ = () + + def __init__(self, lhs, rhs): + BinaryExpression.__init__(self, '=', lhs, rhs) + +class NotEqualExpression(BinaryExpression): + + __slots__ = () + + def __init__(self, lhs, rhs): + BinaryExpression.__init__(self, '<>', lhs, rhs) + + +class GreaterExpression(BinaryExpression): + + __slots__ = () + + def __init__(self, lhs, rhs): + BinaryExpression.__init__(self, '>', lhs, rhs) + +class GreaterEqualExpression(BinaryExpression): + + __slots__ = () + + def __init__(self, lhs, rhs): + BinaryExpression.__init__(self, '>=', lhs, rhs) + +class LessEqualExpression(BinaryExpression): + + __slots__ = () + + def __init__(self, lhs, rhs): + BinaryExpression.__init__(self, '<=', lhs, rhs) + +class And(Expression): + + __slots__ = ( + 'lhs', + 'rhs', + ) + + def __init__(self, lhs, rhs): + self.lhs = as_expression(lhs) + self.rhs = as_expression(rhs) + + def __compile_visit__(self, compiler): + compiler.visit_grouped(BinaryExpression('AND', self.lhs, self.rhs)) + + def resolve_against(self, table): + lhs = self.lhs.resolve_against(table) + rhs = self.rhs.resolve_against(table) + new = copy(self) + self.lhs = lhs + self.rhs = rhs + return new + + +class ParamMixin(object): + def bindparam(self, key): + return bindparam(key) + + def orderedbindparam(self): + return orderedbindparam() + +class ExpressionOperatorMixin(object): + def __eq__(self, other): + return EqualExpression(self, other) + + def __gt__(self, other): + return GreaterExpression(self, other) + + def __ge__(self, other): + return GreaterEqualExpression(self, other) + + def __ne__(self, other): + return NotEqualExpression(self, other) + + def __le__(self, other): + return LessEqualExpression(self, other) diff --git a/src/relstorage/adapters/sql/functions.py b/src/relstorage/adapters/sql/functions.py new file mode 100644 index 00000000..96333097 --- /dev/null +++ b/src/relstorage/adapters/sql/functions.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +""" +Function expressions. + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from .expressions import Expression + +class _Functions(object): + + def max(self, column): + return _Function('max', column) + +class _Function(Expression): + + def __init__(self, name, expression): + self.name = name + self.expression = expression + + def __compile_visit__(self, compiler): + compiler.emit_identifier(self.name) + compiler.visit_grouped(self.expression) + +func = _Functions() diff --git a/src/relstorage/adapters/sql/insert.py b/src/relstorage/adapters/sql/insert.py new file mode 100644 index 00000000..8542c3c7 --- /dev/null +++ b/src/relstorage/adapters/sql/insert.py @@ -0,0 +1,89 @@ +# -*- coding: utf-8 -*- +""" +The ``INSERT`` statement. + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from zope.interface import implementer + +from .query import Query +from ._util import copy +from .query import ColumnList +from .ast import resolved_against + +from .interfaces import ITypedParams +from .interfaces import IOrderedBindParam + +@implementer(ITypedParams) +class Insert(Query): + + column_list = None + select = None + epilogue = '' + values = None + + def __init__(self, table, *columns): + super(Insert, self).__init__() + self.table = table + if columns: + self.column_list = ColumnList(resolved_against(columns, table)) + # TODO: Probably want a different type, like a ValuesList + self.values = ColumnList([self.orderedbindparam() for _ in columns]) + + def from_select(self, names, select): + i = copy(self) + i.column_list = ColumnList(names) + i.select = select + return i + + def __compile_visit__(self, compiler): + compiler.emit_keyword('INSERT INTO') + compiler.visit(self.table) + compiler.visit_grouped(self.column_list) + if self.select: + compiler.visit(self.select) + else: + compiler.emit_keyword('VALUES') + compiler.visit_grouped(self.values) + compiler.emit(self.epilogue) + + def __add__(self, extension): + # This appends a textual epilogue. It's a temporary + # measure until we have more nodes and can model what + # we're trying to accomplish. + assert isinstance(extension, str) + i = copy(self) + i.epilogue += extension + return i + + def datatypes_for_parameters(self): + dialect = self.dialect + if self.values and self.column_list: + # If we're sending in a list of values, those have to + # exactly match the columns, so we can easily get a list + # of datatypes. + column_list = self.column_list + datatypes = dialect.datatypes_for_columns(column_list) + elif self.select and self.select.column_list.has_bind_param(): + targets = self.column_list + sources = self.select.column_list + # TODO: This doesn't support bind params anywhere except the + # select list! + # TODO: This doesn't support named bind params. + columns_with_params = [ + target + for target, source in zip(targets, sources) + if IOrderedBindParam.providedBy(source) + ] + datatypes = dialect.datatypes_for_columns(columns_with_params) + return datatypes + + + +class Insertable(object): + + def insert(self, *columns): + return Insert(self, *columns) diff --git a/src/relstorage/adapters/sql/interfaces.py b/src/relstorage/adapters/sql/interfaces.py new file mode 100644 index 00000000..46126a7e --- /dev/null +++ b/src/relstorage/adapters/sql/interfaces.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- +""" +Interfaces, mostly internal, for the sql module. + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint:disable=inherit-non-class,no-method-argument + +from zope.interface import Interface + +class ITypedParams(Interface): + """ + Something that accepts parameters, and knows + what their types should be. + """ + + def datatypes_for_parameters(): + """ + Returns a sequence of datatypes. + + XXX: This only works for ordered params; make this work + for named parameters. Probably want to treat the two the same, + with ordered parameters using indexes as their name. + """ + + +class IBindParam(Interface): + """ + A parameter to a query. + """ + +class INamedBindParam(IBindParam): + """ + A named parameter. + """ + +class IOrderedBindParam(IBindParam): + """ + A anonymous parameter, identified only by order. + """ diff --git a/src/relstorage/adapters/sql/query.py b/src/relstorage/adapters/sql/query.py new file mode 100644 index 00000000..abd45344 --- /dev/null +++ b/src/relstorage/adapters/sql/query.py @@ -0,0 +1,177 @@ +# -*- coding: utf-8 -*- +""" +Compiled queries ready for execution. + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from weakref import WeakKeyDictionary + +from relstorage._util import CachedIn + +from ._util import copy +from ._util import Columns +from .dialect import DialectAware + +from .expressions import ParamMixin + + +class Clause(DialectAware): + """ + A portion of a SQL statement. + """ + + + +class ColumnList(Columns): + """ + List of columns used in a query. + """ + + # This class exists for semantics, it currently doesn't + # do anything different than the super. + + +class Query(ParamMixin, + Clause): + __name__ = None + prepare = False + + def __str__(self): + return str(self.compiled()) + + def __get__(self, inst, klass): + if inst is None: + return self + # We need to set this into the instance's dictionary. + # Otherwise we'll be rebinding and recompiling each time we're + # accessed which is not good (in fact it's a step backwards + # from query_property()). On Python 3.6+, there's the + # `__set_name__(klass, name)` called on the descriptor which + # does the job perfectly. In earlier versions, we're on our + # own. + # + # TODO: In test cases, we spend a lot of time binding and compiling. + # Can we find another layer of caching somewhere? + result = self.bind(inst).compiled() + if not self.__name__: + # Go through the class hierarchy, find out what we're called. + for base in klass.mro(): + for k, v in vars(base).items(): + if v is self: + self.__name__ = k + break + assert self.__name__ + + vars(inst)[self.__name__] = result + + return result + + def __set_name__(self, owner, name): + self.__name__ = name + + @CachedIn('_v_compiled') + def compiled(self): + return CompiledQuery(self) + + def prepared(self): + """ + Note that it's good to prepare this query, if + supported by the driver. + """ + s = copy(self) + s.prepare = True + return s + + +class CompiledQuery(object): + """ + Represents a completed query. + """ + + stmt = None + params = None + _raw_stmt = None + _prepare_stmt = None + _prepare_converter = None + + def __init__(self, root): + self.root = root + # We do not keep a reference to the context; + # it's likely to be an instance object that's + # going to have us stored in its dictionary. + dialect = root.dialect + + compiler = dialect.compiler(root) + self.stmt, self.params = compiler.compile() + self._raw_stmt = self.stmt # for debugging + if compiler.can_prepare(): + self._prepare_stmt, self.stmt, self._prepare_converter = compiler.prepare() + + def __repr__(self): + if self._prepare_stmt: + return "%s (%s)" % ( + self.stmt, + self._prepare_stmt + ) + return self.stmt + + def __str__(self): + return self.stmt + + _cursor_cache = WeakKeyDictionary() + + def _stmt_cache_for_cursor(self, cursor): + """Returns a dictionary.""" + # If we can't store it directly on the cursor, as happens for + # types implemented in C, we use a weakkey dictionary. + try: + cursor_prep_stmts = cursor._rs_prepared_statements + except AttributeError: + try: + cursor_prep_stmts = cursor._rs_prepared_statements = {} + except AttributeError: + cursor_prep_stmts = self._cursor_cache.get(cursor) + if cursor_prep_stmts is None: + cursor_prep_stmts = self._cursor_cache[cursor] = {} + return cursor_prep_stmts + + def execute(self, cursor, params=None): + # (Any, dict) -> None + # TODO: Include literals from self.params. + # TODO: Syntax transformation if they don't support names. + # TODO: Validate given params match expected ones, nothing missing? + stmt = self.stmt + if self._prepare_stmt: + # Prepare on demand. + + # In all databases, prepared statements + # persist past COMMIT/ROLLBACK (though in PostgreSQL + # preparing them takes locks that only go away at + # COMMIT/ROLLBACK). But they don't persist past a session + # restart (new connection) (obviously). + # + # Thus we keep a cache of statements we have prepared for + # this particular connection/cursor. + # + cursor_prep_stmts = self._stmt_cache_for_cursor(cursor) + try: + stmt = cursor_prep_stmts[self._prepare_stmt] + except KeyError: + stmt = cursor_prep_stmts[self._prepare_stmt] = self.stmt + __traceback_info__ = self._prepare_stmt, self, self.root.dialect.compiler(self.root) + cursor.execute(self._prepare_stmt) + params = self._prepare_converter(params) + + __traceback_info__ = stmt, params + if params: + cursor.execute(stmt, params) + elif self.params: + # XXX: This isn't really good. + # If there are both literals in the SQL and params, + # we don't handle that. + cursor.execute(stmt, self.params) + else: + cursor.execute(stmt) diff --git a/src/relstorage/adapters/sql/schema.py b/src/relstorage/adapters/sql/schema.py new file mode 100644 index 00000000..00993851 --- /dev/null +++ b/src/relstorage/adapters/sql/schema.py @@ -0,0 +1,161 @@ +# -*- coding: utf-8 -*- +""" +Concepts that exist in SQL, such as columns and tables. + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from ._util import Resolvable +from ._util import Columns + +from .types import Type +from .types import Unknown + +from .expressions import ExpressionOperatorMixin +from .expressions import ParamMixin + +from .dialect import DialectAware + +from .select import Selectable +from .insert import Insertable + +class Column(ExpressionOperatorMixin, + Resolvable): + """ + Defines a column in a table. + """ + + def __init__(self, name, type_=Unknown, primary_key=False, nullable=True): + self.name = name + self.type_ = type_ if isinstance(type_, Type) else type_() + self.primary_key = primary_key + self.nullable = False if primary_key else nullable + + def is_type(self, type_): + return isinstance(self.type_, type_) + + def __str__(self): + return self.name + + + def __compile_visit__(self, compiler): + compiler.visit_column(self) + + +class _DeferredColumn(Column): + + def resolve_against(self, table): + return getattr(table.c, self.name) + +class _DeferredColumns(object): + + def __getattr__(self, name): + return _DeferredColumn(name) + +class ColumnResolvingProxy(ParamMixin): + """ + A proxy that select can resolve to tables in the current table. + """ + + c = _DeferredColumns() + + +class SchemaItem(object): + """ + A permanent item in a schema (aka the data dictionary). + """ + + +class Table(Selectable, + Insertable, + ParamMixin, + SchemaItem): + """ + A table relation. + """ + + def __init__(self, name, *columns): + self.name = name + self.columns = columns + self.c = Columns(columns) + + def __str__(self): + return self.name + + def __compile_visit__(self, compiler): + compiler.emit_identifier(self.name) + + def natural_join(self, other_table): + return NaturalJoinedTable(self, other_table) + + + +class TemporaryTable(Table): + """ + A temporary table. + """ + + +class _CompositeTableMixin(Selectable, + ParamMixin): + + def __init__(self, lhs, rhs): + self.lhs = lhs + self.rhs = rhs + + common_columns = [] + for col in lhs.columns: + if hasattr(rhs.c, col.name): + common_columns.append(col) + self.columns = common_columns + self.c = Columns(common_columns) + + +class NaturalJoinedTable(DialectAware, + _CompositeTableMixin): + + def __init__(self, lhs, rhs): + super(NaturalJoinedTable, self).__init__(lhs, rhs) + assert self.c + self._join_columns = self.c + + # TODO: Check for data type mismatches, etc? + self.columns = list(self.lhs.columns) + for col in self.rhs.columns: + if not hasattr(self.lhs.c, col.name): + self.columns.append(col) + self.columns = tuple(self.columns) + self.c = Columns(self.columns) + + def __compile_visit__(self, compiler): + compiler.visit(self.lhs) + compiler.emit_keyword('JOIN') + compiler.visit(self.rhs) + # careful with USING clause in a join: Oracle doesn't allow such + # columns to have a prefix. + compiler.emit_keyword('USING') + compiler.visit_grouped(self._join_columns) + + +class HistoryVariantTable(DialectAware, + _CompositeTableMixin): + """ + A table that can be one of two tables, depending on whether + the instance is keeping history or not. + """ + + @property + def history_preserving(self): + return self.lhs + + @property + def history_free(self): + return self.rhs + + def __compile_visit__(self, compiler): + keep_history = self.context.keep_history + node = self.history_preserving if keep_history else self.history_free + return compiler.visit(node) diff --git a/src/relstorage/adapters/sql/select.py b/src/relstorage/adapters/sql/select.py new file mode 100644 index 00000000..558d590c --- /dev/null +++ b/src/relstorage/adapters/sql/select.py @@ -0,0 +1,147 @@ +# -*- coding: utf-8 -*- +""" +Elements of select queries. + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from .query import Query +from .query import Clause +from .query import ColumnList +from ._util import copy + +from .ast import TextNode +from .ast import resolved_against + +from .expressions import And +from .expressions import EmptyExpression + +class WhereClause(Clause): + + def __init__(self, expression): + self.expression = expression + + def and_(self, expression): + expression = And(self.expression, expression) + new = copy(self) + new.expression = expression + return new + + def __compile_visit__(self, compiler): + compiler.emit_keyword(' WHERE') + compiler.visit_grouped(self.expression) + +class OrderBy(Clause): + + def __init__(self, expression, dir): + self.expression = expression + self.dir = dir + + def __compile_visit__(self, compiler): + compiler.emit(' ORDER BY ') + compiler.visit(self.expression) + if self.dir: + compiler.emit(' ' + self.dir) + + +def _where(expression): + if expression: + return WhereClause(expression) + + +class _SelectColumns(ColumnList): + + def __compile_visit__(self, compiler): + compiler.visit_select_list_csv(self._columns) + + def as_select_list(self): + return self + + +class Select(Query): + """ + A Select query. + + When instances of this class are stored in a class dictionary, + they function as non-data descriptors: The first time they are + accessed, they *bind* themselves to the instance and select the + appropriate SQL syntax and compile themselves into a string. + """ + + _distinct = EmptyExpression() + _where = EmptyExpression() + _order_by = EmptyExpression() + _limit = None + _for_update = None + _nowait = None + + def __init__(self, table, *columns): + self.table = table + if columns: + self.column_list = _SelectColumns(resolved_against(columns, table)) + else: + self.column_list = table + + def where(self, expression): + expression = expression.resolve_against(self.table) + s = copy(self) + s._where = _where(expression) + return s + + def and_(self, expression): + expression = expression.resolve_against(self.table) + s = copy(self) + s._where = self._where.and_(expression) + return s + + def order_by(self, expression, dir=None): + expression = expression.resolve_against(self.table) + s = copy(self) + s._order_by = OrderBy(expression, dir) + return s + + def limit(self, literal): + s = copy(self) + s._limit = literal + return s + + def for_update(self): + s = copy(self) + s._for_update = 'FOR UPDATE' + return s + + def nowait(self): + s = copy(self) + s._nowait = 'NOWAIT' + return s + + def distinct(self): + s = copy(self) + s._distinct = TextNode('DISTINCT') + return s + + def __compile_visit__(self, compiler): + compiler.emit_keyword('SELECT') + compiler.visit(self._distinct) + compiler.visit_select_list(self.column_list) + compiler.visit_from(self.table) + compiler.visit_clause(self._where) + compiler.visit_clause(self._order_by) + if self._limit: + compiler.emit_keyword('LIMIT') + compiler.emit(str(self._limit)) + if self._for_update: + compiler.emit_keyword(self._for_update) + if self._nowait: + compiler.emit_keyword(self._nowait) + + +class Selectable(object): + """ + Mixin for something that can form the root of a selet query. + """ + + def select(self, *args, **kwargs): + return Select(self, *args, **kwargs) diff --git a/src/relstorage/adapters/sql/tests/__init__.py b/src/relstorage/adapters/sql/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/relstorage/adapters/sql/tests/test_ast.py b/src/relstorage/adapters/sql/tests/test_ast.py new file mode 100644 index 00000000..f78735e9 --- /dev/null +++ b/src/relstorage/adapters/sql/tests/test_ast.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +############################################################################## +# +# Copyright (c) 2019 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## + +""" +Tests for AST. + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from relstorage.tests import TestCase + +from .. import ast + + +class TestFuncs(TestCase): + + def test_as_node_boolean(self): + node = ast.as_node(True) + self.assertIsInstance(node, ast.BooleanNode) + self.assertIs(node.raw, True) diff --git a/src/relstorage/adapters/sql/tests/test_dialect.py b/src/relstorage/adapters/sql/tests/test_dialect.py new file mode 100644 index 00000000..4197accd --- /dev/null +++ b/src/relstorage/adapters/sql/tests/test_dialect.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- +############################################################################## +# +# Copyright (c) 2019 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## + +""" +Tests for dialects. + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from relstorage.tests import TestCase + +from .. import dialect + +class TestMissingDialect(TestCase): + + def test_boolean(self): + d = dialect._MissingDialect() + self.assertFalse(d) + + +class TestCompiler(TestCase): + + def test_prepare_no_datatypes(self): + + class C(dialect.Compiler): + def _next_prepared_stmt_name(self, query): + return 'my_stmt' + + def _find_datatypes_for_prepared_query(self): + return () + + compiler = C(None) + + stmt, execute, convert = compiler.prepare() + + self.assertEqual( + stmt, + 'PREPARE my_stmt AS ' + ) + + self.assertEqual( + execute, + 'EXECUTE my_stmt' + ) + + # We get the default dictionary converter even if we + # don't need it. + + self.assertEqual( + [], + convert({'a': 42}) + ) + + def test_prepare_named_datatypes(self): + + compiler = dialect.Compiler(None) + compiler.placeholders[object()] = 'name' + + _s, _x, convert = compiler.prepare() + + self.assertEqual( + [64], + convert({'name': 64}) + ) + + +class TestDialectAware(TestCase): + + def test_bind_none(self): + + aware = dialect.DialectAware() + + with self.assertRaisesRegex(TypeError, 'no dialect found'): + aware.bind(None) + + def test_bind_dialect(self): + class Dialect(dialect.DefaultDialect): + def bind(self, context): + raise AssertionError("Not supposed to re-bind") + + d = Dialect() + + aware = dialect.DialectAware() + new_aware = aware.bind(d) + self.assertIsNot(aware, new_aware) + self.assertIs(new_aware.dialect, d) diff --git a/src/relstorage/adapters/sql/tests/test_expressions.py b/src/relstorage/adapters/sql/tests/test_expressions.py new file mode 100644 index 00000000..37b3c7e8 --- /dev/null +++ b/src/relstorage/adapters/sql/tests/test_expressions.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +############################################################################## +# +# Copyright (c) 2019 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## + +""" +Tests for expressions. + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from relstorage.tests import TestCase + +from .. import expressions + +class TestBinaryExpression(TestCase): + + def test_str(self): + + exp = expressions.BinaryExpression('=', 'lhs', 'rhs') + self.assertEqual( + str(exp), + 'lhs = rhs' + ) + + +class TestEmptyExpression(TestCase): + + exp = expressions.EmptyExpression() + + def test_boolean(self): + self.assertFalse(self.exp) + + def test_str(self): + self.assertEqual(str(self.exp), '') + + def test_and(self): + self.assertIs(self.exp.and_(self), self) + + +class TestAnd(TestCase): + + def test_resolve(self): + # This will wrap them in literal nodes, which + # do nothing when resolved. + + exp = expressions.And('a', 'b') + + resolved = exp.resolve_against(None) + + self.assertIsInstance(resolved, expressions.And) + self.assertIs(resolved.lhs, exp.lhs) + self.assertIs(resolved.rhs, exp.rhs) diff --git a/src/relstorage/adapters/sql/tests/test_query.py b/src/relstorage/adapters/sql/tests/test_query.py new file mode 100644 index 00000000..df85bafc --- /dev/null +++ b/src/relstorage/adapters/sql/tests/test_query.py @@ -0,0 +1,125 @@ +# -*- coding: utf-8 -*- +############################################################################## +# +# Copyright (c) 2019 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from relstorage.tests import TestCase + +from ..query import Query as _BaseQuery +from ..query import CompiledQuery + +class MockDialect(object): + + def bind(self, context): # pylint:disable=unused-argument + return self + +class Query(_BaseQuery): + + def compiled(self): + return self + +class TestQuery(TestCase): + + def test_name_discovery(self): + # If a __name__ isn't assigned when a query is a + # class property and used as a non-data-descriptor, + # it finds it. + + class C(object): + dialect = MockDialect() + + q1 = Query() + q2 = Query() + q_over = Query() + + class D(C): + + q3 = Query() + q_over = Query() + + inst = D() + + # Undo the effects of Python 3.6's __set_name__. + D.q1.__name__ = None + D.q2.__name__ = None + D.q3.__name__ = None + C.q_over.__name__ = None + D.q_over.__name__ = None + + # get them to trigger them to search their name + getattr(inst, 'q1') + getattr(inst, 'q2') + getattr(inst, 'q3') + getattr(inst, 'q_over') + + + self.assertEqual(C.q1.__name__, 'q1') + self.assertEqual(C.q2.__name__, 'q2') + self.assertEqual(D.q3.__name__, 'q3') + self.assertIsNone(C.q_over.__name__) + self.assertEqual(D.q_over.__name__, 'q_over') + + +class TestCompiledQuery(TestCase): + + def test_stmt_cache_on_bad_cursor(self): + + unique_execute_stmt = [] + + class MockStatement(object): + class dialect(object): + class compiler(object): + def __init__(self, _): + "Does nothing" + def compile(self): + return 'stmt', () + def can_prepare(self): + # We have to prepare if we want to try the cache + return True + def prepare(self): + o = object() + unique_execute_stmt.append(o) + return "prepare", o, lambda params: params + + executed = [] + + class Cursor(object): + __slots__ = ('__weakref__',) + + def execute(self, stmt): + executed.append(stmt) + + cursor = Cursor() + + query = CompiledQuery(MockStatement()) + query.execute(cursor) + + self.assertLength(unique_execute_stmt, 1) + self.assertLength(executed, 2) + self.assertEqual(executed, [ + "prepare", + unique_execute_stmt[0], + ]) + + query.execute(cursor) + self.assertLength(unique_execute_stmt, 1) + self.assertLength(executed, 3) + self.assertEqual(executed, [ + "prepare", + unique_execute_stmt[0], + unique_execute_stmt[0], + ]) diff --git a/src/relstorage/adapters/sql/tests/test_schema.py b/src/relstorage/adapters/sql/tests/test_schema.py new file mode 100644 index 00000000..113ab5c1 --- /dev/null +++ b/src/relstorage/adapters/sql/tests/test_schema.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +############################################################################## +# +# Copyright (c) 2019 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from relstorage.tests import TestCase + +from ..schema import Table + + +class TestTable(TestCase): + + def test_str(self): + + self.assertEqual(str(Table("table")), "table") diff --git a/src/relstorage/adapters/sql/tests/test_sql.py b/src/relstorage/adapters/sql/tests/test_sql.py new file mode 100644 index 00000000..d44dd984 --- /dev/null +++ b/src/relstorage/adapters/sql/tests/test_sql.py @@ -0,0 +1,403 @@ +# -*- coding: utf-8 -*- +############################################################################## +# +# Copyright (c) 2019 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## + +""" +Tests for the SQL abstraction layer. + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from relstorage.tests import TestCase + +from .. import Table +from .. import HistoryVariantTable +from .. import Column +from .. import it +from .. import DefaultDialect +from .. import OID +from .. import TID +from .. import State +from .. import Boolean +from .. import BinaryString +from .. import func + +from ..expressions import bindparam + +current_object = Table( + 'current_object', + Column('zoid', OID), + Column('tid', TID) +) + +object_state = Table( + 'object_state', + Column('zoid', OID), + Column('tid', TID), + Column('state', State), + Column('state_size'), +) + +hp_object_and_state = current_object.natural_join(object_state) + +objects = HistoryVariantTable( + current_object, + object_state, +) + +object_and_state = HistoryVariantTable( + hp_object_and_state, + object_state +) + +transaction = Table( + 'transaction', + Column('tid', TID), + Column('packed', Boolean), + Column('username', BinaryString), + Column('description', BinaryString), + Column('extension', BinaryString), +) + +class TestTableSelect(TestCase): + + def test_simple_eq_select(self): + table = object_state + + stmt = table.select().where(table.c.zoid == table.c.tid) + + self.assertEqual( + str(stmt), + 'SELECT zoid, tid, state, state_size FROM object_state WHERE (zoid = tid)' + ) + + def test_simple_eq_limit(self): + table = object_state + + stmt = table.select().where(table.c.zoid == table.c.tid).limit(1) + + self.assertEqual( + str(stmt), + 'SELECT zoid, tid, state, state_size FROM object_state WHERE (zoid = tid) LIMIT 1' + ) + + def test_simple_eq_for_update(self): + table = object_state + + stmt = table.select().where(table.c.zoid == table.c.tid).for_update() + + self.assertEqual( + str(stmt), + 'SELECT zoid, tid, state, state_size FROM object_state WHERE (zoid = tid) FOR UPDATE' + ) + + stmt = stmt.nowait() + self.assertEqual( + str(stmt), + 'SELECT zoid, tid, state, state_size FROM object_state WHERE (zoid = tid) FOR UPDATE ' + 'NOWAIT' + ) + + def test_max_select(self): + table = object_state + + stmt = table.select(func.max(it.c.tid)) + + self.assertEqual( + str(stmt), + 'SELECT max(tid) FROM object_state' + ) + + def test_distinct(self): + table = object_state + stmt = table.select(table.c.zoid).where(table.c.tid == table.bindparam('tid')).distinct() + self.assertEqual( + str(stmt), + 'SELECT DISTINCT zoid FROM object_state WHERE (tid = %(tid)s)' + ) + + def test_simple_eq_select_and(self): + + table = object_state + + stmt = table.select().where(table.c.zoid == table.c.tid) + + self.assertEqual( + str(stmt), + 'SELECT zoid, tid, state, state_size FROM object_state WHERE (zoid = tid)' + ) + + stmt = stmt.and_(table.c.zoid > 5) + self.assertEqual( + str(stmt), + 'SELECT zoid, tid, state, state_size ' + 'FROM object_state WHERE ((zoid = tid AND zoid > %(literal_0)s))' + ) + + def test_simple_eq_select_literal(self): + table = object_state + + # This is a useless query + stmt = table.select().where(table.c.zoid == 7) + + self.assertEqual( + str(stmt), + 'SELECT zoid, tid, state, state_size FROM object_state WHERE (zoid = %(literal_0)s)' + ) + + self.assertEqual( + stmt.compiled().params, + {'literal_0': 7}) + + def test_column_query_variant_table(self): + stmt = objects.select(objects.c.tid, objects.c.zoid).where( + objects.c.tid > bindparam('tid') + ) + + self.assertEqual( + str(stmt), + 'SELECT tid, zoid FROM current_object WHERE (tid > %(tid)s)' + ) + + def test_natural_join(self): + stmt = object_and_state.select( + object_and_state.c.zoid, object_and_state.c.state + ).where( + object_and_state.c.zoid == object_and_state.bindparam('oid') + ) + + self.assertEqual( + str(stmt), + 'SELECT zoid, state ' + 'FROM current_object ' + 'JOIN object_state ' + 'USING (zoid, tid) WHERE (zoid = %(oid)s)' + ) + + class H(object): + keep_history = False + dialect = DefaultDialect() + + stmt = stmt.bind(H()) + + self.assertEqual( + str(stmt), + 'SELECT zoid, state ' + 'FROM object_state ' + 'WHERE (zoid = %(oid)s)' + ) + + def test_bind(self): + from operator import attrgetter + query = objects.select(objects.c.tid, objects.c.zoid).where( + objects.c.tid > bindparam('tid') + ) + # Unbound we assume history + self.assertEqual( + str(query), + 'SELECT tid, zoid FROM current_object WHERE (tid > %(tid)s)' + ) + + class Context(object): + dialect = DefaultDialect() + keep_history = True + + context = Context() + dialect = context.dialect + query = query.bind(context) + + class Root(object): + select = query + + for item_name in ( + 'select', + 'select.table', + 'select._where', + 'select._where.expression', + ): + __traceback_info__ = item_name + item = attrgetter(item_name)(Root) + # The exact context is passed down the tree. + self.assertIs(item.context, context) + # The dialect is first bound, so it's *not* the same + # as the one we can reference (though it is equal)... + self.assertEqual(item.dialect, dialect) + # ...but it *is* the same throughout the tree + self.assertIs(query.dialect, item.dialect) + + # We take up its history setting + self.assertEqual( + str(query), + 'SELECT tid, zoid FROM current_object WHERE (tid > %(tid)s)' + ) + + # Bound to history-free we use history free + context.keep_history = False + query = query.bind(context) + + self.assertEqual( + str(query), + 'SELECT tid, zoid FROM object_state WHERE (tid > %(tid)s)' + ) + + def test_bind_descriptor(self): + class Context(object): + keep_history = True + dialect = DefaultDialect() + select = objects.select(objects.c.tid, objects.c.zoid).where( + objects.c.tid > bindparam('tid') + ) + + # Unbound we assume history + self.assertEqual( + str(Context.select), + 'SELECT tid, zoid FROM current_object WHERE (tid > %(tid)s)' + ) + + context = Context() + context.keep_history = False + self.assertEqual( + str(context.select), + 'SELECT tid, zoid FROM object_state WHERE (tid > %(tid)s)' + ) + + def test_prepared_insert_values(self): + stmt = current_object.insert( + current_object.c.zoid + ) + + self.assertEqual( + str(stmt), + 'INSERT INTO current_object(zoid) VALUES (%s)' + ) + + stmt = stmt.prepared() + self.assertTrue( + str(stmt).startswith('EXECUTE rs_prep_stmt') + ) + + stmt = stmt.compiled() + self.assertRegex( + stmt._prepare_stmt, + r"PREPARE rs_prep_stmt_[0-9]*_[0-9]* \(BIGINT\) AS.*" + ) + + def test_prepared_insert_select_with_param(self): + stmt = current_object.insert().from_select( + (current_object.c.zoid, + current_object.c.tid), + object_state.select( + object_state.c.zoid, + object_state.orderedbindparam() + ) + ) + self.assertEqual( + str(stmt), + 'INSERT INTO current_object(zoid, tid) SELECT zoid, %s FROM object_state' + ) + + stmt = stmt.prepared() + self.assertTrue( + str(stmt).startswith('EXECUTE rs_prep_stmt') + ) + + stmt = stmt.compiled() + self.assertRegex( + stmt._prepare_stmt, + r"PREPARE rs_prep_stmt_[0-9]*_[0-9]* \(BIGINT\) AS.*" + ) + + def test_it(self): + stmt = object_state.select( + it.c.zoid, + it.c.state + ).where( + it.c.tid == it.bindparam('tid') + ).order_by( + it.c.zoid + ) + + self.assertEqual( + str(stmt), + 'SELECT zoid, state FROM object_state WHERE (tid = %(tid)s) ORDER BY zoid' + ) + + # Now something that won't resolve. + col_ref = it.c.dne + + # In the column list + with self.assertRaisesRegex(AttributeError, 'does not include dne'): + object_state.select(col_ref) + + stmt = object_state.select(it.c.zoid) + + # In the where clause + with self.assertRaisesRegex(AttributeError, 'does not include dne'): + stmt.where(col_ref == object_state.c.state) + + # In order by + with self.assertRaisesRegex(AttributeError, 'does not include dne'): + stmt.order_by(col_ref == object_state.c.state) + + def test_boolean_literal(self): + stmt = transaction.select( + transaction.c.tid + ).where( + it.c.packed == False # pylint:disable=singleton-comparison + ).order_by( + transaction.c.tid, 'DESC' + ) + + self.assertEqual( + str(stmt), + 'SELECT tid FROM transaction WHERE (packed = FALSE) ORDER BY tid DESC' + ) + + def test_literal_in_select(self): + stmt = current_object.select( + 1 + ).where( + current_object.c.zoid == current_object.bindparam('oid') + ) + + self.assertEqual( + str(stmt), + 'SELECT 1 FROM current_object WHERE (zoid = %(oid)s)' + ) + + def test_boolean_literal_it_joined_table(self): + stmt = transaction.natural_join( + object_state + ).select( + it.c.tid, it.c.username, it.c.description, it.c.extension, + object_state.c.state_size + ).where( + it.c.zoid == it.bindparam("oid") + ).and_( + it.c.packed == False # pylint:disable=singleton-comparison + ).order_by( + it.c.tid, "DESC" + ) + + self.assertEqual( + str(stmt), + 'SELECT tid, username, description, extension, state_size ' + 'FROM transaction ' + 'JOIN object_state ' + 'USING (tid) ' + 'WHERE ((zoid = %(oid)s AND packed = FALSE)) ' + 'ORDER BY tid DESC' + ) diff --git a/src/relstorage/adapters/sql/types.py b/src/relstorage/adapters/sql/types.py new file mode 100644 index 00000000..670ed153 --- /dev/null +++ b/src/relstorage/adapters/sql/types.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +""" +SQL data types. + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + + +class Type(object): + """ + A database type. + """ + +class Unknown(Type): + "Unspecified." + +class Integer64(Type): + """ + A 64-bit integer. + """ + +class OID(Integer64): + """ + Type of an OID. + """ + +class TID(Integer64): + """ + Type of a TID. + """ + +class BinaryString(Type): + """ + Arbitrary sized binary string. + """ + +class State(Type): + """ + Used for storing object state. + """ + +class Boolean(Type): + """ + A two-value column. + """ diff --git a/src/relstorage/adapters/tests/test_txncontrol.py b/src/relstorage/adapters/tests/test_txncontrol.py index 97b67841..1e5d5d66 100644 --- a/src/relstorage/adapters/tests/test_txncontrol.py +++ b/src/relstorage/adapters/tests/test_txncontrol.py @@ -18,12 +18,9 @@ from relstorage.tests import TestCase from relstorage.tests import MockConnection +from relstorage.tests import MockConnectionManager from relstorage.tests import MockCursor - -class MockConnmanager(object): - - def rollback(self, conn, _cursor): - conn.rollback() +from relstorage.tests import MockPoller class TestTransactionControl(TestCase): @@ -37,29 +34,8 @@ def Binary(self, arg): return arg def _makeOne(self, keep_history=True, binary=None): - return self._getClass()(MockConnmanager(), keep_history, binary or self.Binary) - - def _get_hf_tid_query(self): - return self._getClass()._get_tid_queries[1] - - def _get_hp_tid_query(self): - return self._getClass()._get_tid_queries[0] - - def _check_get_tid_query(self, keep_history, expected_query): - inst = self._makeOne(keep_history) - cur = MockCursor() - cur.results = [(1,)] - - inst.get_tid(cur) - - self.assertEqual(cur.executed.pop(), - (expected_query, None)) - - def test_get_tid_hf(self): - self._check_get_tid_query(False, self._get_hf_tid_query()) - - def test_get_tid_hp(self): - self._check_get_tid_query(True, self._get_hp_tid_query()) + return self._getClass()(MockConnectionManager(), MockPoller(), + keep_history, binary or self.Binary) def test_get_tid_empty_db(self): inst = self._makeOne() @@ -71,12 +47,11 @@ def test_get_tid_empty_db(self): def test_add_transaction_hp(self): inst = self._makeOne() cur = MockCursor() - + __traceback_info__ = inst.__dict__ inst.add_transaction(cur, 1, u'user', u'desc', u'ext') - self.assertEqual( cur.executed.pop(), - (inst._add_transaction_query, + (str(inst._add_transaction_query), (1, False, b'user', b'desc', b'ext')) ) @@ -84,7 +59,7 @@ def test_add_transaction_hp(self): self.assertEqual( cur.executed.pop(), - (inst._add_transaction_query, + (str(inst._add_transaction_query), (1, True, b'user', b'desc', b'ext')) ) diff --git a/src/relstorage/adapters/txncontrol.py b/src/relstorage/adapters/txncontrol.py index 7952a21d..ec1067eb 100644 --- a/src/relstorage/adapters/txncontrol.py +++ b/src/relstorage/adapters/txncontrol.py @@ -21,7 +21,8 @@ from .._compat import ABC from ._util import noop_when_history_free -from ._util import query_property + +from .schema import Schema from .interfaces import ITransactionControl @@ -73,19 +74,14 @@ class GenericTransactionControl(AbstractTransactionControl): and history-preserving storages that share a common syntax. """ - _get_tid_queries = ( - "SELECT MAX(tid) FROM transaction", - "SELECT MAX(tid) FROM object_state" - ) - _get_tid_query = query_property('_get_tid') - - def __init__(self, connmanager, keep_history, Binary): + def __init__(self, connmanager, poller, keep_history, Binary): super(GenericTransactionControl, self).__init__(connmanager) + self.poller = poller self.keep_history = keep_history self.Binary = Binary def get_tid(self, cursor): - cursor.execute(self._get_tid_query) + self.poller.poll_query.execute(cursor) row = cursor.fetchall() if not row: # nothing has been stored yet @@ -94,15 +90,22 @@ def get_tid(self, cursor): tid = row[0][0] return tid if tid is not None else 0 - _add_transaction_query = """ - INSERT INTO transaction (tid, packed, username, description, extension) - VALUES (%s, %s, %s, %s, %s) - """ + _add_transaction_query = Schema.transaction.insert( + Schema.transaction.c.tid, + Schema.transaction.c.packed, + Schema.transaction.c.username, + Schema.transaction.c.description, + Schema.transaction.c.extension + ).prepared() @noop_when_history_free def add_transaction(self, cursor, tid, username, description, extension, packed=False): binary = self.Binary - cursor.execute(self._add_transaction_query, ( - tid, packed, binary(username), - binary(description), binary(extension))) + self._add_transaction_query.execute( + cursor, + ( + tid, packed, binary(username), + binary(description), binary(extension) + ) + ) diff --git a/src/relstorage/cache/interfaces.py b/src/relstorage/cache/interfaces.py index 8cf3a44c..68b0f184 100644 --- a/src/relstorage/cache/interfaces.py +++ b/src/relstorage/cache/interfaces.py @@ -18,17 +18,17 @@ from zope.interface import Attribute from zope.interface import Interface -import BTrees + from transaction.interfaces import TransientError from ZODB.POSException import StorageError +# Export +from relstorage._compat import MAX_TID # pylint:disable=unused-import + # pylint: disable=inherit-non-class,no-method-argument,no-self-argument # pylint:disable=unexpected-special-method-signature # pylint:disable=signature-differs -# An LLBTree uses much less memory than a dict, and is still plenty fast on CPython; -# it's just as big and slower on PyPy, though. -MAX_TID = BTrees.family64.maxint class IStateCache(Interface): """ diff --git a/src/relstorage/storage/__init__.py b/src/relstorage/storage/__init__.py index 0559fa21..fb6b1d16 100644 --- a/src/relstorage/storage/__init__.py +++ b/src/relstorage/storage/__init__.py @@ -470,7 +470,9 @@ def afterCompletion(self): # abort. # The next time we use the load connection, it will need to poll - # and will call our _on_load_activated. + # and will call our __on_first_use. + # Typically our next call from the ZODB Connection will be from its + # `newTransaction` method, a forced `sync` followed by `poll_invalidations`. # TODO: Why doesn't this use connmanager.restart_load()? # They both rollback; the difference is that restart_load checks for replicas, diff --git a/src/relstorage/tests/__init__.py b/src/relstorage/tests/__init__.py index ee954788..ec535197 100644 --- a/src/relstorage/tests/__init__.py +++ b/src/relstorage/tests/__init__.py @@ -11,6 +11,7 @@ from relstorage._compat import ABC from relstorage.options import Options +from relstorage.adapters.sql import DefaultDialect try: from unittest import mock @@ -29,12 +30,17 @@ class TestCase(unittest.TestCase): cleanups. """ # Avoid deprecation warnings; 2.7 doesn't have - # assertRaisesRegex + # assertRaisesRegex or assertRegex assertRaisesRegex = getattr( unittest.TestCase, 'assertRaisesRegex', None ) or getattr(unittest.TestCase, 'assertRaisesRegexp') + assertRegex = getattr( + unittest.TestCase, + 'assertRegex', + None + ) or getattr(unittest.TestCase, 'assertRegexpMatches') def setUp(self): super(TestCase, self).setUp() @@ -100,10 +106,13 @@ def tearDown(self): super(TestCase, self).tearDown() def assertIsEmpty(self, container): - self.assertEqual(len(container), 0) + self.assertLength(container, 0) assertEmpty = assertIsEmpty + def assertLength(self, container, length): + self.assertEqual(len(container), length, container) + class StorageCreatingMixin(ABC): keep_history = None # Override @@ -228,6 +237,10 @@ def fetchall(self): def close(self): self.closed = True + def __iter__(self): + for row in self.results: + yield row + class MockOptions(Options): cache_module_name = '' # disable cache_servers = '' @@ -251,10 +264,12 @@ class MockConnectionManager(object): disconnected_exceptions = () - def rollback(self, conn, cursor): - "Does nothing" + def rollback(self, conn, cursor): # pylint:disable=unused-argument + if hasattr(conn, 'rollback'): + conn.rollback() def rollback_and_close(self, conn, cursor): + self.rollback(conn, cursor) if conn: conn.close() if cursor: @@ -277,9 +292,30 @@ class MockPackUndo(object): class MockOIDAllocator(object): pass +class MockQuery(object): + + def __init__(self, raw): + self.raw = raw + + def execute(self, cursor, params=None): + cursor.execute(self.raw, params) + +class MockPoller(object): + + poll_query = MockQuery('SELECT MAX(tid) FROM object_state') + + def __init__(self, driver=None): + self.driver = driver or MockDriver() + +class MockDriver(object): + + dialect = DefaultDialect() + class MockAdapter(object): def __init__(self): + self.driver = MockDriver() self.connmanager = MockConnectionManager() self.packundo = MockPackUndo() self.oidallocator = MockOIDAllocator() + self.poller = MockPoller(self.driver)