diff --git a/CHANGELOG.md b/CHANGELOG.md index decbb671..4cfa110d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,5 @@ +* BROKEN CHANGE: deny any action in transaction after commit/rollback + ## 3.0.1b6 ## * BROKEN CHANGES: remove writer.write(mess1, mess2) variant, use list instead: writer.write([mess1, mess2]) * BROKEN CHANGES: change names of public method in topic client diff --git a/tests/aio/test_tx.py b/tests/aio/test_tx.py index 2161ddeb..be5c6806 100644 --- a/tests/aio/test_tx.py +++ b/tests/aio/test_tx.py @@ -85,6 +85,7 @@ async def test_tx_snapshot_ro(driver, database): await ro_tx.commit() + ro_tx = session.transaction(tx_mode=ydb.SnapshotReadOnly()) with pytest.raises(ydb.issues.GenericError) as exc_info: await ro_tx.execute("UPDATE `test` SET value = value + 1") assert "read only transaction" in exc_info.value.message @@ -94,3 +95,64 @@ async def test_tx_snapshot_ro(driver, database): commit_tx=True, ) assert data[0].rows == [{"value": 2}] + + +@pytest.mark.asyncio +async def test_split_transactions_deny_split(driver, table_name): + async with ydb.aio.SessionPool(driver, 1) as pool: + + async def check_transaction(s: ydb.aio.table.Session): + async with s.transaction(deny_split_transactions=True) as tx: + await tx.execute("INSERT INTO %s (id) VALUES (1)" % table_name) + await tx.commit() + + with pytest.raises(RuntimeError): + await tx.execute("INSERT INTO %s (id) VALUES (2)" % table_name) + + await tx.commit() + + async with s.transaction() as tx: + rs = await tx.execute("SELECT COUNT(*) as cnt FROM %s" % table_name) + assert rs[0].rows[0].cnt == 1 + + await pool.retry_operation(check_transaction) + + +@pytest.mark.asyncio +async def test_split_transactions_allow_split(driver, table_name): + async with ydb.aio.SessionPool(driver, 1) as pool: + + async def check_transaction(s: ydb.aio.table.Session): + async with s.transaction(deny_split_transactions=False) as tx: + await tx.execute("INSERT INTO %s (id) VALUES (1)" % table_name) + await tx.commit() + + await tx.execute("INSERT INTO %s (id) VALUES (2)" % table_name) + await tx.commit() + + async with s.transaction() as tx: + rs = await tx.execute("SELECT COUNT(*) as cnt FROM %s" % table_name) + assert rs[0].rows[0].cnt == 2 + + await pool.retry_operation(check_transaction) + + +@pytest.mark.asyncio +async def test_split_transactions_default(driver, table_name): + async with ydb.aio.SessionPool(driver, 1) as pool: + + async def check_transaction(s: ydb.aio.table.Session): + async with s.transaction() as tx: + await tx.execute("INSERT INTO %s (id) VALUES (1)" % table_name) + await tx.commit() + + with pytest.raises(RuntimeError): + await tx.execute("INSERT INTO %s (id) VALUES (2)" % table_name) + + await tx.commit() + + async with s.transaction() as tx: + rs = await tx.execute("SELECT COUNT(*) as cnt FROM %s" % table_name) + assert rs[0].rows[0].cnt == 1 + + await pool.retry_operation(check_transaction) diff --git a/tests/conftest.py b/tests/conftest.py index 99123660..e7809847 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -105,6 +105,33 @@ async def driver_sync(endpoint, database, event_loop): driver.stop(timeout=10) +@pytest.fixture() +def table_name(driver_sync, database): + table_name = "table" + + with ydb.SessionPool(driver_sync) as pool: + + def create_table(s): + try: + s.drop_table(database + "/" + table_name) + except ydb.SchemeError: + pass + + s.execute_scheme( + """ +CREATE TABLE %s ( +id Int64 NOT NULL, +i64Val Int64, +PRIMARY KEY(id) +) +""" + % table_name + ) + + pool.retry_operation_sync(create_table) + return table_name + + @pytest.fixture() def topic_consumer(): return "fixture-consumer" diff --git a/tests/table/test_tx.py b/tests/table/test_tx.py index 32d9b763..095fb72f 100644 --- a/tests/table/test_tx.py +++ b/tests/table/test_tx.py @@ -80,6 +80,7 @@ def test_tx_snapshot_ro(driver_sync, database): ro_tx.commit() + ro_tx = session.transaction(tx_mode=ydb.SnapshotReadOnly()) with pytest.raises(ydb.issues.GenericError) as exc_info: ro_tx.execute("UPDATE `test` SET value = value + 1") assert "read only transaction" in exc_info.value.message @@ -89,3 +90,61 @@ def test_tx_snapshot_ro(driver_sync, database): commit_tx=True, ) assert data[0].rows == [{"value": 2}] + + +def test_split_transactions_deny_split(driver_sync, table_name): + with ydb.SessionPool(driver_sync, 1) as pool: + + def check_transaction(s: ydb.table.Session): + with s.transaction(deny_split_transactions=True) as tx: + tx.execute("INSERT INTO %s (id) VALUES (1)" % table_name) + tx.commit() + + with pytest.raises(RuntimeError): + tx.execute("INSERT INTO %s (id) VALUES (2)" % table_name) + + tx.commit() + + with s.transaction() as tx: + rs = tx.execute("SELECT COUNT(*) as cnt FROM %s" % table_name) + assert rs[0].rows[0].cnt == 1 + + pool.retry_operation_sync(check_transaction) + + +def test_split_transactions_allow_split(driver_sync, table_name): + with ydb.SessionPool(driver_sync, 1) as pool: + + def check_transaction(s: ydb.table.Session): + with s.transaction(deny_split_transactions=False) as tx: + tx.execute("INSERT INTO %s (id) VALUES (1)" % table_name) + tx.commit() + + tx.execute("INSERT INTO %s (id) VALUES (2)" % table_name) + tx.commit() + + with s.transaction() as tx: + rs = tx.execute("SELECT COUNT(*) as cnt FROM %s" % table_name) + assert rs[0].rows[0].cnt == 2 + + pool.retry_operation_sync(check_transaction) + + +def test_split_transactions_default(driver_sync, table_name): + with ydb.SessionPool(driver_sync, 1) as pool: + + def check_transaction(s: ydb.table.Session): + with s.transaction() as tx: + tx.execute("INSERT INTO %s (id) VALUES (1)" % table_name) + tx.commit() + + with pytest.raises(RuntimeError): + tx.execute("INSERT INTO %s (id) VALUES (2)" % table_name) + + tx.commit() + + with s.transaction() as tx: + rs = tx.execute("SELECT COUNT(*) as cnt FROM %s" % table_name) + assert rs[0].rows[0].cnt == 1 + + pool.retry_operation_sync(check_transaction) diff --git a/ydb/aio/table.py b/ydb/aio/table.py index 9df797ea..95e2723d 100644 --- a/ydb/aio/table.py +++ b/ydb/aio/table.py @@ -120,8 +120,14 @@ async def alter_table( set_read_replicas_settings, ) - def transaction(self, tx_mode=None): - return TxContext(self._driver, self._state, self, tx_mode) + def transaction(self, tx_mode=None, *, deny_split_transactions=True): + return TxContext( + self._driver, + self._state, + self, + tx_mode, + deny_split_transactions=deny_split_transactions, + ) async def describe_table(self, path, settings=None): # pylint: disable=W0236 return await super().describe_table(path, settings) @@ -184,6 +190,9 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): async def execute( self, query, parameters=None, commit_tx=False, settings=None ): # pylint: disable=W0236 + + self._check_split() + return await super().execute(query, parameters, commit_tx, settings) async def commit(self, settings=None): # pylint: disable=W0236 diff --git a/ydb/table.py b/ydb/table.py index d60f138a..eaee78ec 100644 --- a/ydb/table.py +++ b/ydb/table.py @@ -1173,7 +1173,7 @@ def execute_scheme(self, yql_text, settings=None): pass @abstractmethod - def transaction(self, tx_mode=None): + def transaction(self, tx_mode=None, deny_split_transactions=True): pass @abstractmethod @@ -1677,8 +1677,14 @@ def execute_scheme(self, yql_text, settings=None): self._state.endpoint, ) - def transaction(self, tx_mode=None): - return TxContext(self._driver, self._state, self, tx_mode) + def transaction(self, tx_mode=None, deny_split_transactions=True): + return TxContext( + self._driver, + self._state, + self, + tx_mode, + deny_split_transactions=deny_split_transactions, + ) def has_prepared(self, query): return query in self._state @@ -2189,9 +2195,27 @@ def begin(self, settings=None): class BaseTxContext(ITxContext): - __slots__ = ("_tx_state", "_session_state", "_driver", "session") + __slots__ = ( + "_tx_state", + "_session_state", + "_driver", + "session", + "_finished", + "_deny_split_transactions", + ) - def __init__(self, driver, session_state, session, tx_mode=None): + _COMMIT = "commit" + _ROLLBACK = "rollback" + + def __init__( + self, + driver, + session_state, + session, + tx_mode=None, + *, + deny_split_transactions=True + ): """ An object that provides a simple transaction context manager that allows statements execution in a transaction. You don't have to open transaction explicitly, because context manager encapsulates @@ -2214,6 +2238,8 @@ def __init__(self, driver, session_state, session, tx_mode=None): self._tx_state = _tx_ctx_impl.TxState(tx_mode) self._session_state = session_state self.session = session + self._finished = "" + self._deny_split_transactions = deny_split_transactions def __enter__(self): """ @@ -2271,6 +2297,9 @@ def execute(self, query, parameters=None, commit_tx=False, settings=None): :return: A result sets or exception in case of execution errors """ + + self._check_split() + return self._driver( _tx_ctx_impl.execute_request_factory( self._session_state, @@ -2297,8 +2326,12 @@ def commit(self, settings=None): :return: A committed transaction or exception if commit is failed """ + + self._set_finish(self._COMMIT) + if self._tx_state.tx_id is None and not self._tx_state.dead: return self + return self._driver( _tx_ctx_impl.commit_request_factory(self._session_state, self._tx_state), _apis.TableService.Stub, @@ -2318,8 +2351,12 @@ def rollback(self, settings=None): :return: A rolled back transaction or exception if rollback is failed """ + + self._set_finish(self._ROLLBACK) + if self._tx_state.tx_id is None and not self._tx_state.dead: return self + return self._driver( _tx_ctx_impl.rollback_request_factory(self._session_state, self._tx_state), _apis.TableService.Stub, @@ -2340,6 +2377,9 @@ def begin(self, settings=None): """ if self._tx_state.tx_id is not None: return self + + self._check_split() + return self._driver( _tx_ctx_impl.begin_request_factory(self._session_state, self._tx_state), _apis.TableService.Stub, @@ -2350,6 +2390,21 @@ def begin(self, settings=None): self._session_state.endpoint, ) + def _set_finish(self, val): + self._check_split(val) + self._finished = val + + def _check_split(self, allow=""): + """ + Deny all operaions with transaction after commit/rollback. + Exception: double commit and double rollbacks, because it is safe + """ + if not self._deny_split_transactions: + return + + if self._finished != "" and self._finished != allow: + raise RuntimeError("Any operation with finished transaction is denied") + class TxContext(BaseTxContext): @_utilities.wrap_async_call_exceptions @@ -2365,6 +2420,9 @@ def async_execute(self, query, parameters=None, commit_tx=False, settings=None): :return: A future of query execution """ + + self._check_split() + return self._driver.future( _tx_ctx_impl.execute_request_factory( self._session_state, @@ -2396,8 +2454,11 @@ def async_commit(self, settings=None): :return: A future of commit call """ + self._set_finish(self._COMMIT) + if self._tx_state.tx_id is None and not self._tx_state.dead: return _utilities.wrap_result_in_future(self) + return self._driver.future( _tx_ctx_impl.commit_request_factory(self._session_state, self._tx_state), _apis.TableService.Stub, @@ -2418,8 +2479,11 @@ def async_rollback(self, settings=None): :return: A future of rollback call """ + self._set_finish(self._ROLLBACK) + if self._tx_state.tx_id is None and not self._tx_state.dead: return _utilities.wrap_result_in_future(self) + return self._driver.future( _tx_ctx_impl.rollback_request_factory(self._session_state, self._tx_state), _apis.TableService.Stub, @@ -2441,6 +2505,9 @@ def async_begin(self, settings=None): """ if self._tx_state.tx_id is not None: return _utilities.wrap_result_in_future(self) + + self._check_split() + return self._driver.future( _tx_ctx_impl.begin_request_factory(self._session_state, self._tx_state), _apis.TableService.Stub,