Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable savepints for all engines. Add disable_savepoints param #45

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions src/zope/sqlalchemy/datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@
STATUS_READONLY = "readonly" # session joined to transaction, no writes allowed.
STATUS_INVALIDATED = STATUS_CHANGED # BBB

NO_SAVEPOINT_SUPPORT = {"sqlite"}

_SESSION_STATE = WeakKeyDictionary() # a mapping of session -> status
# This is thread safe because you are using scoped sessions

Expand All @@ -74,7 +72,7 @@ class SessionDataManager(object):
One phase variant.
"""

def __init__(self, session, status, transaction_manager, keep_session=False):
def __init__(self, session, status, transaction_manager, keep_session=False, disable_savepoints=False):
self.transaction_manager = transaction_manager

# Support both SQLAlchemy 1.0 and 1.1
Expand All @@ -90,6 +88,7 @@ def __init__(self, session, status, transaction_manager, keep_session=False):
_SESSION_STATE[session] = status
self.state = "init"
self.keep_session = keep_session
self.disable_savepoints = disable_savepoints

def _finish(self, final_state):
assert self.tx is not None
Expand Down Expand Up @@ -147,11 +146,7 @@ def savepoint(self):
# support savepoints but Postgres is whitelisted independent of its
# version. Possibly additional version information should be taken
# into account (ajung)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is outdated now you're no longer checking engine names.

if set(
engine.url.drivername
for engine in self.session.transaction._connections.keys()
if isinstance(engine, Engine)
).intersection(NO_SAVEPOINT_SUPPORT):
if self.disable_savepoints:
raise AttributeError("savepoint")
return self._savepoint

Expand Down Expand Up @@ -211,6 +206,7 @@ def join_transaction(
initial_state=STATUS_ACTIVE,
transaction_manager=zope_transaction.manager,
keep_session=False,
disable_savepoints=False,
):
"""Join a session to a transaction using the appropriate datamanager.

Expand All @@ -231,19 +227,19 @@ def join_transaction(
else:
DataManager = SessionDataManager
DataManager(
session, initial_state, transaction_manager, keep_session=keep_session
session, initial_state, transaction_manager, keep_session=keep_session, disable_savepoints=disable_savepoints,
)


def mark_changed(
session, transaction_manager=zope_transaction.manager, keep_session=False
session, transaction_manager=zope_transaction.manager, keep_session=False, disable_savepoints=False,
):
"""Mark a session as needing to be committed.
"""
assert (
_SESSION_STATE.get(session, None) is not STATUS_READONLY
), "Session already registered as read only"
join_transaction(session, STATUS_CHANGED, transaction_manager, keep_session)
join_transaction(session, STATUS_CHANGED, transaction_manager, keep_session, disable_savepoints)
_SESSION_STATE[session] = STATUS_CHANGED


Expand All @@ -257,31 +253,33 @@ def __init__(
initial_state=STATUS_ACTIVE,
transaction_manager=zope_transaction.manager,
keep_session=False,
disable_savepoints=False,
):
if initial_state == "invalidated":
initial_state = STATUS_CHANGED # BBB
self.initial_state = initial_state
self.transaction_manager = transaction_manager
self.keep_session = keep_session
self.disable_savepoints = disable_savepoints

def after_begin(self, session, transaction, connection):
join_transaction(
session, self.initial_state, self.transaction_manager, self.keep_session
session, self.initial_state, self.transaction_manager, self.keep_session, self.disable_savepoints
)

def after_attach(self, session, instance):
join_transaction(
session, self.initial_state, self.transaction_manager, self.keep_session
session, self.initial_state, self.transaction_manager, self.keep_session, self.disable_savepoints
)

def after_flush(self, session, flush_context):
mark_changed(session, self.transaction_manager, self.keep_session)
mark_changed(session, self.transaction_manager, self.keep_session, self.disable_savepoints)

def after_bulk_update(self, session, query, query_context, result):
mark_changed(session, self.transaction_manager, self.keep_session)
mark_changed(session, self.transaction_manager, self.keep_session, self.disable_savepoints)

def after_bulk_delete(self, session, query, query_context, result):
mark_changed(session, self.transaction_manager, self.keep_session)
mark_changed(session, self.transaction_manager, self.keep_session, self.disable_savepoints)

def before_commit(self, session):
assert (
Expand All @@ -295,6 +293,7 @@ def register(
initial_state=STATUS_ACTIVE,
transaction_manager=zope_transaction.manager,
keep_session=False,
disable_savepoints=False,
):
"""Register ZopeTransaction listener events on the
given Session or Session factory/class.
Expand All @@ -318,6 +317,7 @@ def register(
initial_state=initial_state,
transaction_manager=transaction_manager,
keep_session=keep_session,
disable_savepoints=disable_savepoints,
)

event.listen(session, "after_begin", ext.after_begin)
Expand Down
55 changes: 19 additions & 36 deletions src/zope/sqlalchemy/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import os
import re
import sys
import threading
import time
import transaction
Expand All @@ -42,6 +43,8 @@

TEST_TWOPHASE = bool(os.environ.get("TEST_TWOPHASE"))
TEST_DSN = os.environ.get("TEST_DSN", "sqlite:///:memory:")
SQLITE_NO_SAVEPOINT = TEST_DSN.startswith("sqlite:") and sys.version_info < (3, 6)
SQLITE_NO_SAVEPOINT_MSG = "SQLite savepoints unsupported by this Python version."


class SimpleModel(object):
Expand All @@ -64,28 +67,6 @@ class Skill(SimpleModel):
engine = sa.create_engine(TEST_DSN)
engine2 = sa.create_engine(TEST_DSN)

# See https://code.google.com/p/pysqlite-static-env/
HAS_PATCHED_PYSQLITE = False
if engine.url.drivername == "sqlite":
try:
from pysqlite2.dbapi2 import Connection
except ImportError:
pass
else:
if hasattr(Connection, "operation_needs_transaction_callback"):
HAS_PATCHED_PYSQLITE = True

if HAS_PATCHED_PYSQLITE:
from sqlalchemy import event
from zope.sqlalchemy.datamanager import NO_SAVEPOINT_SUPPORT

NO_SAVEPOINT_SUPPORT.remove("sqlite")

@event.listens_for(engine, "connect")
def connect(dbapi_connection, connection_record):
dbapi_connection.operation_needs_transaction_callback = lambda x: True


Session = orm.scoped_session(orm.sessionmaker(bind=engine, twophase=TEST_TWOPHASE))
tx.register(Session)

Expand All @@ -98,6 +79,8 @@ def connect(dbapi_connection, connection_record):
KeepSession = orm.scoped_session(orm.sessionmaker(bind=engine, twophase=TEST_TWOPHASE))
tx.register(KeepSession, keep_session=True)

DisableSavepointsSession = orm.scoped_session(orm.sessionmaker(bind=engine, twophase=TEST_TWOPHASE))
tx.register(DisableSavepointsSession, disable_savepoints=True)

metadata = sa.MetaData() # best to use unbound metadata

Expand Down Expand Up @@ -336,19 +319,15 @@ def testTransactionJoiningUsingRegister(self):
"Not joined transaction",
)

@unittest.skipIf(SQLITE_NO_SAVEPOINT, SQLITE_NO_SAVEPOINT_MSG)
def testSavepoint(self):
use_savepoint = not engine.url.drivername in tx.NO_SAVEPOINT_SUPPORT
t = transaction.get()
session = Session()
query = session.query(User)
self.assertFalse(query.all(), "Users table should be empty")

s0 = t.savepoint(optimistic=True) # this should always work

if not use_savepoint:
self.assertRaises(TypeError, t.savepoint)
return # sqlite databases do not support savepoints

s1 = t.savepoint()
session.add(User(id=1, firstname="udo", lastname="juergens"))
session.flush()
Expand All @@ -365,11 +344,18 @@ def testSavepoint(self):
s1.rollback()
self.assertFalse(query.all(), "Users table should be empty")

def testRollbackAttributes(self):
use_savepoint = not engine.url.drivername in tx.NO_SAVEPOINT_SUPPORT
if not use_savepoint:
return # sqlite databases do not support savepoints
def testDisableSavepoints(self):
t = transaction.get()
session = DisableSavepointsSession()
query = session.query(User)
self.assertFalse(query.all(), "Users table should be empty")

s0 = t.savepoint(optimistic=True) # this should always work

self.assertRaises(TypeError, t.savepoint)

@unittest.skipIf(SQLITE_NO_SAVEPOINT, SQLITE_NO_SAVEPOINT_MSG)
def testRollbackAttributes(self):
t = transaction.get()
session = Session()
query = session.query(User)
Expand All @@ -393,7 +379,6 @@ def testRollbackAttributes(self):
def testCommit(self):
session = Session()

use_savepoint = not engine.url.drivername in tx.NO_SAVEPOINT_SUPPORT
query = session.query(User)
rows = query.all()
self.assertEqual(len(rows), 0)
Expand Down Expand Up @@ -436,9 +421,8 @@ def testCommit(self):
results = engine.connect().execute(test_users.select())
self.assertEqual(len(results.fetchall()), 2)

@unittest.skipIf(SQLITE_NO_SAVEPOINT, SQLITE_NO_SAVEPOINT_MSG)
def testCommitWithSavepoint(self):
if engine.url.drivername in tx.NO_SAVEPOINT_SUPPORT:
return
session = Session()
session.add(User(id=1, firstname="udo", lastname="juergens"))
session.add(User(id=2, firstname="heino", lastname="n/a"))
Expand All @@ -460,10 +444,9 @@ def testCommitWithSavepoint(self):
results = engine.connect().execute(test_users.select())
self.assertEqual(len(results.fetchall()), 1)

@unittest.skipIf(SQLITE_NO_SAVEPOINT, SQLITE_NO_SAVEPOINT_MSG)
def testNestedSessionCommitAllowed(self):
# Existing code might use nested transactions
if engine.url.drivername in tx.NO_SAVEPOINT_SUPPORT:
return
session = Session()
session.add(User(id=1, firstname="udo", lastname="juergens"))
session.begin_nested()
Expand Down