Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
* BROKEN CHANGE: deny any action in transaction after commit/rollback

## 3.0.1b6 ##
* BROKEN CHANGES: remove writer.write(mess1, mess2) variant, use list instead: writer.write([mess1, mess2])
* BROKEN CHANGES: change names of public method in topic client
Expand Down
62 changes: 62 additions & 0 deletions tests/aio/test_tx.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ async def test_tx_snapshot_ro(driver, database):

await ro_tx.commit()

ro_tx = session.transaction(tx_mode=ydb.SnapshotReadOnly())
with pytest.raises(ydb.issues.GenericError) as exc_info:
await ro_tx.execute("UPDATE `test` SET value = value + 1")
assert "read only transaction" in exc_info.value.message
Expand All @@ -94,3 +95,64 @@ async def test_tx_snapshot_ro(driver, database):
commit_tx=True,
)
assert data[0].rows == [{"value": 2}]


@pytest.mark.asyncio
async def test_split_transactions_deny_split(driver, table_name):
async with ydb.aio.SessionPool(driver, 1) as pool:

async def check_transaction(s: ydb.aio.table.Session):
async with s.transaction(deny_split_transactions=True) as tx:
await tx.execute("INSERT INTO %s (id) VALUES (1)" % table_name)
await tx.commit()

with pytest.raises(RuntimeError):
await tx.execute("INSERT INTO %s (id) VALUES (2)" % table_name)

await tx.commit()

async with s.transaction() as tx:
rs = await tx.execute("SELECT COUNT(*) as cnt FROM %s" % table_name)
assert rs[0].rows[0].cnt == 1

await pool.retry_operation(check_transaction)


@pytest.mark.asyncio
async def test_split_transactions_allow_split(driver, table_name):
async with ydb.aio.SessionPool(driver, 1) as pool:

async def check_transaction(s: ydb.aio.table.Session):
async with s.transaction(deny_split_transactions=False) as tx:
await tx.execute("INSERT INTO %s (id) VALUES (1)" % table_name)
await tx.commit()

await tx.execute("INSERT INTO %s (id) VALUES (2)" % table_name)
await tx.commit()

async with s.transaction() as tx:
rs = await tx.execute("SELECT COUNT(*) as cnt FROM %s" % table_name)
assert rs[0].rows[0].cnt == 2

await pool.retry_operation(check_transaction)


@pytest.mark.asyncio
async def test_split_transactions_default(driver, table_name):
async with ydb.aio.SessionPool(driver, 1) as pool:

async def check_transaction(s: ydb.aio.table.Session):
async with s.transaction() as tx:
await tx.execute("INSERT INTO %s (id) VALUES (1)" % table_name)
await tx.commit()

with pytest.raises(RuntimeError):
await tx.execute("INSERT INTO %s (id) VALUES (2)" % table_name)

await tx.commit()

async with s.transaction() as tx:
rs = await tx.execute("SELECT COUNT(*) as cnt FROM %s" % table_name)
assert rs[0].rows[0].cnt == 1

await pool.retry_operation(check_transaction)
27 changes: 27 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,33 @@ async def driver_sync(endpoint, database, event_loop):
driver.stop(timeout=10)


@pytest.fixture()
def table_name(driver_sync, database):
table_name = "table"

with ydb.SessionPool(driver_sync) as pool:

def create_table(s):
try:
s.drop_table(database + "/" + table_name)
except ydb.SchemeError:
pass

s.execute_scheme(
"""
CREATE TABLE %s (
id Int64 NOT NULL,
i64Val Int64,
PRIMARY KEY(id)
)
"""
% table_name
)

pool.retry_operation_sync(create_table)
return table_name


@pytest.fixture()
def topic_consumer():
return "fixture-consumer"
Expand Down
59 changes: 59 additions & 0 deletions tests/table/test_tx.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def test_tx_snapshot_ro(driver_sync, database):

ro_tx.commit()

ro_tx = session.transaction(tx_mode=ydb.SnapshotReadOnly())
with pytest.raises(ydb.issues.GenericError) as exc_info:
ro_tx.execute("UPDATE `test` SET value = value + 1")
assert "read only transaction" in exc_info.value.message
Expand All @@ -89,3 +90,61 @@ def test_tx_snapshot_ro(driver_sync, database):
commit_tx=True,
)
assert data[0].rows == [{"value": 2}]


def test_split_transactions_deny_split(driver_sync, table_name):
with ydb.SessionPool(driver_sync, 1) as pool:

def check_transaction(s: ydb.table.Session):
with s.transaction(deny_split_transactions=True) as tx:
tx.execute("INSERT INTO %s (id) VALUES (1)" % table_name)
tx.commit()

with pytest.raises(RuntimeError):
tx.execute("INSERT INTO %s (id) VALUES (2)" % table_name)

tx.commit()

with s.transaction() as tx:
rs = tx.execute("SELECT COUNT(*) as cnt FROM %s" % table_name)
assert rs[0].rows[0].cnt == 1

pool.retry_operation_sync(check_transaction)


def test_split_transactions_allow_split(driver_sync, table_name):
with ydb.SessionPool(driver_sync, 1) as pool:

def check_transaction(s: ydb.table.Session):
with s.transaction(deny_split_transactions=False) as tx:
tx.execute("INSERT INTO %s (id) VALUES (1)" % table_name)
tx.commit()

tx.execute("INSERT INTO %s (id) VALUES (2)" % table_name)
tx.commit()

with s.transaction() as tx:
rs = tx.execute("SELECT COUNT(*) as cnt FROM %s" % table_name)
assert rs[0].rows[0].cnt == 2

pool.retry_operation_sync(check_transaction)


def test_split_transactions_default(driver_sync, table_name):
with ydb.SessionPool(driver_sync, 1) as pool:

def check_transaction(s: ydb.table.Session):
with s.transaction() as tx:
tx.execute("INSERT INTO %s (id) VALUES (1)" % table_name)
tx.commit()

with pytest.raises(RuntimeError):
tx.execute("INSERT INTO %s (id) VALUES (2)" % table_name)

tx.commit()

with s.transaction() as tx:
rs = tx.execute("SELECT COUNT(*) as cnt FROM %s" % table_name)
assert rs[0].rows[0].cnt == 1

pool.retry_operation_sync(check_transaction)
13 changes: 11 additions & 2 deletions ydb/aio/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,14 @@ async def alter_table(
set_read_replicas_settings,
)

def transaction(self, tx_mode=None):
return TxContext(self._driver, self._state, self, tx_mode)
def transaction(self, tx_mode=None, *, deny_split_transactions=True):
return TxContext(
self._driver,
self._state,
self,
tx_mode,
deny_split_transactions=deny_split_transactions,
)

async def describe_table(self, path, settings=None): # pylint: disable=W0236
return await super().describe_table(path, settings)
Expand Down Expand Up @@ -184,6 +190,9 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
async def execute(
self, query, parameters=None, commit_tx=False, settings=None
): # pylint: disable=W0236

self._check_split()

return await super().execute(query, parameters, commit_tx, settings)

async def commit(self, settings=None): # pylint: disable=W0236
Expand Down
77 changes: 72 additions & 5 deletions ydb/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1173,7 +1173,7 @@ def execute_scheme(self, yql_text, settings=None):
pass

@abstractmethod
def transaction(self, tx_mode=None):
def transaction(self, tx_mode=None, deny_split_transactions=True):
pass

@abstractmethod
Expand Down Expand Up @@ -1677,8 +1677,14 @@ def execute_scheme(self, yql_text, settings=None):
self._state.endpoint,
)

def transaction(self, tx_mode=None):
return TxContext(self._driver, self._state, self, tx_mode)
def transaction(self, tx_mode=None, deny_split_transactions=True):
return TxContext(
self._driver,
self._state,
self,
tx_mode,
deny_split_transactions=deny_split_transactions,
)

def has_prepared(self, query):
return query in self._state
Expand Down Expand Up @@ -2189,9 +2195,27 @@ def begin(self, settings=None):


class BaseTxContext(ITxContext):
__slots__ = ("_tx_state", "_session_state", "_driver", "session")
__slots__ = (
"_tx_state",
"_session_state",
"_driver",
"session",
"_finished",
"_deny_split_transactions",
)

def __init__(self, driver, session_state, session, tx_mode=None):
_COMMIT = "commit"
_ROLLBACK = "rollback"

def __init__(
self,
driver,
session_state,
session,
tx_mode=None,
*,
deny_split_transactions=True
):
"""
An object that provides a simple transaction context manager that allows statements execution
in a transaction. You don't have to open transaction explicitly, because context manager encapsulates
Expand All @@ -2214,6 +2238,8 @@ def __init__(self, driver, session_state, session, tx_mode=None):
self._tx_state = _tx_ctx_impl.TxState(tx_mode)
self._session_state = session_state
self.session = session
self._finished = ""
self._deny_split_transactions = deny_split_transactions

def __enter__(self):
"""
Expand Down Expand Up @@ -2271,6 +2297,9 @@ def execute(self, query, parameters=None, commit_tx=False, settings=None):

:return: A result sets or exception in case of execution errors
"""

self._check_split()

return self._driver(
_tx_ctx_impl.execute_request_factory(
self._session_state,
Expand All @@ -2297,8 +2326,12 @@ def commit(self, settings=None):

:return: A committed transaction or exception if commit is failed
"""

self._set_finish(self._COMMIT)

if self._tx_state.tx_id is None and not self._tx_state.dead:
return self

return self._driver(
_tx_ctx_impl.commit_request_factory(self._session_state, self._tx_state),
_apis.TableService.Stub,
Expand All @@ -2318,8 +2351,12 @@ def rollback(self, settings=None):

:return: A rolled back transaction or exception if rollback is failed
"""

self._set_finish(self._ROLLBACK)

if self._tx_state.tx_id is None and not self._tx_state.dead:
return self

return self._driver(
_tx_ctx_impl.rollback_request_factory(self._session_state, self._tx_state),
_apis.TableService.Stub,
Expand All @@ -2340,6 +2377,9 @@ def begin(self, settings=None):
"""
if self._tx_state.tx_id is not None:
return self

self._check_split()

return self._driver(
_tx_ctx_impl.begin_request_factory(self._session_state, self._tx_state),
_apis.TableService.Stub,
Expand All @@ -2350,6 +2390,21 @@ def begin(self, settings=None):
self._session_state.endpoint,
)

def _set_finish(self, val):
self._check_split(val)
self._finished = val

def _check_split(self, allow=""):
"""
Deny all operaions with transaction after commit/rollback.
Exception: double commit and double rollbacks, because it is safe
"""
if not self._deny_split_transactions:
return

if self._finished != "" and self._finished != allow:
raise RuntimeError("Any operation with finished transaction is denied")


class TxContext(BaseTxContext):
@_utilities.wrap_async_call_exceptions
Expand All @@ -2365,6 +2420,9 @@ def async_execute(self, query, parameters=None, commit_tx=False, settings=None):

:return: A future of query execution
"""

self._check_split()

return self._driver.future(
_tx_ctx_impl.execute_request_factory(
self._session_state,
Expand Down Expand Up @@ -2396,8 +2454,11 @@ def async_commit(self, settings=None):

:return: A future of commit call
"""
self._set_finish(self._COMMIT)

if self._tx_state.tx_id is None and not self._tx_state.dead:
return _utilities.wrap_result_in_future(self)

return self._driver.future(
_tx_ctx_impl.commit_request_factory(self._session_state, self._tx_state),
_apis.TableService.Stub,
Expand All @@ -2418,8 +2479,11 @@ def async_rollback(self, settings=None):

:return: A future of rollback call
"""
self._set_finish(self._ROLLBACK)

if self._tx_state.tx_id is None and not self._tx_state.dead:
return _utilities.wrap_result_in_future(self)

return self._driver.future(
_tx_ctx_impl.rollback_request_factory(self._session_state, self._tx_state),
_apis.TableService.Stub,
Expand All @@ -2441,6 +2505,9 @@ def async_begin(self, settings=None):
"""
if self._tx_state.tx_id is not None:
return _utilities.wrap_result_in_future(self)

self._check_split()

return self._driver.future(
_tx_ctx_impl.begin_request_factory(self._session_state, self._tx_state),
_apis.TableService.Stub,
Expand Down