diff --git a/tests/aio/test_tx.py b/tests/aio/test_tx.py index be5c6806..da66769e 100644 --- a/tests/aio/test_tx.py +++ b/tests/aio/test_tx.py @@ -156,3 +156,46 @@ async def check_transaction(s: ydb.aio.table.Session): assert rs[0].rows[0].cnt == 1 await pool.retry_operation(check_transaction) + + +@pytest.mark.asyncio +async def test_truncated_response(driver, table_name, table_path): + column_types = ydb.BulkUpsertColumns().add_column("id", ydb.PrimitiveType.Int64) + + rows = [] + + rows_count = 1100 + for i in range(rows_count): + rows.append({"id": i}) + + await driver.table_client.bulk_upsert(table_path, rows, column_types) + + table_client = driver.table_client # default table client with driver's settings + s = table_client.session() + await s.create() + t = s.transaction() + with pytest.raises(ydb.TruncatedResponseError): + await t.execute("SELECT * FROM %s" % table_name) + + +@pytest.mark.asyncio +async def test_truncated_response_allow(driver, table_name, table_path): + column_types = ydb.BulkUpsertColumns().add_column("id", ydb.PrimitiveType.Int64) + + rows = [] + + rows_count = 1100 + for i in range(rows_count): + rows.append({"id": i}) + + await driver.table_client.bulk_upsert(table_path, rows, column_types) + + table_client = ydb.TableClient( + driver, ydb.TableClientSettings().with_allow_truncated_result(True) + ) + s = table_client.session() + await s.create() + t = s.transaction() + result = await t.execute("SELECT * FROM %s" % table_name) + assert result[0].truncated + assert len(result[0].rows) == 1000 diff --git a/tests/conftest.py b/tests/conftest.py index e7809847..675ef7b6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -132,6 +132,11 @@ def create_table(s): return table_name +@pytest.fixture() +def table_path(database, table_name) -> str: + return database + "/" + table_name + + @pytest.fixture() def topic_consumer(): return "fixture-consumer" diff --git a/tests/table/test_tx.py b/tests/table/test_tx.py index 095fb72f..bd703fa8 100644 --- a/tests/table/test_tx.py +++ b/tests/table/test_tx.py @@ -148,3 +148,46 @@ def check_transaction(s: ydb.table.Session): assert rs[0].rows[0].cnt == 1 pool.retry_operation_sync(check_transaction) + + +def test_truncated_response(driver_sync, table_name, table_path): + column_types = ydb.BulkUpsertColumns().add_column("id", ydb.PrimitiveType.Int64) + + rows = [] + + rows_count = 1100 + for i in range(rows_count): + rows.append({"id": i}) + + driver_sync.table_client.bulk_upsert(table_path, rows, column_types) + + table_client = ( + driver_sync.table_client + ) # default table client with driver's settings + s = table_client.session() + s.create() + t = s.transaction() + with pytest.raises(ydb.TruncatedResponseError): + t.execute("SELECT * FROM %s" % table_name) + + +def test_truncated_response_allow(driver_sync, table_name, table_path): + column_types = ydb.BulkUpsertColumns().add_column("id", ydb.PrimitiveType.Int64) + + rows = [] + + rows_count = 1100 + for i in range(rows_count): + rows.append({"id": i}) + + driver_sync.table_client.bulk_upsert(table_path, rows, column_types) + + table_client = ydb.TableClient( + driver_sync, ydb.TableClientSettings().with_allow_truncated_result(True) + ) + s = table_client.session() + s.create() + t = s.transaction() + result = t.execute("SELECT * FROM %s" % table_name) + assert result[0].truncated + assert len(result[0].rows) == 1000 diff --git a/ydb/convert.py b/ydb/convert.py index 70bc638e..567900a1 100644 --- a/ydb/convert.py +++ b/ydb/convert.py @@ -489,5 +489,13 @@ def __init__(self, result_sets_pb, table_client_settings=None): _ResultSet.from_message if not make_lazy else _ResultSet.lazy_from_message ) for result_set in result_sets_pb: - result_sets.append(initializer(result_set, table_client_settings)) + result_set = initializer(result_set, table_client_settings) + if ( + result_set.truncated + and not table_client_settings._allow_truncated_result + ): + raise issues.TruncatedResponseError( + "Response for the request was truncated by server" + ) + result_sets.append(result_set) super(ResultSets, self).__init__(result_sets) diff --git a/ydb/issues.py b/ydb/issues.py index 5a57f4d2..55c14cea 100644 --- a/ydb/issues.py +++ b/ydb/issues.py @@ -52,6 +52,10 @@ def __init__(self, message, issues=None): self.message = message +class TruncatedResponseError(Error): + status = None + + class ConnectionError(Error): status = None diff --git a/ydb/table.py b/ydb/table.py index eaee78ec..40431c62 100644 --- a/ydb/table.py +++ b/ydb/table.py @@ -1002,6 +1002,7 @@ def __init__(self): self._native_json_in_result_sets = False self._native_interval_in_result_sets = False self._native_timestamp_in_result_sets = False + self._allow_truncated_result = False def with_native_timestamp_in_result_sets(self, enabled): # type:(bool) -> ydb.TableClientSettings @@ -1038,6 +1039,11 @@ def with_lazy_result_sets(self, enabled): self._make_result_sets_lazy = enabled return self + def with_allow_truncated_result(self, enabled): + # type:(bool) -> ydb.TableClientSettings + self._allow_truncated_result = enabled + return self + class ScanQueryResult(object): def __init__(self, result, table_client_settings):