Skip to content

Commit

Permalink
Enable savepints for all engines. Add disable_savepoints configuratio…
Browse files Browse the repository at this point in the history
…n parameter.
  • Loading branch information
lrowe committed Mar 30, 2020
1 parent 2bb7718 commit 25a0bcb
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 52 deletions.
32 changes: 16 additions & 16 deletions src/zope/sqlalchemy/datamanager.py
Expand Up @@ -56,8 +56,6 @@
STATUS_READONLY = "readonly" # session joined to transaction, no writes allowed.
STATUS_INVALIDATED = STATUS_CHANGED # BBB

NO_SAVEPOINT_SUPPORT = {"sqlite"}

_SESSION_STATE = WeakKeyDictionary() # a mapping of session -> status
# This is thread safe because you are using scoped sessions

Expand All @@ -74,7 +72,7 @@ class SessionDataManager(object):
One phase variant.
"""

def __init__(self, session, status, transaction_manager, keep_session=False):
def __init__(self, session, status, transaction_manager, keep_session=False, disable_savepoints=False):
self.transaction_manager = transaction_manager

# Support both SQLAlchemy 1.0 and 1.1
Expand All @@ -90,6 +88,7 @@ def __init__(self, session, status, transaction_manager, keep_session=False):
_SESSION_STATE[session] = status
self.state = "init"
self.keep_session = keep_session
self.disable_savepoints = disable_savepoints

def _finish(self, final_state):
assert self.tx is not None
Expand Down Expand Up @@ -147,11 +146,7 @@ def savepoint(self):
# support savepoints but Postgres is whitelisted independent of its
# version. Possibly additional version information should be taken
# into account (ajung)
if set(
engine.url.drivername
for engine in self.session.transaction._connections.keys()
if isinstance(engine, Engine)
).intersection(NO_SAVEPOINT_SUPPORT):
if self.disable_savepoints:
raise AttributeError("savepoint")
return self._savepoint

Expand Down Expand Up @@ -211,6 +206,7 @@ def join_transaction(
initial_state=STATUS_ACTIVE,
transaction_manager=zope_transaction.manager,
keep_session=False,
disable_savepoints=False,
):
"""Join a session to a transaction using the appropriate datamanager.
Expand All @@ -231,19 +227,19 @@ def join_transaction(
else:
DataManager = SessionDataManager
DataManager(
session, initial_state, transaction_manager, keep_session=keep_session
session, initial_state, transaction_manager, keep_session=keep_session, disable_savepoints=disable_savepoints,
)


def mark_changed(
session, transaction_manager=zope_transaction.manager, keep_session=False
session, transaction_manager=zope_transaction.manager, keep_session=False, disable_savepoints=False,
):
"""Mark a session as needing to be committed.
"""
assert (
_SESSION_STATE.get(session, None) is not STATUS_READONLY
), "Session already registered as read only"
join_transaction(session, STATUS_CHANGED, transaction_manager, keep_session)
join_transaction(session, STATUS_CHANGED, transaction_manager, keep_session, disable_savepoints)
_SESSION_STATE[session] = STATUS_CHANGED


Expand All @@ -257,31 +253,33 @@ def __init__(
initial_state=STATUS_ACTIVE,
transaction_manager=zope_transaction.manager,
keep_session=False,
disable_savepoints=False,
):
if initial_state == "invalidated":
initial_state = STATUS_CHANGED # BBB
self.initial_state = initial_state
self.transaction_manager = transaction_manager
self.keep_session = keep_session
self.disable_savepoints = disable_savepoints

def after_begin(self, session, transaction, connection):
join_transaction(
session, self.initial_state, self.transaction_manager, self.keep_session
session, self.initial_state, self.transaction_manager, self.keep_session, self.disable_savepoints
)

def after_attach(self, session, instance):
join_transaction(
session, self.initial_state, self.transaction_manager, self.keep_session
session, self.initial_state, self.transaction_manager, self.keep_session, self.disable_savepoints
)

def after_flush(self, session, flush_context):
mark_changed(session, self.transaction_manager, self.keep_session)
mark_changed(session, self.transaction_manager, self.keep_session, self.disable_savepoints)

def after_bulk_update(self, session, query, query_context, result):
mark_changed(session, self.transaction_manager, self.keep_session)
mark_changed(session, self.transaction_manager, self.keep_session, self.disable_savepoints)

def after_bulk_delete(self, session, query, query_context, result):
mark_changed(session, self.transaction_manager, self.keep_session)
mark_changed(session, self.transaction_manager, self.keep_session, self.disable_savepoints)

def before_commit(self, session):
assert (
Expand All @@ -295,6 +293,7 @@ def register(
initial_state=STATUS_ACTIVE,
transaction_manager=zope_transaction.manager,
keep_session=False,
disable_savepoints=False,
):
"""Register ZopeTransaction listener events on the
given Session or Session factory/class.
Expand All @@ -318,6 +317,7 @@ def register(
initial_state=initial_state,
transaction_manager=transaction_manager,
keep_session=keep_session,
disable_savepoints=disable_savepoints,
)

event.listen(session, "after_begin", ext.after_begin)
Expand Down
55 changes: 19 additions & 36 deletions src/zope/sqlalchemy/tests.py
Expand Up @@ -26,6 +26,7 @@

import os
import re
import sys
import threading
import time
import transaction
Expand All @@ -42,6 +43,8 @@

TEST_TWOPHASE = bool(os.environ.get("TEST_TWOPHASE"))
TEST_DSN = os.environ.get("TEST_DSN", "sqlite:///:memory:")
SQLITE_NO_SAVEPOINT = TEST_DSN.startswith("sqlite:") and sys.version_info < (3, 6)
SQLITE_NO_SAVEPOINT_MSG = "SQLite savepoints unsupported by this Python version."


class SimpleModel(object):
Expand All @@ -64,28 +67,6 @@ class Skill(SimpleModel):
engine = sa.create_engine(TEST_DSN)
engine2 = sa.create_engine(TEST_DSN)

# See https://code.google.com/p/pysqlite-static-env/
HAS_PATCHED_PYSQLITE = False
if engine.url.drivername == "sqlite":
try:
from pysqlite2.dbapi2 import Connection
except ImportError:
pass
else:
if hasattr(Connection, "operation_needs_transaction_callback"):
HAS_PATCHED_PYSQLITE = True

if HAS_PATCHED_PYSQLITE:
from sqlalchemy import event
from zope.sqlalchemy.datamanager import NO_SAVEPOINT_SUPPORT

NO_SAVEPOINT_SUPPORT.remove("sqlite")

@event.listens_for(engine, "connect")
def connect(dbapi_connection, connection_record):
dbapi_connection.operation_needs_transaction_callback = lambda x: True


Session = orm.scoped_session(orm.sessionmaker(bind=engine, twophase=TEST_TWOPHASE))
tx.register(Session)

Expand All @@ -98,6 +79,8 @@ def connect(dbapi_connection, connection_record):
KeepSession = orm.scoped_session(orm.sessionmaker(bind=engine, twophase=TEST_TWOPHASE))
tx.register(KeepSession, keep_session=True)

DisableSavepointsSession = orm.scoped_session(orm.sessionmaker(bind=engine, twophase=TEST_TWOPHASE))
tx.register(DisableSavepointsSession, disable_savepoints=True)

metadata = sa.MetaData() # best to use unbound metadata

Expand Down Expand Up @@ -336,19 +319,15 @@ def testTransactionJoiningUsingRegister(self):
"Not joined transaction",
)

@unittest.skipIf(SQLITE_NO_SAVEPOINT, SQLITE_NO_SAVEPOINT_MSG)
def testSavepoint(self):
use_savepoint = not engine.url.drivername in tx.NO_SAVEPOINT_SUPPORT
t = transaction.get()
session = Session()
query = session.query(User)
self.assertFalse(query.all(), "Users table should be empty")

s0 = t.savepoint(optimistic=True) # this should always work

if not use_savepoint:
self.assertRaises(TypeError, t.savepoint)
return # sqlite databases do not support savepoints

s1 = t.savepoint()
session.add(User(id=1, firstname="udo", lastname="juergens"))
session.flush()
Expand All @@ -365,11 +344,18 @@ def testSavepoint(self):
s1.rollback()
self.assertFalse(query.all(), "Users table should be empty")

def testRollbackAttributes(self):
use_savepoint = not engine.url.drivername in tx.NO_SAVEPOINT_SUPPORT
if not use_savepoint:
return # sqlite databases do not support savepoints
def testDisableSavepoints(self):
t = transaction.get()
session = DisableSavepointsSession()
query = session.query(User)
self.assertFalse(query.all(), "Users table should be empty")

s0 = t.savepoint(optimistic=True) # this should always work

self.assertRaises(TypeError, t.savepoint)

@unittest.skipIf(SQLITE_NO_SAVEPOINT, SQLITE_NO_SAVEPOINT_MSG)
def testRollbackAttributes(self):
t = transaction.get()
session = Session()
query = session.query(User)
Expand All @@ -393,7 +379,6 @@ def testRollbackAttributes(self):
def testCommit(self):
session = Session()

use_savepoint = not engine.url.drivername in tx.NO_SAVEPOINT_SUPPORT
query = session.query(User)
rows = query.all()
self.assertEqual(len(rows), 0)
Expand Down Expand Up @@ -436,9 +421,8 @@ def testCommit(self):
results = engine.connect().execute(test_users.select())
self.assertEqual(len(results.fetchall()), 2)

@unittest.skipIf(SQLITE_NO_SAVEPOINT, SQLITE_NO_SAVEPOINT_MSG)
def testCommitWithSavepoint(self):
if engine.url.drivername in tx.NO_SAVEPOINT_SUPPORT:
return
session = Session()
session.add(User(id=1, firstname="udo", lastname="juergens"))
session.add(User(id=2, firstname="heino", lastname="n/a"))
Expand All @@ -460,10 +444,9 @@ def testCommitWithSavepoint(self):
results = engine.connect().execute(test_users.select())
self.assertEqual(len(results.fetchall()), 1)

@unittest.skipIf(SQLITE_NO_SAVEPOINT, SQLITE_NO_SAVEPOINT_MSG)
def testNestedSessionCommitAllowed(self):
# Existing code might use nested transactions
if engine.url.drivername in tx.NO_SAVEPOINT_SUPPORT:
return
session = Session()
session.add(User(id=1, firstname="udo", lastname="juergens"))
session.begin_nested()
Expand Down

0 comments on commit 25a0bcb

Please sign in to comment.