From 3cb1c3282a6582da3ac4d3734675270cb3eaf741 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Wed, 18 Jan 2023 19:54:53 +0300 Subject: [PATCH 001/147] start 3.0 beta branch --- CHANGELOG.md | 2 ++ setup.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 594eb330..f9103bc5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,5 @@ +* start 3.0 beta branch + ## 2.12.1 ## * Supported `TYPE_UNSPECIFIED` item type to scheme ls * Fixed error while request iam token with bad content type in metadata diff --git a/setup.py b/setup.py index ee69b6e7..06895627 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setuptools.setup( name="ydb", - version="2.12.1", # AUTOVERSION + version="3.0.0a0", # AUTOVERSION description="YDB Python SDK", author="Yandex LLC", author_email="ydb@yandex-team.ru", From 7a2ece5f8dd4098c502c73958f7daa3b8cbdc5d2 Mon Sep 17 00:00:00 2001 From: robot Date: Thu, 19 Jan 2023 06:50:55 +0000 Subject: [PATCH 002/147] Release: --- CHANGELOG.md | 1 + setup.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f9103bc5..9c6f0e7e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,4 @@ +## 3.0.1b1 ## * start 3.0 beta branch ## 2.12.1 ## diff --git a/setup.py b/setup.py index 06895627..992b0ee9 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setuptools.setup( name="ydb", - version="3.0.0a0", # AUTOVERSION + version="3.0.1b1", # AUTOVERSION description="YDB Python SDK", author="Yandex LLC", author_email="ydb@yandex-team.ru", From f9c5d13aa468c7e1fb982396a211de31a1ad4ef5 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Thu, 19 Jan 2023 13:42:07 +0300 Subject: [PATCH 003/147] typos --- .github/scripts/increment_version_test.py | 19 ++++++++++--------- .github/workflows/python-publish.yml | 2 +- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/.github/scripts/increment_version_test.py b/.github/scripts/increment_version_test.py index a1c8bbf4..ee921a62 100644 --- a/.github/scripts/increment_version_test.py +++ b/.github/scripts/increment_version_test.py @@ -7,16 +7,17 @@ [ ("0.0.0", 'patch', False, "0.0.1"), ("0.0.1", 'patch', False, "0.0.2"), - ("0.0.1a1", 'patch', False, "0.0.1"), - ("0.0.0", 'patch', True, "0.0.1a1"), - ("0.0.1", 'patch', True, "0.0.2a1"), - ("0.0.2a1", 'patch', True, "0.0.2a2"), + ("0.0.1b1", 'patch', False, "0.0.1"), + ("0.0.0", 'patch', True, "0.0.1b1"), + ("0.0.1", 'patch', True, "0.0.2b1"), + ("0.0.2b1", 'patch', True, "0.0.2b2"), ("0.0.1", 'minor', False, "0.1.0"), - ("0.0.1a1", 'minor', False, "0.1.0"), - ("0.1.0a1", 'minor', False, "0.1.0"), - ("0.1.0", 'minor', True, "0.2.0a1"), - ("0.1.0a1", 'minor', True, "0.1.0a2"), - ("0.1.1a1", 'minor', True, "0.2.0a1"), + ("0.0.1b1", 'minor', False, "0.1.0"), + ("0.1.0b1", 'minor', False, "0.1.0"), + ("0.1.0", 'minor', True, "0.2.0b1"), + ("0.1.0b1", 'minor', True, "0.1.0b2"), + ("0.1.1b1", 'minor', True, "0.2.0b1"), + ("3.0.0b1", 'patch', True, "3.0.0b2"), ] ) def test_increment_version(source, inc_type, with_beta, result): diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 38c30de3..f3762395 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -93,7 +93,7 @@ jobs: git config --global user.email "robot@umbrella"; git config --global user.name "robot"; - git commit -m "Release: $NEW_VERSION"; + git commit -m "Release: $TAG"; git tag "$TAG" git push && git push --tags From 3854f04dc7ffd35acca5908061d6a90202f352cb Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Fri, 20 Jan 2023 11:11:18 +0300 Subject: [PATCH 004/147] initial topic writer --- .gitignore | 3 +- AUTHORS | 3 +- examples/topic/reader_async_example.py | 172 +++++++ examples/topic/reader_example.py | 173 +++++++ examples/topic/writer_async_example.py | 116 +++++ examples/topic/writer_example.py | 116 +++++ requirements.txt | 7 +- test-requirements.txt | 8 +- tox.ini | 3 + ydb/__init__.py | 1 + ydb/_apis.py | 13 + ydb/_topic_reader/__init__.py | 1 + ydb/_topic_reader/topic_reader.py | 353 +++++++++++++ ydb/_topic_wrapper/__init__.py | 0 ydb/_topic_wrapper/common.py | 217 ++++++++ ydb/_topic_wrapper/control_plane.py | 14 + ydb/_topic_wrapper/reader.py | 117 +++++ ydb/_topic_wrapper/writer.py | 254 +++++++++ ydb/_topic_writer/__init__.py | 1 + ydb/_topic_writer/topic_writer.py | 294 +++++++++++ ydb/_topic_writer/topic_writer_asyncio.py | 425 +++++++++++++++ .../topic_writer_asyncio_test.py | 487 ++++++++++++++++++ ydb/_topic_writer/topic_writer_sync.py | 116 +++++ ydb/aio/connection.py | 8 + ydb/aio/driver.py | 9 +- ydb/driver.py | 6 + ydb/pool.py | 3 +- ydb/topic.py | 112 ++++ 28 files changed, 3020 insertions(+), 12 deletions(-) create mode 100644 examples/topic/reader_async_example.py create mode 100644 examples/topic/reader_example.py create mode 100644 examples/topic/writer_async_example.py create mode 100644 examples/topic/writer_example.py create mode 100644 ydb/_topic_reader/__init__.py create mode 100644 ydb/_topic_reader/topic_reader.py create mode 100644 ydb/_topic_wrapper/__init__.py create mode 100644 ydb/_topic_wrapper/common.py create mode 100644 ydb/_topic_wrapper/control_plane.py create mode 100644 ydb/_topic_wrapper/reader.py create mode 100644 ydb/_topic_wrapper/writer.py create mode 100644 ydb/_topic_writer/__init__.py create mode 100644 ydb/_topic_writer/topic_writer.py create mode 100644 ydb/_topic_writer/topic_writer_asyncio.py create mode 100644 ydb/_topic_writer/topic_writer_asyncio_test.py create mode 100644 ydb/_topic_writer/topic_writer_sync.py create mode 100644 ydb/topic.py diff --git a/.gitignore b/.gitignore index 12e29eac..45896947 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,5 @@ ydb.egg-info/ /.idea /tox /venv -/ydb_certs \ No newline at end of file +/ydb_certs +/tmp diff --git a/AUTHORS b/AUTHORS index 200343e3..69fee17e 100644 --- a/AUTHORS +++ b/AUTHORS @@ -1,4 +1,5 @@ -The following authors have created the source code of "YDB Python SDK" +The following authors have created the source code of "Yandex Database Python SDK" published and distributed by YANDEX LLC as the owner: Vitalii Gridnev +Timofey Koolin diff --git a/examples/topic/reader_async_example.py b/examples/topic/reader_async_example.py new file mode 100644 index 00000000..8a5b84f6 --- /dev/null +++ b/examples/topic/reader_async_example.py @@ -0,0 +1,172 @@ +import asyncio +import json +import time + +import ydb + + +async def connect(): + db = ydb.aio.Driver(connection_string="grpc://localhost:2135?database=/local", credentials=ydb.credentials.AnonymousCredentials()) + reader = ydb.TopicClientAsyncIO(db).topic_reader("/local/topic", consumer="consumer") + + +async def create_reader_and_close_with_context_manager(db: ydb.aio.Driver): + with ydb.TopicClientAsyncIO(db).topic_reader("/database/topic/path", consumer="consumer") as reader: + async for message in reader.messages(): + pass + + +async def print_message_content(reader: ydb.TopicReaderAsyncIO): + async for message in reader.messages(): + print("text", message.data.read().decode("utf-8")) + # await and async_commit need only for sync commit mode - for wait ack from servr + await reader.commit(message) + + +async def process_messages_batch_explicit_commit(reader: ydb.TopicReaderAsyncIO): + # Explicit commit example + async for batch in reader.batches(max_messages=100, timeout=2): + async with asyncio.TaskGroup() as tg: + for message in batch.messages: + tg.create_task(_process(message)) + + # wait complete of process all messages from batch be taskgroup context manager + # and commit complete batch + await reader.commit(batch) + + +async def process_messages_batch_context_manager_commit(reader: ydb.TopicReaderAsyncIO): + # Commit with context manager + async for batch in reader.batches(): + async with reader.commit_on_exit(batch), asyncio.TaskGroup() as tg: + for message in batch.messages: + tg.create_task(_process(message)) + + +async def get_message_with_timeout(reader: ydb.TopicReaderAsyncIO): + try: + message = await asyncio.wait_for(reader.receive_message(), timeout=1) + except TimeoutError: + print("Have no new messages in a second") + return + + print("mess", message.data) + + +async def get_all_messages_with_small_wait(reader: ydb.TopicReaderAsyncIO): + async for message in reader.messages(timeout=1): + await _process(message) + print("Have no new messages in a second") + + +async def get_a_message_from_external_loop(reader: ydb.TopicReaderAsyncIO): + for i in range(10): + try: + message = await asyncio.wait_for(reader.receive_message(), timeout=1) + except TimeoutError: + return + await _process(message) + + +async def get_one_batch_from_external_loop_async(reader: ydb.TopicReaderAsyncIO): + for i in range(10): + try: + batch = await asyncio.wait_for(reader.receive_batch(), timeout=2) + except TimeoutError: + return + + for message in batch.messages: + await _process(message) + await reader.commit(batch) + + +async def auto_deserialize_message(db: ydb.aio.Driver): + # async, batch work similar to this + + async with ydb.TopicClientAsyncIO(db).topic_reader("/database/topic/path", consumer="asd", deserializer=json.loads) as reader: + async for message in reader.messages(): + print(message.data.Name) # message.data replaces by json.loads(message.data) of raw message + reader.commit(message) + + +async def commit_batch_with_context(reader: ydb.TopicReaderAsyncIO): + async for batch in reader.batches(): + async with reader.commit_on_exit(batch): + for message in batch.messages: + if not batch.is_alive: + break + await _process(message) + + +async def handle_partition_stop(reader: ydb.TopicReaderAsyncIO): + async for message in reader.messages(): + time.sleep(1) # some work + if message.is_alive: + time.sleep(123) # some other work + await reader.commit(message) + + +async def handle_partition_stop_batch(reader: ydb.TopicReaderAsyncIO): + def process_batch(batch): + for message in batch.messages: + if not batch.is_alive: + # no reason work with expired batch + # go read next - good batch + return + await _process(message) + await reader.commit(batch) + + async for batch in reader.batches(): + process_batch(batch) + + +async def connect_and_read_few_topics(db: ydb.aio.Driver): + with ydb.TopicClientAsyncIO(db).topic_reader( + ["/database/topic/path", ydb.TopicSelector("/database/second-topic", partitions=3)]) as reader: + async for message in reader.messages(): + await _process(message) + await reader.commit(message) + + +async def handle_partition_graceful_stop_batch(reader: ydb.TopicReaderAsyncIO): + # no special handle, but batch will contain less than prefer count messages + async for batch in reader.batches(): + await _process(batch) + reader.commit(batch) + + +async def advanced_commit_notify(db: ydb.aio.Driver): + def on_commit(event: ydb.TopicReaderEvents.OnCommit) -> None: + print(event.topic) + print(event.offset) + + async with ydb.TopicClientAsyncIO(db).topic_reader("/local", + consumer="consumer", + commit_batch_time=4, + on_commit=on_commit) as reader: + async for message in reader.messages(): + await _process(message) + await reader.commit(message) + + +async def advanced_read_with_own_progress_storage(db: ydb.TopicReaderAsyncIO): + async def on_get_partition_start_offset(req: ydb.TopicReaderEvents.OnPartitionGetStartOffsetRequest) -> \ + ydb.TopicReaderEvents.OnPartitionGetStartOffsetResponse: + # read current progress from database + resp = ydb.TopicReaderEvents.OnPartitionGetStartOffsetResponse() + resp.start_offset = 123 + return resp + + async with ydb.TopicClient(db).topic_reader("/local/test", consumer="consumer", + on_get_partition_start_offset=on_get_partition_start_offset + ) as reader: + async for mess in reader.messages(): + await _process(mess) + # save progress to own database + + # no commit progress to topic service + # reader.commit(mess) + + +async def _process(msg): + raise NotImplementedError() diff --git a/examples/topic/reader_example.py b/examples/topic/reader_example.py new file mode 100644 index 00000000..0bb7bb8f --- /dev/null +++ b/examples/topic/reader_example.py @@ -0,0 +1,173 @@ +import json +import time + +import ydb + + +def connect(): + db = ydb.Driver(connection_string="grpc://localhost:2135?database=/local", credentials=ydb.credentials.AnonymousCredentials()) + reader = ydb.TopicClient(db).topic_reader("/local/topic", consumer="consumer") + + +def create_reader_and_close_with_context_manager(db: ydb.Driver): + with ydb.TopicClient(db).topic_reader("/database/topic/path", consumer="consumer", buffer_size_bytes=123) as reader: + for message in reader: + pass + + +def print_message_content(reader: ydb.TopicReader): + for message in reader.messages(): + print("text", message.data.read().decode("utf-8")) + reader.commit(message) + + +def process_messages_batch_explicit_commit(reader: ydb.TopicReader): + for batch in reader.batches(max_messages=100, timeout=2): + for message in batch.messages: + _process(message) + reader.commit(batch) + + +def process_messages_batch_context_manager_commit(reader: ydb.TopicReader): + for batch in reader.batches(max_messages=100, timeout=2): + with reader.commit_on_exit(batch): + for message in batch.messages: + _process(message) + + +def get_message_with_timeout(reader: ydb.TopicReader): + try: + message = reader.receive_message(timeout=1) + except TimeoutError: + print("Have no new messages in a second") + return + + print("mess", message.data) + + +def get_all_messages_with_small_wait(reader: ydb.TopicReader): + for message in reader.messages(timeout=1): + _process(message) + print("Have no new messages in a second") + + +def get_a_message_from_external_loop(reader: ydb.TopicReader): + for i in range(10): + try: + message = reader.receive_message(timeout=1) + except TimeoutError: + return + _process(message) + + +def get_one_batch_from_external_loop(reader: ydb.TopicReader): + for i in range(10): + try: + batch = reader.receive_batch(timeout=2) + except TimeoutError: + return + + for message in batch.messages: + _process(message) + reader.commit(batch) + + +def auto_deserialize_message(db: ydb.Driver): + # async, batch work similar to this + + reader = ydb.TopicClient(db).topic_reader("/database/topic/path", consumer="asd", deserializer=json.loads) + for message in reader.messages(): + print(message.data.Name) # message.data replaces by json.loads(message.data) of raw message + reader.commit(message) + + +def commit_batch_with_context(reader: ydb.TopicReader): + for batch in reader.batches(): + with reader.commit_on_exit(batch): + for message in batch.messages: + if not batch.is_alive: + break + _process(message) + + +def handle_partition_stop(reader: ydb.TopicReader): + for message in reader.messages(): + time.sleep(1) # some work + if message.is_alive: + time.sleep(123) # some other work + reader.commit(message) + + +def handle_partition_stop_batch(reader: ydb.TopicReader): + def process_batch(batch): + for message in batch.messages: + if not batch.is_alive: + # no reason work with expired batch + # go read next - good batch + return + _process(message) + reader.commit(batch) + + for batch in reader.batches(): + process_batch(batch) + + +def connect_and_read_few_topics(db: ydb.Driver): + with ydb.TopicClient(db).topic_reader(["/database/topic/path", ydb.TopicSelector("/database/second-topic", partitions=3)]) as reader: + for message in reader: + _process(message) + reader.commit(message) + + +def handle_partition_graceful_stop_batch(reader: ydb.TopicReader): + # no special handle, but batch will contain less than prefer count messages + for batch in reader.batches(): + _process(batch) + reader.commit(batch) + + +def advanced_commit_notify(db: ydb.Driver): + def on_commit(event: ydb.TopicReaderEvents.OnCommit) -> None: + print(event.topic) + print(event.offset) + + with ydb.TopicClient(db).topic_reader("/local", consumer="consumer", commit_batch_time=4, on_commit=on_commit) as reader: + for message in reader: + with reader.commit_on_exit(message): + _process(message) + + +def advanced_read_with_own_progress_storage(db: ydb.TopicReader): + def on_get_partition_start_offset(req: ydb.TopicReaderEvents.OnPartitionGetStartOffsetRequest) -> \ + ydb.TopicReaderEvents.OnPartitionGetStartOffsetResponse: + + # read current progress from database + resp = ydb.TopicReaderEvents.OnPartitionGetStartOffsetResponse() + resp.start_offset = 123 + return resp + + with ydb.TopicClient(db).topic_reader("/local/test", consumer="consumer", + on_get_partition_start_offset=on_get_partition_start_offset + ) as reader: + for mess in reader: + _process(mess) + # save progress to own database + + # no commit progress to topic service + # reader.commit(mess) + + +def get_current_statistics(reader: ydb.TopicReader): + # sync + stat = reader.sessions_stat() + print(stat) + + # with feature + f = reader.async_sessions_stat() + stat = f.result() + print(stat) + + +def _process(msg): + raise NotImplementedError() + diff --git a/examples/topic/writer_async_example.py b/examples/topic/writer_async_example.py new file mode 100644 index 00000000..4b26c702 --- /dev/null +++ b/examples/topic/writer_async_example.py @@ -0,0 +1,116 @@ +import asyncio +import json +import time +from typing import Dict, List, Set + +import ydb + + +async def create_writer(db: ydb.aio.Driver): + async with ydb.TopicClientAsyncIO(db).topic_writer("/database/topic/path", + producer_and_message_group_id="producer-id", + ) as writer: + pass + + +async def connect_and_wait(db: ydb.aio.Driver): + async with ydb.TopicClientAsyncIO(db).topic_writer("/database/topic/path", + producer_and_message_group_id="producer-id", + ) as writer: + writer.wait_init() + + +async def connect_without_context_manager(db: ydb.aio.Driver): + writer = ydb.TopicClientAsyncIO(db).topic_writer("/database/topic/path", + producer_and_message_group_id="producer-id", + ) + try: + pass # some code + finally: + await writer.close() + + +async def send_messages(writer: ydb.TopicWriterAsyncIO): + # simple str/bytes without additional metadata + await writer.write("mess") # send text + await writer.write(bytes([1, 2, 3])) # send bytes + await writer.write("mess-1", "mess-2") # send two messages + + # full forms + await writer.write(ydb.TopicWriterMessage("mess")) # send text + await writer.write(ydb.TopicWriterMessage(bytes([1, 2, 3]))) # send bytes + await writer.write(ydb.TopicWriterMessage("mess-1"), + ydb.TopicWriterMessage("mess-2")) # send few messages by one call + + # with meta + await writer.write(ydb.TopicWriterMessage("asd", seqno=123, created_at_ns=time.time_ns())) + + +async def send_message_without_block_if_internal_buffer_is_full(writer: ydb.TopicWriterAsyncIO, msg) -> bool: + try: + # put message to internal queue for send, but if buffer is full - fast return + # without wait + await asyncio.wait_for(writer.write(msg), 0) + return True + except TimeoutError(): + return False + + +def send_messages_with_manual_seqno(writer: ydb.TopicWriter): + await writer.write(ydb.TopicWriterMessage("mess")) # send text + + +async def send_messages_with_wait_ack(writer: ydb.TopicWriterAsyncIO): + # future wait + await writer.write_with_result(ydb.TopicWriterMessage("mess", seqno=1), ydb.TopicWriterMessage("mess", seqno=2)) + + # send with flush + await writer.write("1", "2", "3") + await writer.flush() + + +async def send_json_message(db: ydb.aio.Driver): + async with ydb.TopicClientAsyncIO(db).topic_writer("/database/path/topic", serializer=json.dumps) as writer: + writer.write({"a": 123}) + + +async def send_messages_and_wait_all_commit_with_flush(writer: ydb.TopicWriterAsyncIO): + for i in range(10): + await writer.write(ydb.TopicWriterMessage("%s" % i)) + await writer.flush() + + +async def send_messages_and_wait_all_commit_with_results(writer: ydb.TopicWriterAsyncIO): + last_future = None + for i in range(10): + content = "%s" % i + last_future = await writer.write_with_ack(content) + + await asyncio.wait(last_future) + if last_future.exception() is not None: + raise last_future.exception() + + +async def switch_messages_with_many_producers(writers: Dict[str, ydb.TopicWriterAsyncIO], messages: List[str]): + futures = [] # type: List[asyncio.Future] + + for msg in messages: + # select writer for the msg + writer_idx = msg[:1] + writer = writers[writer_idx] + future = await writer.write_with_ack(msg) + futures.append(future) + + # wait acks from all writes + await asyncio.wait(futures) + for future in futures: + if future.exception() is not None: + raise future.exception() + + # all ok, explicit return - for better + return + + +async def get_current_statistics(reader: ydb.TopicReaderAsyncIO): + stat = await reader.sessions_stat() + print(stat) diff --git a/examples/topic/writer_example.py b/examples/topic/writer_example.py new file mode 100644 index 00000000..99966791 --- /dev/null +++ b/examples/topic/writer_example.py @@ -0,0 +1,116 @@ +import concurrent.futures +import json +import time +from typing import Dict, List +from concurrent.futures import Future, wait + +import ydb + + +async def connect(): + db = ydb.aio.Driver(connection_string="grpc://localhost:2135?database=/local", credentials=ydb.credentials.AnonymousCredentials()) + reader = ydb.TopicClientAsyncIO(db).topic_writer("/local/topic", producer_and_message_group_id="producer-id", ) + + +def create_writer(db: ydb.Driver): + with ydb.TopicClient(db).topic_writer("/database/topic/path", + producer_and_message_group_id="producer-id", + ) as writer: + pass + + +def connect_and_wait(db: ydb.Driver): + with ydb.TopicClient(db).topic_writer("/database/topic/path", + producer_and_message_group_id="producer-id", + ) as writer: + writer.wait() + + +def connect_without_context_manager(db: ydb.Driver): + writer = ydb.TopicClient(db).topic_writer("/database/topic/path", + producer_and_message_group_id="producer-id", + ) + try: + pass # some code + finally: + await writer.close() + + +def send_messages(writer: ydb.TopicWriter): + # simple str/bytes without additional metadata + writer.write("mess") # send text + writer.write(bytes([1, 2, 3])) # send bytes + writer.write("mess-1", "mess-2") # send two messages + + # full forms + writer.write(ydb.TopicWriterMessage("mess")) # send text + writer.write(ydb.TopicWriterMessage(bytes([1, 2, 3]))) # send bytes + writer.write(ydb.TopicWriterMessage("mess-1"), ydb.TopicWriterMessage("mess-2")) # send few messages by one call + + # with meta + writer.write(ydb.TopicWriterMessage("asd", seqno=123, created_at_ns=time.time_ns())) + + +def send_message_without_block_if_internal_buffer_is_full(writer: ydb.TopicWriter, msg) -> bool: + try: + # put message to internal queue for send, but if buffer is full - fast return + # without wait + writer.write(msg, timeout=0) + return True + except TimeoutError(): + return False + + +def send_messages_with_manual_seqno(writer: ydb.TopicWriter): + writer.write(ydb.TopicWriterMessage("mess")) # send text + + +def send_messages_with_wait_ack(writer: ydb.TopicWriter): + # Explicit future wait + writer.async_write_with_ack(ydb.TopicWriterMessage("mess", seqno=1), ydb.TopicWriterMessage("mess", seqno=2)).result() + + # implicit, by sync call + writer.write_with_ack(ydb.TopicWriterMessage("mess", seqno=1), ydb.TopicWriterMessage("mess", seqno=2)) + # write_with_ack + + # send with flush + writer.write("1", "2", "3") + writer.flush() + + +def send_json_message(db: ydb.Driver): + with ydb.TopicClient(db).topic_writer("/database/path/topic", serializer=json.dumps) as writer: + writer.write({"a": 123}) + + +def send_messages_and_wait_all_commit_with_flush(writer: ydb.TopicWriter): + for i in range(10): + content = "%s" % i + writer.write(content) + writer.flush() + + +def send_messages_and_wait_all_commit_with_results(writer: ydb.TopicWriter): + futures = [] # type: List[concurrent.futures.Future] + for i in range(10): + future = writer.async_write_with_ack() + futures.append(future) + + concurrent.futures.wait(futures) + for future in futures: + if future.exception() is not None: + raise future.exception() + + +def switch_messages_with_many_producers(writers: Dict[str, ydb.TopicWriter], messages: List[str]): + futures = [] # type: List[Future] + + for msg in messages: + # select writer for the msg + writer_idx = msg[:1] + writer = writers[writer_idx] + future = writer.async_write_with_ack(msg) + futures.append(future) + + # wait acks from all writes + wait(futures) diff --git a/requirements.txt b/requirements.txt index dddc1b23..57470a28 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ -grpcio==1.39.0 +grpcio>=1.42.0 packaging -protobuf>3.13.0,<5.0.0 -pytest==6.2.4 -aiohttp==3.7.4 +protobuf>=3.13.0,<5.0.0 +aiohttp>=3.7.4,<4.0.0 diff --git a/test-requirements.txt b/test-requirements.txt index 627ad2f5..eb4cab95 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -10,8 +10,8 @@ docker==5.0.0 docker-compose==1.29.2 dockerpty==0.4.1 docopt==0.6.2 -grpcio -grpcio-tools +grpcio==1.42.0 +grpcio-tools==1.42.0 idna==3.2 importlib-metadata==4.6.1 iniconfig==1.1.1 @@ -19,7 +19,7 @@ jsonschema==3.2.0 packaging==21.0 paramiko==2.10.1 pluggy==0.13.1 -protobuf>3.13.0,<5.0.0 +protobuf==3.13.0 py==1.10.0 pycparser==2.20 PyNaCl==1.4.0 @@ -45,4 +45,4 @@ flake8==3.9.2 sqlalchemy==1.4.26 pylint-protobuf cython -grpcio-tools +freezegun==1.2.2 diff --git a/tox.ini b/tox.ini index 9ab926b0..1246209c 100644 --- a/tox.ini +++ b/tox.ini @@ -72,3 +72,6 @@ builtins = _ max-line-length = 160 ignore=E203,W503 exclude=*_pb2.py,*_grpc.py,.venv,.git,.tox,dist,doc,*egg,ydb/public/api/protos/*,docs/*,ydb/public/api/grpc/*,persqueue/*,client/*,dbapi/*,ydb/default_pem.py,*docs/conf.py + +[pytest] +asyncio_mode = auto diff --git a/ydb/__init__.py b/ydb/__init__.py index d8c23fee..7af0087b 100644 --- a/ydb/__init__.py +++ b/ydb/__init__.py @@ -12,6 +12,7 @@ from .scripting import * # noqa from .import_client import * # noqa from .tracing import * # noqa +from .topic import * # noqa try: import ydb.aio as aio # noqa diff --git a/ydb/_apis.py b/ydb/_apis.py index 8ee2d731..2820f482 100644 --- a/ydb/_apis.py +++ b/ydb/_apis.py @@ -16,6 +16,12 @@ from ydb._grpc.common.protos import ydb_common_pb2 from ydb._grpc.common import ydb_operation_v1_pb2_grpc +# Workaround for good IDE and universal runtime +if False: + from ydb._grpc.v4 import ydb_topic_v1_pb2_grpc +else: + from ydb._grpc.common import ydb_topic_v1_pb2_grpc + StatusIds = ydb_status_codes_pb2.StatusIds FeatureFlag = ydb_common_pb2.FeatureFlag @@ -74,3 +80,10 @@ class TableService(object): KeepAlive = "KeepAlive" StreamReadTable = "StreamReadTable" BulkUpsert = "BulkUpsert" + + +class TopicService(object): + Stub = ydb_topic_v1_pb2_grpc.TopicServiceStub + + StreamRead = "StreamRead" + StreamWrite = "StreamWrite" diff --git a/ydb/_topic_reader/__init__.py b/ydb/_topic_reader/__init__.py new file mode 100644 index 00000000..3aab85c2 --- /dev/null +++ b/ydb/_topic_reader/__init__.py @@ -0,0 +1 @@ +from .topic_reader import * diff --git a/ydb/_topic_reader/topic_reader.py b/ydb/_topic_reader/topic_reader.py new file mode 100644 index 00000000..9c40f5c3 --- /dev/null +++ b/ydb/_topic_reader/topic_reader.py @@ -0,0 +1,353 @@ +import abc +import concurrent.futures +import enum +import io +import datetime +from typing import Union, Optional, List, Mapping, Callable, Iterable, AsyncIterable, AsyncContextManager, \ + Any + + +class Selector: + path: str + partitions: Union[None, int, List[int]] + read_from_timestamp_ms: Optional[int] + max_time_lag_ms: Optional[int] + + def __init__(self, path, *, partitions: Union[None, int, List[int]] = None): + self.path = path + self.partitions = partitions + + +class ReaderAsyncIO(object): + async def __aenter__(self): + raise NotImplementedError() + + async def __aexit__(self, exc_type, exc_val, exc_tb): + raise NotImplementedError() + + async def sessions_stat(self) -> List["SessionStat"]: + """ + Receive stat from the server + + use asyncio.wait_for for wait with timeout. + """ + raise NotImplementedError() + + def messages(self, *, timeout: Union[float, None] = None) -> AsyncIterable["Message"]: + """ + Block until receive new message + + if no new messages in timeout seconds: stop iteration by raise StopAsyncIteration + """ + raise NotImplementedError() + + async def receive_message(self) -> Union["Message", None]: + """ + Block until receive new message + + use asyncio.wait_for for wait with timeout. + """ + raise NotImplementedError() + + def batches(self, *, max_messages: Union[int, None] = None, max_bytes: Union[int, None] = None, + timeout: Union[float, None] = None) -> AsyncIterable["Batch"]: + """ + Block until receive new batch. + All messages in a batch from same partition. + + if no new message in timeout seconds (default - infinite): stop iterations by raise StopIteration + """ + raise NotImplementedError() + + async def receive_batch(self, *, max_messages: Union[int, None] = None, max_bytes: Union[int, None]) -> Union["Batch", None]: + """ + Get one messages batch from reader. + All messages in a batch from same partition. + + use asyncio.wait_for for wait with timeout. + """ + raise NotImplementedError() + + async def commit_on_exit(self, mess: "ICommittable") -> AsyncContextManager: + """ + commit the mess match/message if exit from context manager without exceptions + + reader will close if exit from context manager with exception + """ + raise NotImplementedError() + + def commit(self, mess: "ICommittable"): + """ + Write commit message to a buffer. + + For the method no way check the commit result + (for example if lost connection - commits will not re-send and committed messages will receive again) + """ + raise NotImplementedError() + + async def commit_with_ack(self, mess: "ICommittable") -> Union["CommitResult", List["CommitResult"]]: + """ + write commit message to a buffer and wait ack from the server. + + use asyncio.wait_for for wait with timeout. + """ + raise NotImplementedError() + + async def flush(self): + """ + force send all commit messages from internal buffers to server and wait acks for all of them. + + use asyncio.wait_for for wait with timeout. + """ + raise NotImplementedError() + + async def close(self): + raise NotImplementedError() + + +class Reader(object): + def async_sessions_stat(self) -> concurrent.futures.Future: + """ + Receive stat from the server, return feature. + """ + raise NotImplementedError() + + async def sessions_stat(self) -> List["SessionStat"]: + """ + Receive stat from the server + + use async_sessions_stat for set explicit wait timeout + """ + raise NotImplementedError() + + def messages(self, *, timeout: Union[float, None] = None) -> Iterable["Message"]: + """ + todo? + + Block until receive new message + It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available. + + if no new message in timeout seconds (default - infinite): stop iterations by raise StopIteration + if timeout <= 0 - it will fast non block method, get messages from internal buffer only. + """ + raise NotImplementedError() + + def receive_message(self, *, timeout: Union[float, None] = None) -> "Message": + """ + Block until receive new message + It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available. + + if no new message in timeout seconds (default - infinite): raise TimeoutError() + if timeout <= 0 - it will fast non block method, get messages from internal buffer only. + """ + raise NotImplementedError() + + def async_wait_message(self) -> concurrent.futures.Future: + """ + Return future, which will completed when the reader has least one message in queue. + If reader already has message - future will return completed. + + Possible situation when receive signal about message available, but no messages when try to receive a message. + If message expired between send event and try to retrieve message (for example connection broken). + """ + raise NotImplementedError() + + def batches(self, *, max_messages: Union[int, None] = None, max_bytes: Union[int, None] = None, + timeout: Union[float, None] = None) -> Iterable["Batch"]: + """ + Block until receive new batch. + It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available. + + if no new message in timeout seconds (default - infinite): stop iterations by raise StopIteration + if timeout <= 0 - it will fast non block method, get messages from internal buffer only. + """ + raise NotImplementedError() + + def receive_batch(self, *, max_messages: Union[int, None] = None, max_bytes: Union[int, None], + timeout: Union[float, None] = None) -> Union["Batch", None]: + """ + Get one messages batch from reader + It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available. + + if no new message in timeout seconds (default - infinite): raise TimeoutError() + if timeout <= 0 - it will fast non block method, get messages from internal buffer only. + """ + raise NotImplementedError() + + def commit(self, mess: "ICommittable"): + """ + Put commit message to internal buffer. + + For the method no way check the commit result + (for example if lost connection - commits will not re-send and committed messages will receive again) + """ + raise NotImplementedError() + + def commit_with_ack(self, mess: "ICommittable") -> Union["CommitResult", List["CommitResult"]]: + """ + write commit message to a buffer and wait ack from the server. + + if receive in timeout seconds (default - infinite): raise TimeoutError() + """ + raise NotImplementedError() + + def async_commit_with_ack(self, mess: "ICommittable") -> Union["CommitResult", List["CommitResult"]]: + """ + write commit message to a buffer and return Future for wait result. + """ + raise NotImplementedError() + + def async_flush(self) -> concurrent.futures.Future: + """ + force send all commit messages from internal buffers to server and return Future for wait server acks. + """ + raise NotImplementedError() + + def flush(self): + """ + force send all commit messages from internal buffers to server and wait acks for all of them. + """ + raise NotImplementedError() + + def close(self): + raise NotImplementedError() + + +class ReaderSettings: + def __init__(self, *, + consumer: str, + buffer_size_bytes: int = 50 * 1024 * 1024, + on_commit: Callable[["OnCommitEvent"], None] = None, + on_get_partition_start_offset: Callable[ + ["OnPartitionGetStartOffsetRequest"], "OnPartitionGetStartOffsetResponse"] = None, + on_partition_session_start: Callable[["StubEvent"], None] = None, + on_partition_session_stop: Callable[["StubEvent"], None] = None, + on_partition_session_close: Callable[["StubEvent"], None] = None, # todo? + decoder: Union[Mapping[int, Callable[[bytes], bytes]], None] = None, + deserializer: Union[Callable[[bytes], Any], None] = None, + one_attempt_connection_timeout: Union[float, None] = 1, + connection_timeout: Union[float, None] = None, + retry_policy: Union["RetryPolicy", None] = None, + ): + raise NotImplementedError() + + +class ICommittable(abc.ABC): + @property + @abc.abstractmethod + def start_offset(self) -> int: + pass + + @property + @abc.abstractmethod + def end_offset(self) -> int: + pass + + +class ISessionAlive(abc.ABC): + @property + @abc.abstractmethod + def is_alive(self) -> bool: + pass + + +class Message(ICommittable, ISessionAlive): + seqno: int + created_at_ns: int + message_group_id: str + session_metadata: Mapping[str, str] + offset: int + written_at_ns: int + producer_id: int + data: Union[bytes, Any] # set as original decompressed bytes or deserialized object if deserializer set in reader + + def __init__(self): + self.seqno = -1 + self.created_at_ns = -1 + self.data = io.BytesIO() + + @property + def start_offset(self) -> int: + raise NotImplementedError() + + @property + def end_offset(self) -> int: + raise NotImplementedError() + + # ISessionAlive implementation + @property + def is_alive(self) -> bool: + raise NotImplementedError() + + +class Batch(ICommittable, ISessionAlive): + session_metadata: Mapping[str, str] + messages: List[Message] + + def __init__(self): + pass + + @property + def start_offset(self) -> int: + raise NotImplementedError() + + @property + def end_offset(self) -> int: + raise NotImplementedError() + + # ISessionAlive implementation + @property + def is_alive(self) -> bool: + raise NotImplementedError() + + +class Events: + class OnCommit: + topic: str + offset: int + + class OnPartitionGetStartOffsetRequest: + topic: str + partition_id: int + + class OnPartitionGetStartOffsetResponse: + start_offset: int + + class OnInitPartition: + pass + + class OnShutdownPatition: + pass + + +class RetryPolicy: + connection_timeout_sec: float + overload_timeout_sec: float + retry_access_denied: bool = False + + +class CommitResult: + topic: str + partition: int + offset: int + state: "CommitResult.State" + details: str # for humans only, content messages may be change in any time + + class State(enum.Enum): + UNSENT = 1 # commit didn't send to the server + SENT = 2 # commit was sent to server, but ack hasn't received + ACKED = 3 # ack from server is received + + +class SessionStat: + path: str + partition_id: str + partition_offsets: "OffsetRange" + committed_offset: int + write_time_high_watermark: datetime.datetime + write_time_high_watermark_timestamp_nano: int + + +class OffsetRange: + start: int + end: int diff --git a/ydb/_topic_wrapper/__init__.py b/ydb/_topic_wrapper/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ydb/_topic_wrapper/common.py b/ydb/_topic_wrapper/common.py new file mode 100644 index 00000000..bc5beb0b --- /dev/null +++ b/ydb/_topic_wrapper/common.py @@ -0,0 +1,217 @@ +import abc +import asyncio +import queue +import typing +from dataclasses import dataclass +from enum import Enum + +from google.protobuf.message import Message + +import ydb.aio + +# Workaround for good autocomplete in IDE and universal import at runtime +if False: + from ydb._grpc.v4.protos import ydb_status_codes_pb2, ydb_issue_message_pb2, ydb_topic_pb2 +else: + # noinspection PyUnresolvedReferences + from ydb._grpc.common.protos import ydb_status_codes_pb2, ydb_issue_message_pb2, ydb_topic_pb2 + + +class Codec(Enum): + CODEC_UNSPECIFIED = 0 + CODEC_RAW = 1 + CODEC_GZIP = 2 + CODEC_LZOP = 3 + CODEC_ZSTD = 4 + + +@dataclass +class OffsetsRange: + start: int + end: int + + +class IToProto(abc.ABC): + + @abc.abstractmethod + def to_proto(self) -> Message: + pass + + +class UnknownGrpcMessageError(ydb.Error): + pass + + +class IFromProto(abc.ABC): + @staticmethod + @abc.abstractmethod + def from_proto(msg: Message) -> typing.Any: + pass + + +class QueueToIteratorAsyncIO: + __slots__ = ("_queue",) + + def __init__(self, q: asyncio.Queue): + self._queue = q + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return await self._queue.get() + except asyncio.QueueEmpty: + raise StopAsyncIteration() + + +class AsyncQueueToSyncIteratorAsyncIO: + __slots__ = ( + "_loop", + "_queue", + ) + _queue: asyncio.Queue + + def __init__(self, q: asyncio.Queue): + self._loop = asyncio.get_running_loop() + self._queue = q + + def __iter__(self): + return self + + def __next__(self): + try: + res = asyncio.run_coroutine_threadsafe(self._queue.get(), self._loop).result() + return res + except asyncio.QueueEmpty: + raise StopIteration() + + +class SyncIteratorToAsyncIterator: + + def __init__(self, sync_iterator: typing.Iterator): + self._sync_iterator = sync_iterator + + def __aiter__(self): + return self + + async def __anext__(self): + try: + res = await asyncio.to_thread(self._sync_iterator.__next__) + return res + except StopAsyncIteration: + raise StopIteration() + + +class IteratorToQueueAsyncIO: + __slots__ = ("_iterator",) + + def __init__(self, iterator: typing.AsyncIterator[typing.Any]): + self._iterator = iterator + + async def get(self) -> typing.Any: + try: + return anext(self._iterator) + except StopAsyncIteration: + raise asyncio.QueueEmpty() + + +class IGrpcWrapperAsyncIO(abc.ABC): + @abc.abstractmethod + async def receive(self) -> typing.Any: ... + + @abc.abstractmethod + def write(self, wrap_message: IToProto): ... + + +SupportedDriverType = typing.Union[ydb.Driver, ydb.aio.Driver] + + +class GrpcWrapperAsyncIO(IGrpcWrapperAsyncIO): + from_client_grpc: asyncio.Queue + from_server_grpc: typing.AsyncIterator + convert_server_grpc_to_wrapper: typing.Callable[[typing.Any], typing.Any] + + def __init__(self, convert_server_grpc_to_wrapper): + self.from_client_grpc = asyncio.Queue() + self.convert_server_grpc_to_wrapper = convert_server_grpc_to_wrapper + + async def start(self, driver: SupportedDriverType, stub, method): + if asyncio.iscoroutinefunction(driver.__call__): + await self._start_asyncio_driver(driver, stub, method) + else: + await self._start_sync_driver(driver, stub, method) + + async def _start_asyncio_driver(self, driver: ydb.aio.Driver, stub, method): + requests_iterator = QueueToIteratorAsyncIO(self.from_client_grpc) + stream_call = await driver( + requests_iterator, + stub, + method, + ) + self.from_server_grpc = stream_call.__aiter__() + + async def _start_sync_driver(self, driver: ydb.Driver, stub, method): + requests_iterator = AsyncQueueToSyncIteratorAsyncIO(self.from_client_grpc) + stream_call = await asyncio.to_thread(driver, + requests_iterator, + stub, + method, + ) + self.from_server_grpc = SyncIteratorToAsyncIterator(stream_call.__iter__()) + + async def receive(self) -> typing.Any: + # todo handle grpc exceptions and convert it to internal exceptions + grpc_item = await self.from_server_grpc.__anext__() + return self.convert_server_grpc_to_wrapper(grpc_item) + + def write(self, wrap_message: IToProto): + self.from_client_grpc.put_nowait(wrap_message.to_proto()) + + +@dataclass(init=False) +class ServerStatus(IFromProto): + __slots__ = ("status", "_issues") + + def __init__(self, + status: ydb_status_codes_pb2.StatusIds.StatusCode, + issues: typing.Iterable[ydb_issue_message_pb2.IssueMessage]): + self.status = status + self._issues = issues + + def __str__(self): + return self.__repr__() + + @staticmethod + def from_proto(msg: Message) -> "ServerStatus": + return ServerStatus( + msg.status + ) + + def is_success(self) -> bool: + return self.status == ydb_status_codes_pb2.StatusIds.SUCCESS + + @classmethod + def issue_to_str(cls, issue: ydb_issue_message_pb2.IssueMessage): + res = """code: %s message: "%s" """ % (issue.issue_code, issue.message) + if len(issue.issues) > 0: + d = ", " + res += d + d.join(str(sub_issue) for sub_issue in issue.issues) + return res + + +@dataclass +class UpdateTokenRequest(IToProto): + token: str + + def to_proto(self) -> Message: + res = ydb_topic_pb2.UpdateTokenRequest() + res.token = self.token + return res + + +@dataclass +class UpdateTokenResponse(IFromProto): + @staticmethod + def from_proto(msg: ydb_topic_pb2.UpdateTokenResponse) -> typing.Any: + return UpdateTokenResponse() diff --git a/ydb/_topic_wrapper/control_plane.py b/ydb/_topic_wrapper/control_plane.py new file mode 100644 index 00000000..b8bbdff0 --- /dev/null +++ b/ydb/_topic_wrapper/control_plane.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass +from typing import Union, List + + +@dataclass +class CreateTopicRequest: + path: str + consumers: Union[List["Consumer"], None] = None + + +@dataclass +class Consumer: + name: str + diff --git a/ydb/_topic_wrapper/reader.py b/ydb/_topic_wrapper/reader.py new file mode 100644 index 00000000..51a21c49 --- /dev/null +++ b/ydb/_topic_wrapper/reader.py @@ -0,0 +1,117 @@ +import abc +import datetime +import typing +from codecs import Codec +from dataclasses import dataclass, field +from typing import List, Union, Dict + +from google.protobuf.message import Message + +from ydb._topic_wrapper.common import OffsetsRange + + +class StreamReadMessage: + @dataclass + class PartitionSession: + partition_session_id: int + path: str + partition_id: int + + @dataclass + class InitRequest: + topics_read_settings: List["TopicReadSettings"] + consumer: str + + @dataclass + class TopicReadSettings: + path: str + partition_ids: List[int] = field(default_factory=list) + max_lag_seconds: Union[float, None] = None + read_from: Union[int, float, datetime.datetime, None] = None + + @dataclass + class InitResponse: + session_id: str + + @dataclass + class ReadRequest: + bytes_size: int + + @dataclass + class ReadResponse: + partition_data: List["PartitionData"] + bytes_size: int + + @dataclass + class MessageData: + offset: int + seq_no: int + created_at: float # unix timestamp + data: bytes + uncompresed_size: int + message_group_id: str + + @dataclass + class Batch: + message_data: List["MessageData"] + producer_id: str + write_session_meta: Dict[str, str] + codec: int + written_at: float # unix timestamp + + @dataclass + class PartitionData: + partition_session_id: int + batches: List["Batch"] + + @dataclass + class CommitOffsetRequest: + commit_offsets: List["PartitionCommitOffset"] + + @dataclass + class PartitionCommitOffset: + partition_session_id: int + offsets: List[OffsetsRange] + + @dataclass + class CommitOffsetResponse: + partitions_committed_offsets: List["PartitionCommittedOffset"] + + @dataclass + class PartitionCommittedOffset: + partition_session_id: int + committed_offset: int + + @dataclass + class PartitionSessionStatusRequest: + partition_session_id: int + + @dataclass + class PartitionSessionStatusResponse: + partition_session_id: int + partition_offsets: OffsetsRange + committed_offset: int + write_time_high_watermark: float + + @dataclass + class StartPartitionSessionRequest: + partition_session: "PartitionSession" + committed_offset: int + partition_offsets: OffsetsRange + + @dataclass + class StartPartitionSessionResponse: + partition_session_id: int + read_offset: int + commit_offset: int + + @dataclass + class StopPartitionSessionRequest: + partition_session_id: int + graceful: bool + committed_offset: int + + @dataclass + class StopPartitionSessionResponse: + partition_session_id: int + diff --git a/ydb/_topic_wrapper/writer.py b/ydb/_topic_wrapper/writer.py new file mode 100644 index 00000000..784dc711 --- /dev/null +++ b/ydb/_topic_wrapper/writer.py @@ -0,0 +1,254 @@ +import asyncio +import datetime +import enum +import typing +from dataclasses import dataclass, field +from typing import Union + +from google.protobuf.message import Message + +from ydb._topic_wrapper.common import IToProto, IFromProto, ServerStatus, UpdateTokenRequest, UpdateTokenResponse, \ + UnknownGrpcMessageError + +# Workaround for good autocomplete in IDE and universal import at runtime +if False: + from ydb._grpc.v4.protos import ydb_topic_pb2 +else: + from ydb._grpc.common.protos import ydb_topic_pb2 + + +class StreamWriteMessage: + @dataclass() + class InitRequest(IToProto): + path: str + producer_id: str + write_session_meta: typing.Dict[str, str] + partitioning: "StreamWriteMessage.PartitioningType" + get_last_seq_no: bool + + def to_proto(self) -> ydb_topic_pb2.StreamWriteMessage.InitRequest: + proto = ydb_topic_pb2.StreamWriteMessage.InitRequest() + proto.path = self.path + proto.producer_id = self.producer_id + + if self.partitioning is None: + pass + elif isinstance(self.partitioning, StreamWriteMessage.PartitioningMessageGroupID): + proto.message_group_id = self.partitioning.message_group_id + elif isinstance(self.partitioning, StreamWriteMessage.PartitioningPartitionID): + proto.partition_id = self.partitioning.partition_id + else: + raise Exception("Bad partitioning type at StreamWriteMessage.InitRequest") + + if self.write_session_meta: + for key in self.write_session_meta: + proto.write_session_meta[key] = self.write_session_meta[key] + + proto.get_last_seq_no = self.get_last_seq_no + return proto + + @dataclass + class InitResponse(IFromProto): + last_seq_no: Union[int, None] + session_id: str + partition_id: int + supported_codecs: typing.List[int] + status: ServerStatus = None + + @staticmethod + def from_proto(msg: ydb_topic_pb2.StreamWriteMessage.InitResponse) -> "StreamWriteMessage.InitResponse": + codecs = [] # type: typing.List[int] + if msg.supported_codecs: + for codec in msg.supported_codecs.codecs: + codecs.append(codec) + + return StreamWriteMessage.InitResponse( + last_seq_no=msg.last_seq_no, + session_id=msg.session_id, + partition_id=msg.partition_id, + supported_codecs=codecs + ) + + @dataclass + class WriteRequest(IToProto): + messages: typing.List["StreamWriteMessage.WriteRequest.MessageData"] + codec: int + + @dataclass + class MessageData(IToProto): + seq_no: int + created_at: datetime.datetime + data: bytes + uncompressed_size: int + partitioning: "StreamWriteMessage.PartitioningType" + + def to_proto(self) -> ydb_topic_pb2.StreamWriteMessage.WriteRequest.MessageData: + proto = ydb_topic_pb2.StreamWriteMessage.WriteRequest.MessageData() + proto.seq_no = self.seq_no + proto.created_at.FromDatetime(self.created_at) + proto.data = self.data + proto.uncompressed_size = self.uncompressed_size + + if self.partitioning is None: + pass + elif isinstance(self.partitioning, StreamWriteMessage.PartitioningPartitionID): + proto.partition_id = self.partitioning.partition_id + elif isinstance(self.partitioning, StreamWriteMessage.PartitioningMessageGroupID): + proto.message_group_id = self.partitioning.message_group_id + else: + raise Exception("Bad partition at StreamWriteMessage.WriteRequest.MessageData") + + return proto + + def to_proto(self) -> ydb_topic_pb2.StreamWriteMessage.WriteRequest: + proto = ydb_topic_pb2.StreamWriteMessage.WriteRequest() + proto.codec = self.codec + + for message in self.messages: + proto_mess = proto.messages.add() + proto_mess.CopyFrom(message.to_proto()) + + return proto + + @dataclass + class WriteResponse(IFromProto): + partition_id: int + acks: typing.List["StreamWriteMessage.WriteResponse.WriteAck"] + write_statistics: "StreamWriteMessage.WriteResponse.WriteStatistics" + status: ServerStatus = field(default=None) + + @staticmethod + def from_proto(msg: ydb_topic_pb2.StreamWriteMessage.WriteResponse) -> "StreamWriteMessage.WriteResponse": + acks = [] + for proto_ack in msg.acks: + ack = StreamWriteMessage.WriteResponse.WriteAck.from_proto(proto_ack) + acks.append(ack) + write_statistics = StreamWriteMessage.WriteResponse.WriteStatistics( + persisting_time=msg.write_statistics.persisting_time.ToTimedelta(), + min_queue_wait_time=msg.write_statistics.min_queue_wait_time.ToTimedelta(), + max_queue_wait_time=msg.write_statistics.max_queue_wait_time.ToTimedelta(), + partition_quota_wait_time=msg.write_statistics.partition_quota_wait_time.ToTimedelta(), + topic_quota_wait_time=msg.write_statistics.topic_quota_wait_time.ToTimedelta(), + ) + return StreamWriteMessage.WriteResponse( + partition_id=msg.partition_id, + acks=acks, + write_statistics=write_statistics, + status=None, + ) + + @dataclass + class WriteAck(IFromProto): + seq_no: int + message_write_status: Union[ + "StreamWriteMessage.WriteResponse.WriteAck.StatusWritten", + "StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped", + int + ] + + @classmethod + def from_proto(cls, proto_ack: ydb_topic_pb2.StreamWriteMessage.WriteResponse.WriteAck): + if proto_ack.HasField("written"): + message_write_status = StreamWriteMessage.WriteResponse.WriteAck.StatusWritten( + proto_ack.written.offset + ) + elif proto_ack.HasField("skipped"): + reason = proto_ack.skipped.reason + try: + message_write_status = StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped( + reason=StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped.Reason.from_protobuf_code(reason) + ) + except ValueError: + message_write_status = reason + else: + raise NotImplementedError("unexpected ack status") + + return StreamWriteMessage.WriteResponse.WriteAck( + seq_no=proto_ack.seq_no, + message_write_status=message_write_status, + ) + + @dataclass + class StatusWritten: + offset: int + + @dataclass + class StatusSkipped: + reason: "StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped.Reason" + + class Reason(enum.Enum): + UNSPECIFIED = 0 + ALREADY_WRITTEN = 1 + + @classmethod + def from_protobuf_code(cls, code: int) -> Union[ + "StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped.Reason", + int + ]: + try: + return StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped.Reason(code) + except ValueError: + return code + + @dataclass + class WriteStatistics: + persisting_time: datetime.timedelta + min_queue_wait_time: datetime.timedelta + max_queue_wait_time: datetime.timedelta + partition_quota_wait_time: datetime.timedelta + topic_quota_wait_time: datetime.timedelta + + @dataclass + class PartitioningMessageGroupID: + message_group_id: str + + @dataclass + class PartitioningPartitionID: + partition_id: int + + PartitioningType = Union[PartitioningMessageGroupID, PartitioningPartitionID, None] + + @dataclass + class FromClient(IToProto): + value: "WriterMessagesFromClientToServer" + + def __init__(self, value: "WriterMessagesFromClientToServer"): + self.value = value + + def to_proto(self) -> Message: + res = ydb_topic_pb2.StreamWriteMessage.FromClient() + value = self.value + if isinstance(value, StreamWriteMessage.WriteRequest): + res.write_request.CopyFrom(value.to_proto()) + elif isinstance(value, StreamWriteMessage.InitRequest): + res.init_request.CopyFrom(value.to_proto()) + elif isinstance(value, UpdateTokenRequest): + res.update_token_request.CopyFrom(value.to_proto()) + else: + raise Exception("Unknown outcoming grpc message: %s" % value) + return res + + class FromServer(IFromProto): + @staticmethod + def from_proto(msg: ydb_topic_pb2.StreamWriteMessage.FromServer) -> typing.Any: + message_type = msg.WhichOneof("server_message") + if message_type == "write_response": + res = StreamWriteMessage.WriteResponse.from_proto(msg.write_response) + elif message_type == "init_response": + res = StreamWriteMessage.InitResponse.from_proto(msg.init_response) + elif message_type == "update_token_response": + res = UpdateTokenResponse.from_proto(msg.update_token_response) + else: + # todo log instead of exception - for allow add messages in the future + raise UnknownGrpcMessageError("Unexpected proto message: %s" % msg) + + res.status = ServerStatus(msg.status, msg.issues) + return res + + +WriterMessagesFromClientToServer = Union[ + StreamWriteMessage.InitRequest, StreamWriteMessage.WriteRequest, UpdateTokenRequest +] +WriterMessagesFromServerToClient = Union[ + StreamWriteMessage.InitResponse, StreamWriteMessage.WriteResponse, UpdateTokenResponse +] diff --git a/ydb/_topic_writer/__init__.py b/ydb/_topic_writer/__init__.py new file mode 100644 index 00000000..87216032 --- /dev/null +++ b/ydb/_topic_writer/__init__.py @@ -0,0 +1 @@ +from .topic_writer import * diff --git a/ydb/_topic_writer/topic_writer.py b/ydb/_topic_writer/topic_writer.py new file mode 100644 index 00000000..0e7231c2 --- /dev/null +++ b/ydb/_topic_writer/topic_writer.py @@ -0,0 +1,294 @@ +import asyncio +import concurrent.futures +import datetime +import enum +import time +from dataclasses import dataclass +from enum import Enum +from typing import List, Union, TextIO, BinaryIO, Optional, Callable, Mapping, Any, Dict + +import typing + +import ydb.aio +from .._topic_wrapper.common import IToProto, Codec +from .._topic_wrapper.writer import StreamWriteMessage + + +class Writer: + @property + def last_seqno(self) -> int: + raise NotImplemented() + + def __init__(self, db: ydb.Driver): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def close(self): + pass + + MessageType = typing.Union["PublicMessage", "Message.SimpleMessageSourceType"] + + def write(self, message: Union[MessageType, List[MessageType]], *args: Optional[MessageType], + timeout: [float, None] = None): + """ + send one or number of messages to server. + it fast put message to internal buffer, without wait message result + return None + + message will send independent of wait/no wait result + + timeout - time for waiting for put message into internal queue. + if 0 or negative - non block calls + if None or not set - infinite wait + It will raise TimeoutError() exception if it can't put message to internal queue by limits during timeout. + """ + raise NotImplementedError() + + def async_write_with_ack(self, message: Union[MessageType, List[MessageType]], *args: Optional[MessageType], + timeout: [float, None] = None) -> concurrent.futures.Future: + """ + send one or number of messages to server. + return feature, which can be waited for check send result: ack/duplicate/error + + Usually it is fast method, but can wait if internal buffer is full. + + timeout - time for waiting for put message into internal queue. + The method can be blocked up to timeout seconds before return future. + + if 0 or negative - non block calls + if None or not set - infinite wait + It will raise TimeoutError() exception if it can't put message to internal queue by limits during timeout. + """ + raise NotImplementedError() + + def write_with_ack(self, message: Union[MessageType, List[MessageType]], *args: Optional[MessageType], + buffer_timeout: [float, None] = None) -> Union[ + "MessageWriteStatus", List["MessageWriteStatus"]]: + """ + IT IS SLOWLY WAY. IT IS BAD CHOISE IN MOST CASES. + It is recommended to use write with optionally flush or async_write_with_ack and receive acks by wait future. + + send one or number of messages to server. + blocked until receive server ack for the message/messages. + + message will send independent of wait/no wait result + + buffer_timeout - time for send message to server and receive ack. + if 0 or negative - non block calls + if None or not set - infinite wait + It will raise TimeoutError() exception if it isn't receive ack in timeout + """ + raise NotImplementedError() + + def async_flush(self): + """ + Force send all messages from internal buffer and wait acks from server for all + messages. + + flush starts of flush process, and return Future for wait result. + messages will be flushed independent of future waiting. + """ + raise NotImplementedError() + + def flush(self, timeout: Union[float, None] = None) -> concurrent.futures.Future: + """ + Force send all messages from internal buffer and wait acks from server for all + messages. + + timeout - time for waiting for send all messages and receive server ack. + if 0 or negative - non block calls + if None or not set - infinite wait + It will raise TimeoutError() exception if it isn't receive ack in timeout + """ + raise NotImplementedError() + + def async_wait_init(self) -> concurrent.futures.Future: + """ + Return feature, which done when underling connection established + """ + raise NotImplementedError() + + def wait_init(self, timeout: Union[float, None] = None): + """ + Wait until underling connection established + + timeout - time for waiting for send all messages and receive server ack. + if 0 or negative - non block calls + if None or not set - infinite wait + It will raise TimeoutError() exception if it isn't receive ack in timeout + """ + raise NotImplementedError() + + +@dataclass +class PublicWriterSettings: + topic: str + producer_and_message_group_id: str + session_metadata: Optional[Dict[str, str]] = None + encoders: Union[Mapping[int, Callable[[bytes], bytes]], None] = None + serializer: Union[Callable[[Any], bytes], None] = None + send_buffer_count: Union[int, None] = 10000 + send_buffer_bytes: Union[int, None] = 100 * 1024 * 1024 + partition_id: Optional[int] = None + codec: Union[int, None] = None + codec_autoselect: bool = True + auto_seqno: bool = True + auto_created_at: bool = True + get_last_seqno: bool = False + retry_policy: Union["RetryPolicy", None] = None + update_token_interval: Union[int, float] = 3600 + + +@dataclass +class PublicWriteResult: + @dataclass(eq=True) + class Written: + __slots__ = ( + "offset" + ) + offset: int + + @dataclass(eq=True) + class Skipped: + pass + + +class WriterSettings(PublicWriterSettings): + def __init__(self, settings: PublicWriterSettings): + self.__dict__ = settings.__dict__.copy() + + def create_init_request(self) -> StreamWriteMessage.InitRequest: + return StreamWriteMessage.InitRequest( + path=self.topic, + producer_id=self.producer_and_message_group_id, + write_session_meta=self.session_metadata, + partitioning=self.get_partitioning(), + get_last_seq_no=self.get_last_seqno, + ) + + def get_partitioning(self) -> StreamWriteMessage.PartitioningType: + if self.partition_id is not None: + return StreamWriteMessage.PartitioningPartitionID(self.partition_id) + return StreamWriteMessage.PartitioningMessageGroupID(self.producer_and_message_group_id) + + +class SendMode(Enum): + ASYNC = 1 + SYNC = 2 + + +@dataclass +class PublicWriterInitInfo: + __slots__ = ( + "last_seqno" + ) + last_seqno: Optional[int] + + +class PublicMessage: + seqno: Optional[int] + created_at: Optional[datetime.datetime] + data: Union[str, bytes, TextIO, BinaryIO] + + SimpleMessageSourceType = Union[str, bytes, TextIO, BinaryIO] + + def __init__(self, + data: SimpleMessageSourceType, *, + seqno: Optional[int] = None, + created_at: Optional[datetime.datetime] = None, + ): + self.seqno = seqno + self.created_at = created_at + self.data = data + + +class InternalMessage(StreamWriteMessage.WriteRequest.MessageData, IToProto): + def __init__(self, mess: PublicMessage): + StreamWriteMessage.WriteRequest.MessageData.__init__( + self, + seq_no=mess.seqno, + created_at=mess.created_at, + data=mess.data, + uncompressed_size=len(mess.data), + partitioning = None, + ) + + def get_bytes(self) -> bytes: + if self.data is None: + return bytes() + if isinstance(self.data, bytes): + return self.data + if isinstance(self.data, str): + return self.data.encode("utf-8") + raise ValueError("Bad data type") + + def to_message_data(self) -> StreamWriteMessage.WriteRequest.MessageData: + data = self.get_bytes() + return StreamWriteMessage.WriteRequest.MessageData( + seq_no=self.seq_no, + created_at=self.created_at, + data=data, + uncompressed_size=len(data), + partitioning=None, # unsupported by server now + ) + + +class MessageSendResult: + offset: Union[None, int] + write_status: "MessageWriteStatus" + + +class MessageWriteStatus(enum.Enum): + Written = 1 + AlreadyWritten = 2 + + +class RetryPolicy: + connection_timeout_sec: float + overload_timeout_sec: float + retry_access_denied: bool = False + + +class TopicWriterError(ydb.Error): + def __init__(self, message: str): + super(TopicWriterError, self).__init__(message) + + +class TopicWriterRepeatableError(TopicWriterError): + pass + + +class TopicWriterStopped(TopicWriterError): + def __init__(self): + super(TopicWriterStopped, self).__init__("topic writer was stopped by call close") + + +def default_serializer_message_content(data: Any) -> bytes: + if data is None: + return bytes() + if isinstance(data, bytes): + return data + if isinstance(data, bytearray): + return bytes(data) + if isinstance(data, str): + return data.encode(encoding='utf-8') + raise ValueError("can't serialize type %s to bytes" % type(data)) + + +def messages_to_proto_requests(messages: List[InternalMessage]) -> List[StreamWriteMessage.FromClient]: + # todo split by proto message size and codec + res = [] + for msg in messages: + req = StreamWriteMessage.FromClient( + StreamWriteMessage.WriteRequest( + messages=[msg.to_message_data()], + codec=Codec.CODEC_RAW.value, + ) + ) + res.append(req) + return res diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py new file mode 100644 index 00000000..daa4705a --- /dev/null +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -0,0 +1,425 @@ +import asyncio +import threading +from collections import deque +from typing import Dict, Awaitable, Deque, AsyncIterator + +import ydb +from .topic_writer import * +from .. import _apis, YDB_AUTH_TICKET_HEADER, issues, check_retriable_error, RetrySettings +from .._topic_wrapper.common import UpdateTokenResponse, UpdateTokenRequest, QueueToIteratorAsyncIO, Codec, \ + GrpcWrapperAsyncIO, IGrpcWrapperAsyncIO, SupportedDriverType +from .._topic_wrapper.writer import StreamWriteMessage, WriterMessagesFromServerToClient + +# Workaround for good autocomplete in IDE and universal import at runtime +if False: + from .._grpc.v4.protos import ydb_topic_pb2 +else: + # noinspection PyUnresolvedReferences + from .._grpc.common.protos import ydb_topic_pb2 + + +class WriterAsyncIO: + _loop: asyncio.AbstractEventLoop + _reconnector: "WriterAsyncIOReconnector" + _lock: threading.Lock + _closed: bool + + @property + def last_seqno(self) -> int: + raise NotImplemented() + + def __init__(self, driver: SupportedDriverType, settings: PublicWriterSettings): + self._loop = asyncio.get_running_loop() + self._closed = False + self._reconnector = WriterAsyncIOReconnector(driver=driver, settings=WriterSettings(settings)) + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + + def __del__(self): + if self._closed or self._loop.is_closed(): + return + + self._loop.call_soon(self.close) + + async def close(self): + with self._lock: + if self._closed: + return + self._closed = True + + await self._reconnector.close() + + async def write_with_ack(self, + messages: Union[Writer.MessageType, List[Writer.MessageType]], + *args: Optional[Writer.MessageType], + ) -> Union[PublicWriteResult, List[PublicWriteResult]]: + """ + IT IS SLOWLY WAY. IT IS BAD CHOISE IN MOST CASES. + It is recommended to use write with optionally flush or write_with_ack_futures and receive acks by wait futures. + + send one or number of messages to server and wait acks. + + For wait with timeout use asyncio.wait_for. + """ + if isinstance(messages, PublicMessage): + futures = await self._reconnector.write_with_ack([messages]) + return await futures[0] + if isinstance(messages, list): + for m in messages: + if not isinstance(m, PublicMessage): + raise NotImplementedError() + + futures = await self._reconnector.write_with_ack(messages) + await asyncio.wait(futures) + + results = [f.result() for f in futures] + return results + + raise NotImplementedError() + + async def write_with_ack_future(self, + messages: Union[Writer.MessageType, List[Writer.MessageType]], + *args: Optional[Writer.MessageType], + ) -> Union[asyncio.Future, List[asyncio.Future]]: + """ + send one or number of messages to server. + return feature, which can be waited for check send result. + + Usually it is fast method, but can wait if internal buffer is full. + + For wait with timeout use asyncio.wait_for. + """ + if isinstance(messages, PublicMessage): + futures = await self._reconnector.write_with_ack([messages]) + return futures[0] + if isinstance(messages, list): + for m in messages: + if not isinstance(m, PublicMessage): + raise NotImplementedError() + return await self._reconnector.write_with_ack(messages) + raise NotImplementedError() + + async def write(self, + messages: Union[Writer.MessageType, List[Writer.MessageType]], + *args: Optional[Writer.MessageType], + ): + """ + send one or number of messages to server. + it put message to internal buffer + + For wait with timeout use asyncio.wait_for. + """ + await self.write_with_ack_future(messages) + + async def flush(self): + """ + Force send all messages from internal buffer and wait acks from server for all + messages. + + For wait with timeout use asyncio.wait_for. + """ + raise NotImplementedError() + + async def wait_init(self) -> PublicWriterInitInfo: + """ + wait while real connection will be established to server. + + For wait with timeout use asyncio.wait_for() + """ + return await self._reconnector.wait_init() + + +class WriterAsyncIOReconnector: + _credentials: Union[ydb.Credentials, None] + _driver: ydb.aio.Driver + _update_token_interval: int + _token_get_function: "TokenGetter" + _init_message: StreamWriteMessage.InitRequest + _new_messages: asyncio.Queue + _init_info: asyncio.Future + _stream_connected: asyncio.Event + _settings: WriterSettings + + _lock: asyncio.Lock + _last_known_seq_no: int + _messages: Deque[InternalMessage] + _messages_future: Deque[asyncio.Future] + _stop_reason: Optional[Exception] + _background_tasks: List[asyncio.Task] + + def __init__(self, driver: SupportedDriverType, settings: WriterSettings): + self._driver = driver + self._credentials = driver._credentials + self._init_message = settings.create_init_request() + self._new_messages = asyncio.Queue() + self._init_info = asyncio.Future() + self._stream_connected = asyncio.Event() + self._settings = settings + + self._lock = asyncio.Lock() + self._last_known_seq_no = 0 + self._messages = deque() + self._messages_future = deque() + self._stop_reason = None + self._background_tasks = [ + asyncio.create_task(self._connection_loop(), name="connection_loop") + ] + + async def close(self): + await self._check_stop() + await self._stop(TopicWriterStopped()) + + async def wait_init(self) -> PublicWriterInitInfo: + return await self._init_info + + async def write_with_ack(self, messages: List[PublicMessage]) -> List[asyncio.Future]: + # todo check internal buffer limit + await self._check_stop() + + if self._settings.auto_seqno: + await self.wait_init() + + async with self._lock: + internal_messages = self._prepare_internal_messages_locked(messages) + messages_future = [asyncio.Future() for _ in internal_messages] + + self._messages.extend(internal_messages) + self._messages_future.extend(messages_future) + + for m in internal_messages: + self._new_messages.put_nowait(m) + + return messages_future + + def _prepare_internal_messages_locked(self, messages: List[PublicMessage]): + if self._settings.auto_created_at: + now = datetime.datetime.now() + else: + now = None + + res = [] + for m in messages: + internal_message = InternalMessage(m) + if self._settings.auto_seqno: + if internal_message.seq_no is None: + self._last_known_seq_no += 1 + internal_message.seq_no = self._last_known_seq_no + else: + raise TopicWriterError("Explicit seqno and auto_seq setting is mutual exclusive") + else: + if internal_message.seq_no is None or internal_message.seq_no == 0: + raise TopicWriterError("Empty seqno and auto_seq setting is disabled") + elif internal_message.seq_no <= self._last_known_seq_no: + raise TopicWriterError("Message seqno is duplicated: %s" % internal_message.seq_no) + else: + self._last_known_seq_no = internal_message.seq_no + + if self._settings.auto_created_at: + if internal_message.created_at is not None: + raise TopicWriterError( + "Explicit set auto_created_at and setting auto_created_at is mutual exclusive" + ) + else: + internal_message.created_at = now + + res.append(internal_message) + + return res + + async def _check_stop(self): + async with self._lock: + if self._stop_reason is not None: + raise self._stop_reason + + async def _connection_loop(self): + retry_settings = RetrySettings() # todo + + while True: + attempt = 0 # todo calc and reset + pending = [] + + async def on_stop(): + for t in pending: + self._background_tasks.append(t) + pending.clear() + await self._stop(e) + + # noinspection PyBroadException + try: + stream_writer = await WriterAsyncIOStream.create(self._driver, self._init_message, self._get_token) + try: + async with self._lock: + self._last_known_seq_no = stream_writer.last_seqno + self._init_info.set_result(PublicWriterInitInfo(last_seqno=stream_writer.last_seqno)) + except asyncio.InvalidStateError: + pass + + self._stream_connected.set() + + send_loop = asyncio.create_task(self._send_loop(stream_writer), name="writer send loop") + receive_loop = asyncio.create_task(self._read_loop(stream_writer), name="writer receive loop") + + pending = [send_loop, receive_loop] + + done, pending = await asyncio.wait([send_loop, receive_loop], return_when=asyncio.FIRST_COMPLETED) + done.pop().result() + except issues.Error as err: + # todo log error + print(err) + + err_info = check_retriable_error(err, retry_settings, attempt) + if not err_info.is_retriable: + await on_stop() + return + + await asyncio.sleep(err_info.sleep_timeout_seconds) + + except Exception as e: + await on_stop() + return + finally: + if len(pending) > 0: + for task in pending: + task.cancel() + await asyncio.wait(pending) + + async def _read_loop(self, writer: "WriterAsyncIOStream"): + while True: + resp = await writer.receive() + async with self._lock: + for ack in resp.acks: + self._handle_receive_ack_need_lock(ack) + + def _handle_receive_ack_need_lock(self, ack): + current_message = self._messages.popleft() + message_future = self._messages_future.popleft() + if current_message.seq_no != ack.seq_no: + raise TopicWriterError( + "internal error - receive unexpected ack. Expected seqno: %s, received seqno: %s" % + (current_message.seq_no, ack.seq_no) + ) + message_future.set_result(None) # todo - return result with offset or skip status + + async def _send_loop(self, writer: "WriterAsyncIOStream"): + try: + async with self._lock: + messages = list(self._messages) + + last_seq_no = 0 + for m in messages: + writer.write([m]) + last_seq_no = m.seq_no + + while True: + m = await self._new_messages.get() # type: InternalMessage + if m.seq_no > last_seq_no: + writer.write([m]) + finally: + pass + + async def _stop(self, reason: Exception): + if reason is None: + raise Exception("writer stop reason can not be None") + + async with self._lock: + if self._stop_reason is not None: + return + self._stop_reason = reason + background_tasks = self._background_tasks + + for task in background_tasks: + task.cancel() + + await asyncio.wait(self._background_tasks) + + def _get_token(self) -> str: + raise NotImplementedError() + + +class WriterAsyncIOStream: + # todo slots + + last_seqno: int + + _stream: IGrpcWrapperAsyncIO + _token_getter: "TokenGetter" + _requests: asyncio.Queue + _responses: AsyncIterator + + def __init__(self, + token_getter: "TokenGetter", + ): + self._token_getter = token_getter + + @staticmethod + async def create(driver: SupportedDriverType, init_request: StreamWriteMessage.InitRequest, token_getter: "TokenGetter")\ + -> "WriterAsyncIOStream": + stream = GrpcWrapperAsyncIO(StreamWriteMessage.FromServer.from_proto) + + await stream.start(driver, _apis.TopicService.Stub, _apis.TopicService.StreamWrite) + + writer = WriterAsyncIOStream(token_getter) + await writer._start( + stream, + init_request + ) + return writer + + + @staticmethod + async def _create_stream_from_async(driver: ydb.aio.Driver, init_request: StreamWriteMessage.InitRequest, token_getter: "TokenGetter")\ + -> "WriterAsyncIOStream": + return GrpcWrapperAsyncIO(StreamWriteMessage.FromServer.from_proto) + + @staticmethod + async def _create_from_sync(driver: ydb.Driver, init_request: StreamWriteMessage.InitRequest, token_getter: "TokenGetter")\ + -> "WriterAsyncIOStream": + stream = GrpcWrapperAsyncIO(StreamWriteMessage.FromServer.from_proto) + await stream.start(driver, _apis.TopicService.Stub, _apis.TopicService.StreamWrite) + + writer = WriterAsyncIOStream(token_getter) + await writer._start( + stream, + init_request + ) + return writer + + async def receive(self) -> StreamWriteMessage.WriteResponse: + while True: + item = await self._stream.receive() + + if isinstance(item, StreamWriteMessage.WriteResponse): + return item + if isinstance(item, UpdateTokenResponse): + continue + + # todo log unknown messages instead of raise exception + raise Exception("Unknown message while read writer answers: %s" % item) + + async def _start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamWriteMessage.InitRequest): + stream.write(StreamWriteMessage.FromClient(init_message)) + + resp = await stream.receive() + self._ensure_ok(resp) + if not isinstance(resp, StreamWriteMessage.InitResponse): + raise TopicWriterError("Unexpected answer for init request: %s" % resp) + + self.last_seqno = resp.last_seq_no + + self._stream = stream + + @staticmethod + def _ensure_ok(message: WriterMessagesFromServerToClient): + if not message.status.is_success(): + raise TopicWriterError("status error from server in writer: %s", message.status) + + def write(self, messages: List[InternalMessage]): + for request in messages_to_proto_requests(messages): + self._stream.write(request) + + +TokenGetter = Optional[Callable[[], str]] diff --git a/ydb/_topic_writer/topic_writer_asyncio_test.py b/ydb/_topic_writer/topic_writer_asyncio_test.py new file mode 100644 index 00000000..8c5c721a --- /dev/null +++ b/ydb/_topic_writer/topic_writer_asyncio_test.py @@ -0,0 +1,487 @@ +from __future__ import annotations + +import asyncio +import copy +import dataclasses +import datetime +import typing +from queue import Queue, Empty +from unittest import mock + +import freezegun +import pytest + + +import ydb.aio +from ydb import _apis, StatusCode, issues +from ydb._topic_wrapper.common import QueueToIteratorAsyncIO, ServerStatus, IGrpcWrapperAsyncIO, IToProto, Codec +from ydb._topic_writer import InternalMessage, PublicMessage, WriterSettings, PublicWriterSettings, \ + PublicWriterInitInfo, PublicWriteResult, TopicWriterError + +# Workaround for good IDE and universal runtime +if False: + from ydb._grpc.v4.protos import ydb_topic_pb2, ydb_status_codes_pb2 +else: + from ydb._grpc.common.protos import ydb_topic_pb2, ydb_status_codes_pb2 + +from .._topic_wrapper.writer import StreamWriteMessage +from ydb._topic_writer.topic_writer_asyncio import WriterAsyncIOStream, WriterAsyncIOReconnector, TokenGetter, \ + WriterAsyncIO + + +@pytest.fixture +def default_driver() -> ydb.aio.Driver: + driver = mock.Mock(spec=ydb.aio.Driver) + driver._credentials = mock.Mock() + return driver + + +@pytest.mark.asyncio +class TestWriterAsyncIOStream: + class StreamMock(IGrpcWrapperAsyncIO): + from_server: asyncio.Queue + from_client: asyncio.Queue + + def __init__(self): + self.from_server = asyncio.Queue() + self.from_client = asyncio.Queue() + + async def receive(self) -> typing.Any: + item = await self.from_server.get() + if isinstance(item, Exception): + raise item + return item + + def write(self, wrap_message: IToProto): + self.from_client.put_nowait(wrap_message) + + @dataclasses.dataclass + class WriterWithMockedStream: + writer: WriterAsyncIOStream + stream: "TestWriterAsyncIOStream.StreamMock" + + @pytest.fixture + def stream(self): + return TestWriterAsyncIOStream.StreamMock() + + @pytest.fixture + async def writer_and_stream(self, stream) -> WriterWithMockedStream: + stream.from_server.put_nowait(StreamWriteMessage.InitResponse( + last_seq_no=4, + session_id="123", + partition_id=3, + supported_codecs=[Codec.CODEC_RAW.value, Codec.CODEC_GZIP.value], + status=ServerStatus(StatusCode.SUCCESS, []) + )) + + writer = WriterAsyncIOStream(None) + await writer._start(stream, init_message=StreamWriteMessage.InitRequest( + path="/local/test", + producer_id="producer-id", + write_session_meta={"a": "b"}, + partitioning=StreamWriteMessage.PartitioningMessageGroupID(message_group_id="message-group-id"), + get_last_seq_no=False, + )) + await stream.from_client.get() + + return TestWriterAsyncIOStream.WriterWithMockedStream( + stream=stream, + writer=writer, + ) + + async def test_init_writer(self, stream): + init_seqno = 4 + init_message = StreamWriteMessage.InitRequest( + path="/local/test", + producer_id="producer-id", + write_session_meta={"a": "b"}, + partitioning=StreamWriteMessage.PartitioningMessageGroupID(message_group_id="message-group-id"), + get_last_seq_no=False, + ) + stream.from_server.put_nowait(StreamWriteMessage.InitResponse( + last_seq_no=init_seqno, + session_id="123", + partition_id=0, + supported_codecs=[], + status=ServerStatus(StatusCode.SUCCESS, []) + )) + + writer = WriterAsyncIOStream(None) + await writer._start(stream, init_message) + + sent_message = await stream.from_client.get() + expected_message = StreamWriteMessage.FromClient(init_message) + + assert expected_message == sent_message + assert writer.last_seqno == init_seqno + + async def test_write_a_message(self, writer_and_stream: WriterWithMockedStream): + data = "123".encode() + now = datetime.datetime.now() + writer_and_stream.writer.write([InternalMessage( + PublicMessage( + seqno=1, + created_at=now, + data=data, + ) + )]) + + expected_message = StreamWriteMessage.FromClient(StreamWriteMessage.WriteRequest( + codec=Codec.CODEC_RAW.value, + messages=[ + StreamWriteMessage.WriteRequest.MessageData( + seq_no=1, + created_at=now, + data=data, + uncompressed_size=len(data), + partitioning=None, + ) + ] + )) + + sent_message = await writer_and_stream.stream.from_client.get() + assert expected_message == sent_message + + +@pytest.mark.asyncio +class TestWriterAsyncIOReconnector: + init_last_seqno = 0 + + class StreamWriterMock: + last_seqno: int + + from_client: asyncio.Queue + from_server: asyncio.Queue + + def __init__(self): + self.last_seqno = 0 + self.from_server = asyncio.Queue() + self.from_client = asyncio.Queue() + + def write(self, messages: typing.List[InternalMessage]): + self.from_client.put_nowait(messages) + + async def receive(self) -> StreamWriteMessage.WriteResponse: + item = await self.from_server.get() + if isinstance(item, Exception): + raise item + return item + + @pytest.fixture(autouse=True) + async def stream_writer_double_queue(self, monkeypatch): + + class DoubleQueueWriters: + _first: Queue + _second: Queue + + def __init__(self): + self._first = Queue() + self._second = Queue() + + def get_first(self): + try: + return self._first.get_nowait() + except Empty: + self._create() + return self.get_first() + + def get_second(self): + try: + return self._second.get_nowait() + except Empty: + self._create() + return self.get_second() + + def _create(self): + writer = TestWriterAsyncIOReconnector.StreamWriterMock() + writer.last_seqno = TestWriterAsyncIOReconnector.init_last_seqno + self._first.put_nowait(writer) + self._second.put_nowait(writer) + + res = DoubleQueueWriters() + + async def async_create(driver, init_message, token_getter): + return res.get_first() + + monkeypatch.setattr(WriterAsyncIOStream, "create", async_create) + return res + + @pytest.fixture + def get_stream_writer(self, stream_writer_double_queue) -> typing.Callable[[], "TestWriterAsyncIOReconnector.StreamWriterMock"]: + return stream_writer_double_queue.get_second + + @pytest.fixture + def default_settings(self) -> WriterSettings: + return WriterSettings(PublicWriterSettings( + topic="/local/topic", + producer_and_message_group_id="test-producer", + auto_seqno=False, + auto_created_at=False, + )) + + @pytest.fixture + def default_write_statistic(self) -> StreamWriteMessage.WriteResponse.WriteStatistics: + return StreamWriteMessage.WriteResponse.WriteStatistics( + persisting_time=datetime.timedelta(milliseconds=1), + min_queue_wait_time=datetime.timedelta(milliseconds=2), + max_queue_wait_time=datetime.timedelta(milliseconds=3), + partition_quota_wait_time=datetime.timedelta(milliseconds=4), + topic_quota_wait_time=datetime.timedelta(milliseconds=5), + ) + + @pytest.fixture + async def reconnector(self, default_driver, default_settings) -> WriterAsyncIOReconnector: + return WriterAsyncIOReconnector(default_driver, default_settings) + + async def test_reconnect_and_resent_non_acked_messages_on_retriable_error( + self, + reconnector: WriterAsyncIOReconnector, + get_stream_writer, + default_write_statistic, + ): + now = datetime.datetime.now() + data = "123".encode() + + message1 = PublicMessage( + data=data, + seqno=1, + created_at=now, + ) + message2 = PublicMessage( + data=data, + seqno=2, + created_at=now, + ) + await reconnector.write_with_ack([message1, message2]) + + # sent to first stream + stream_writer = get_stream_writer() + + messages = await stream_writer.from_client.get() + assert [InternalMessage(message1)] == messages + messages = await stream_writer.from_client.get() + assert [InternalMessage(message2)] == messages + + # ack first message + stream_writer.from_server.put_nowait(StreamWriteMessage.WriteResponse( + partition_id=1, + acks=[ + StreamWriteMessage.WriteResponse.WriteAck( + seq_no=1, + message_write_status=StreamWriteMessage.WriteResponse.WriteAck.StatusWritten( + offset=1 + ) + ) + ], + write_statistics=default_write_statistic, + )) + + stream_writer.from_server.put_nowait(issues.Overloaded("test")) + + second_writer = get_stream_writer() + second_sent_msg = await second_writer.from_client.get() + + expected_messages = [InternalMessage(message2)] + assert second_sent_msg == expected_messages + await reconnector.close() + + async def test_stop_on_unexpected_exception(self, reconnector: WriterAsyncIOReconnector, get_stream_writer): + class TestException(Exception): + pass + + stream_writer = get_stream_writer() + stream_writer.from_server.put_nowait(TestException()) + + message = PublicMessage( + data="123", + seqno=3, + ) + + with pytest.raises(TestException): + async def wait_stop(): + while True: + await reconnector.write_with_ack([message]) + await asyncio.sleep(0.1) + + await asyncio.wait_for(wait_stop(), 1) + + with pytest.raises(TestException): + await reconnector.close() + + async def test_wait_init(self, default_driver, default_settings, get_stream_writer): + init_seqno = 100 + expected_init_info = PublicWriterInitInfo(init_seqno) + with mock.patch.object(TestWriterAsyncIOReconnector, "init_last_seqno", init_seqno): + reconnector = WriterAsyncIOReconnector(default_driver, default_settings) + info = await reconnector.wait_init() + assert info == expected_init_info + + reconnector._stream_connected.clear() + + # force reconnect + with mock.patch.object(TestWriterAsyncIOReconnector, "init_last_seqno", init_seqno+1): + stream_writer = get_stream_writer() + stream_writer.from_server.put_nowait(issues.Overloaded("test")) # some retriable error + await reconnector._stream_connected.wait() + + info = await reconnector.wait_init() + assert info == expected_init_info + + await reconnector.close() + + async def test_write_message(self, reconnector: WriterAsyncIOReconnector, get_stream_writer): + stream_writer = get_stream_writer() + message = PublicMessage( + data="123", + seqno=3, + ) + await reconnector.write_with_ack([message]) + + sent_messages = await asyncio.wait_for(stream_writer.from_client.get(), 1) + assert sent_messages == [InternalMessage(message)] + + await reconnector.close() + + async def test_auto_seq_no(self, default_driver, default_settings, get_stream_writer): + last_seq_no = 100 + with mock.patch.object(TestWriterAsyncIOReconnector, "init_last_seqno", last_seq_no): + settings = copy.deepcopy(default_settings) + settings.auto_seqno = True + + reconnector = WriterAsyncIOReconnector(default_driver, settings) + + await reconnector.write_with_ack([PublicMessage(data="123")]) + await reconnector.write_with_ack([PublicMessage(data="456")]) + + stream_writer = get_stream_writer() + + sent = await stream_writer.from_client.get() + assert [InternalMessage(PublicMessage(seqno=last_seq_no+1, data="123"))] == sent + + sent = await stream_writer.from_client.get() + assert [InternalMessage(PublicMessage(seqno=last_seq_no+2, data="456"))] == sent + + with pytest.raises(TopicWriterError): + await reconnector.write_with_ack([PublicMessage(seqno=last_seq_no+3, data="123")]) + + await reconnector.close() + + async def test_deny_double_seqno(self, reconnector: WriterAsyncIOReconnector): + await reconnector.write_with_ack([PublicMessage(seqno=10, data="123")]) + + with pytest.raises(TopicWriterError): + await reconnector.write_with_ack([PublicMessage(seqno=9, data="123")]) + + with pytest.raises(TopicWriterError): + await reconnector.write_with_ack([PublicMessage(seqno=10, data="123")]) + + await reconnector.write_with_ack([PublicMessage(seqno=11, data="123")]) + + await reconnector.close() + + @freezegun.freeze_time("2022-01-13 20:50:00", tz_offset=0) + async def test_auto_created_at(self, default_driver, default_settings, get_stream_writer): + now = datetime.datetime.now() + + settings = copy.deepcopy(default_settings) + settings.auto_created_at = True + reconnector = WriterAsyncIOReconnector(default_driver, settings) + await reconnector.write_with_ack([PublicMessage(seqno=4, data="123")]) + + stream_writer = get_stream_writer() + sent = await stream_writer.from_client.get() + + assert [InternalMessage(PublicMessage(seqno=4, data="123", created_at=now))] == sent + await reconnector.close() + + +@pytest.mark.asyncio +class TestWriterAsyncIO: + class ReconnectorMock: + lock: asyncio.Lock + messages: typing.List[InternalMessage] + futures: typing.List[asyncio.Future] + messages_writted: asyncio.Event + + def __init__(self): + self.lock = asyncio.Lock() + self.messages = [] + self.futures = [] + self.messages_writted = asyncio.Event() + + async def write_with_ack(self, messages: typing.List[InternalMessage]): + async with self.lock: + futures = [asyncio.Future() for _ in messages] + self.messages.extend(messages) + self.futures.extend(futures) + self.messages_writted.set() + return futures + + async def close(self): + pass + + @pytest.fixture + def default_settings(self) -> PublicWriterSettings: + return PublicWriterSettings( + topic="/local/topic", + producer_and_message_group_id="producer-id", + ) + + @pytest.fixture(autouse=True) + def mock_reconnector_init(self, monkeypatch, reconnector): + def t(cls, driver, settings): + return reconnector + monkeypatch.setattr(WriterAsyncIOReconnector, "__new__", t) + + @pytest.fixture + def reconnector(self, monkeypatch) -> TestWriterAsyncIO.ReconnectorMock: + reconnector = TestWriterAsyncIO.ReconnectorMock() + return reconnector + + @pytest.fixture + async def writer(self, default_driver, default_settings): + return WriterAsyncIO(default_driver, default_settings) + + async def test_write(self, writer: WriterAsyncIO, reconnector): + m = PublicMessage(seqno=1, data="123") + res = await writer.write(m) + assert res is None + + assert reconnector.messages == [m] + + async def test_write_with_futures(self, writer: WriterAsyncIO, reconnector): + m = PublicMessage(seqno=1, data="123") + res = await writer.write_with_ack_future(m) + + assert reconnector.messages == [m] + assert asyncio.isfuture(res) + + async def test_write_with_ack(self, writer: WriterAsyncIO, reconnector): + reconnector.messages_writted.clear() + + async def ack_first_message(): + await reconnector.messages_writted.wait() + async with reconnector.lock: + reconnector.futures[0].set_result(PublicWriteResult.Written(offset=1)) + asyncio.create_task(ack_first_message()) + + m = PublicMessage(seqno=1, data="123") + res = await writer.write_with_ack(m) + + assert res == PublicWriteResult.Written(offset=1) + + reconnector.messages_writted.clear() + async with reconnector.lock: + reconnector.messages.clear() + reconnector.futures.clear() + + async def ack_next_messages(): + await reconnector.messages_writted.wait() + async with reconnector.lock: + reconnector.futures[0].set_result(PublicWriteResult.Written(offset=2)) + reconnector.futures[1].set_result(PublicWriteResult.Skipped()) + asyncio.create_task(ack_next_messages()) + + res = await writer.write_with_ack([PublicMessage(seqno=2, data="123"), PublicMessage(seqno=3, data="123")]) + assert res == [PublicWriteResult.Written(offset=2), PublicWriteResult.Skipped()] + diff --git a/ydb/_topic_writer/topic_writer_sync.py b/ydb/_topic_writer/topic_writer_sync.py new file mode 100644 index 00000000..97ecee2a --- /dev/null +++ b/ydb/_topic_writer/topic_writer_sync.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +import asyncio +from concurrent.futures import Future +import threading +from typing import Union, List, Optional, Coroutine + +import ydb +from .._topic_wrapper.common import SupportedDriverType +from ydb._topic_writer import PublicWriterSettings, TopicWriterError, PublicWriterInitInfo, PublicMessage, Writer, \ + PublicWriteResult +from ydb._topic_writer.topic_writer_asyncio import WriterAsyncIO + +_shared_event_loop_lock = threading.Lock() +_shared_event_loop = None # type: Optional[asyncio.AbstractEventLoop] + + +def _get_shared_event_loop() -> asyncio.AbstractEventLoop: + global _shared_event_loop + + if _shared_event_loop is not None: + return _shared_event_loop + + with _shared_event_loop_lock: + if _shared_event_loop is not None: + return _shared_event_loop + + event_loop_set_done = Future() + + def start_event_loop(): + global _shared_event_loop + _shared_event_loop = asyncio.new_event_loop() + event_loop_set_done.set_result(None) + asyncio.set_event_loop(_shared_event_loop) + _shared_event_loop.run_forever() + + t = threading.Thread(target=start_event_loop, name="Common ydb topic writer event loop", daemon=True) + t.start() + + event_loop_set_done.result() + return _shared_event_loop + + +class WriterSync: + _loop: asyncio.AbstractEventLoop + _async_writer: WriterAsyncIO + _closed: bool + + def __init__(self, + driver: SupportedDriverType, + settings: PublicWriterSettings, + *, + eventloop: asyncio.AbstractEventLoop = None): + + self._closed = False + + if eventloop: + self._loop = eventloop + else: + self._loop = _get_shared_event_loop() + + async def create_async_writer(): + return WriterAsyncIO(driver, settings) + + self._async_writer = asyncio.run_coroutine_threadsafe(create_async_writer(), self._loop).result() + + def _call(self, coro, *args, **kwargs): + if self._closed: + raise TopicWriterError("writer is closed") + + return asyncio.run_coroutine_threadsafe(coro, self._loop) + + def _call_sync(self, coro: Coroutine, timeout, *args, **kwargs): + f = self._call(coro, *args, **kwargs) + try: + return f.result() + except TimeoutError: + f.cancel() + raise + + def close(self): + if self._closed: + return + self._closed = True + asyncio.run_coroutine_threadsafe(self._async_writer.close(), self._loop).result() + + def async_flush(self) -> Future: + if self._closed: + raise TopicWriterError("writer is closed") + return self._call(self._async_writer.flush()) + + def flush(self, timeout=None): + self._call_sync(self._async_writer.flush(), timeout) + + def async_wait_init(self) -> Future[PublicWriterInitInfo]: + return self._call(self._async_writer.wait_init()) + + def wait_init(self, timeout) -> PublicWriterInitInfo: + return self._call_sync(self._async_writer.wait_init(), timeout) + + def write(self, message: Union[PublicMessage, List[PublicMessage]], *args: Optional[PublicMessage], + timeout: Union[float, None] = None): + self._call_sync(self._async_writer.write(message, *args), timeout=timeout) + + def async_write_with_ack(self, + messages: Union[Writer.MessageType, List[Writer.MessageType]], + *args: Optional[Writer.MessageType], + ) -> Future[Union[PublicWriteResult, List[PublicWriteResult]]]: + return self._call(self._async_writer.write_with_ack(messages, *args)) + + def write_with_ack(self, + messages: Union[Writer.MessageType, List[Writer.MessageType]], + *args: Optional[Writer.MessageType], + timeout: Union[float, None] = None, + ) -> Union[PublicWriteResult, List[PublicWriteResult]]: + return self._call_sync(self._async_writer.write_with_ack(messages, *args), timeout=timeout) diff --git a/ydb/aio/connection.py b/ydb/aio/connection.py index 88ab738c..85c22638 100644 --- a/ydb/aio/connection.py +++ b/ydb/aio/connection.py @@ -24,11 +24,19 @@ from ydb.settings import BaseRequestSettings from ydb import issues +# Workaround for good IDE and universal runtime +if False: + from ydb._grpc.v4 import ydb_topic_v1_pb2_grpc +else: + from ydb._grpc.common import ydb_topic_v1_pb2_grpc + + _stubs_list = ( _apis.TableService.Stub, _apis.SchemeService.Stub, _apis.DiscoveryService.Stub, _apis.CmsService.Stub, + ydb_topic_v1_pb2_grpc.TopicServiceStub, ) logger = logging.getLogger(__name__) diff --git a/ydb/aio/driver.py b/ydb/aio/driver.py index 3bf6cca8..b6641e27 100644 --- a/ydb/aio/driver.py +++ b/ydb/aio/driver.py @@ -2,7 +2,9 @@ from . import pool, scheme, table import ydb +from .. import _utilities from ydb.driver import get_config +from .. import topic def default_credentials(credentials=None): @@ -56,7 +58,7 @@ def default_from_endpoint_and_database( def default_from_connection_string( cls, connection_string, root_certificates=None, credentials=None, **kwargs ): - endpoint, database = ydb.parse_connection_string(connection_string) + endpoint, database = _utilities.parse_connection_string(connection_string) return cls( endpoint, database, @@ -67,6 +69,8 @@ def default_from_connection_string( class Driver(pool.ConnectionPool): + _credentials: ydb.Credentials # used for topic clients + def __init__( self, driver_config=None, @@ -89,5 +93,8 @@ def __init__( super(Driver, self).__init__(config) + self._credentials = config.credentials + self.scheme_client = scheme.SchemeClient(self) self.table_client = table.TableClient(self, config.table_client_settings) + self.topic_client = topic.TopicClientAsyncIO(self, config.topic_client_settings) diff --git a/ydb/driver.py b/ydb/driver.py index 9b3fa99c..e66a5fc9 100644 --- a/ydb/driver.py +++ b/ydb/driver.py @@ -70,6 +70,7 @@ class DriverConfig(object): "grpc_keep_alive_timeout", "secure_channel", "table_client_settings", + "topic_client_settings", "endpoints", "primary_user_agent", "tracer", @@ -92,6 +93,7 @@ def __init__( private_key=None, grpc_keep_alive_timeout=None, table_client_settings=None, + topic_client_settings=None, endpoints=None, primary_user_agent="python-library", tracer=None, @@ -138,6 +140,7 @@ def __init__( self.private_key = private_key self.grpc_keep_alive_timeout = grpc_keep_alive_timeout self.table_client_settings = table_client_settings + self.topic_client_settings = topic_client_settings self.primary_user_agent = primary_user_agent self.tracer = tracer if tracer is not None else tracing.Tracer(None) self.grpc_lb_policy_name = grpc_lb_policy_name @@ -238,5 +241,8 @@ def __init__( ) super(Driver, self).__init__(driver_config) + + self._credentials = driver_config.credentials + self.scheme_client = scheme.SchemeClient(self) self.table_client = table.TableClient(self, driver_config.table_client_settings) diff --git a/ydb/pool.py b/ydb/pool.py index dfda0adf..73cd1681 100644 --- a/ydb/pool.py +++ b/ydb/pool.py @@ -10,6 +10,7 @@ from . import connection as connection_impl, issues, resolver, _utilities, tracing from abc import abstractmethod, ABCMeta +from .connection import Connection logger = logging.getLogger(__name__) @@ -127,7 +128,7 @@ def subscribe(self): return subscription @tracing.with_trace() - def get(self, preferred_endpoint=None): + def get(self, preferred_endpoint=None) -> Connection: with self.lock: if ( preferred_endpoint is not None diff --git a/ydb/topic.py b/ydb/topic.py new file mode 100644 index 00000000..d6bfe0f7 --- /dev/null +++ b/ydb/topic.py @@ -0,0 +1,112 @@ +from typing import List, Callable, Union, Mapping, Any + +import ydb._topic_writer + +from ydb._topic_reader import ( + Reader as TopicReader, ReaderAsyncIO as TopicReaderAsyncIO, + Selector as TopicSelector, +) + +from ydb._topic_writer import ( + Writer as TopicWriter, + PublicWriterSettings as TopicWriterSettings, +) + +from ydb._topic_writer.topic_writer_asyncio import WriterAsyncIO as TopicWriterAsyncIO + + +class TopicClientAsyncIO: + _driver: ydb.aio.Driver + _credentials: Union[ydb.Credentials, None] + + def __init__(self, driver: ydb.aio.Driver, settings: "TopicClientSettings" = None): + self._driver = driver + + def topic_reader(self, topic: Union[str, TopicSelector, List[Union[str, TopicSelector]]], + consumer: str, + commit_batch_time: Union[float, None] = 0.1, + commit_batch_count: Union[int, None] = 1000, + buffer_size_bytes: int = 50 * 1024 * 1024, + sync_commit: bool = False, # reader.commit(...) will wait commit ack from server + on_commit: Callable[["OnCommitEvent"], None] = None, + on_get_partition_start_offset: Callable[ + ["ydb._topic_reader.Events.OnPartitionGetStartOffsetRequest"], "ydb._topic_reader.Events.OnPartitionGetStartOffsetResponse"] = None, + on_init_partition: Callable[["StubEvent"], None] = None, + on_shutdown_partition: Callable[["StubEvent"], None] = None, + decoder: Union[Mapping[int, Callable[[bytes], bytes]], None] = None, + deserializer: Union[Callable[[bytes], Any], None] = None, + one_attempt_connection_timeout: Union[float, None] = 1, + connection_timeout: Union[float, None] = None, + retry_policy: Union["ydb._topic_reader.RetryPolicy", None] = None, + ) -> TopicReaderAsyncIO: + raise NotImplementedError() + + def topic_writer(self, topic, + *, + producer_and_message_group_id: str, + session_metadata: Mapping[str, str] = None, + encoders: Union[Mapping[int, Callable[[bytes], bytes]], None] = None, + serializer: Union[Callable[[Any], bytes], None] = None, + send_buffer_count: Union[int, None] = 10000, + send_buffer_bytes: Union[int, None] = 100 * 1024 * 1024, + partition_id: Union[int, None] = None, + codec: Union[int, None] = None, + codec_autoselect: bool = True, + auto_seqno: bool = True, + auto_created_at: bool = True, + get_last_seqno: bool = False, + retry_policy: Union["ydb._topic_writer.RetryPolicy", None] = None, + ) -> TopicWriterAsyncIO: + args = locals() + del args['self'] + settings = TopicWriterSettings(**args) + return TopicWriterAsyncIO(self._driver, settings) + + +class TopicClient: + def __init__(self, driver, topic_client_settings: "TopicClientSettings" = None): + pass + + def topic_reader(self, topic: Union[str, TopicSelector, List[Union[str, TopicSelector]]], + consumer: str, + commit_batch_time: Union[float, None] = 0.1, + commit_batch_count: Union[int, None] = 1000, + buffer_size_bytes: int = 50 * 1024 * 1024, + sync_commit: bool = False, # reader.commit(...) will wait commit ack from server + on_commit: Callable[["OnCommitEvent"], None] = None, + on_get_partition_start_offset: Callable[ + ["ydb._topic_reader.Events.OnPartitionGetStartOffsetRequest"], "ydb._topic_reader.Events.OnPartitionGetStartOffsetResponse"] = None, + on_init_partition: Callable[["StubEvent"], None] = None, + on_shutdown_partition: Callable[["StubEvent"], None] = None, + decoder: Union[Mapping[int, Callable[[bytes], bytes]], None] = None, + deserializer: Union[Callable[[bytes], Any], None] = None, + one_attempt_connection_timeout: Union[float, None] = 1, + connection_timeout: Union[float, None] = None, + retry_policy: Union["ydb._topic_reader.RetryPolicy", None] = None, + ) -> TopicReader: + raise NotImplementedError() + + def topic_writer(self, topic, + producer_and_message_group_id: str, + session_metadata: Mapping[str, str] = None, + encoders: Union[Mapping[int, Callable[[bytes], bytes]], None] = None, + serializer: Union[Callable[[Any], bytes], None] = None, + send_buffer_count: Union[int, None] = 10000, + send_buffer_bytes: Union[int, None] = 100 * 1024 * 1024, + partition_id: Union[int, None] = None, + codec: Union[int, None] = None, + codec_autoselect: bool = True, + auto_seqno: bool = True, + auto_created_at: bool = True, + get_last_seqno: bool = False, + retry_policy: Union["ydb._topic_writer.RetryPolicy", None] = None, + ) -> TopicWriter: + raise NotImplementedError() + + +class TopicClientSettings: + pass + + +class StubEvent: + pass From d2534c99cc243561e18665667163961bdecea73c Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 24 Jan 2023 14:19:11 +0300 Subject: [PATCH 005/147] fix imports --- ydb/_topic_writer/__init__.py | 1 - .../topic_writer_asyncio_test.py | 20 +++++++------------ ydb/_topic_writer/topic_writer_sync.py | 6 +++--- ydb/topic.py | 3 ++- 4 files changed, 12 insertions(+), 18 deletions(-) diff --git a/ydb/_topic_writer/__init__.py b/ydb/_topic_writer/__init__.py index 87216032..e69de29b 100644 --- a/ydb/_topic_writer/__init__.py +++ b/ydb/_topic_writer/__init__.py @@ -1 +0,0 @@ -from .topic_writer import * diff --git a/ydb/_topic_writer/topic_writer_asyncio_test.py b/ydb/_topic_writer/topic_writer_asyncio_test.py index 8c5c721a..4ea16764 100644 --- a/ydb/_topic_writer/topic_writer_asyncio_test.py +++ b/ydb/_topic_writer/topic_writer_asyncio_test.py @@ -12,26 +12,20 @@ import pytest -import ydb.aio -from ydb import _apis, StatusCode, issues -from ydb._topic_wrapper.common import QueueToIteratorAsyncIO, ServerStatus, IGrpcWrapperAsyncIO, IToProto, Codec -from ydb._topic_writer import InternalMessage, PublicMessage, WriterSettings, PublicWriterSettings, \ +from .. import aio +from .. import StatusCode, issues +from .._topic_wrapper.common import ServerStatus, IGrpcWrapperAsyncIO, IToProto, Codec +from .topic_writer import InternalMessage, PublicMessage, WriterSettings, PublicWriterSettings, \ PublicWriterInitInfo, PublicWriteResult, TopicWriterError -# Workaround for good IDE and universal runtime -if False: - from ydb._grpc.v4.protos import ydb_topic_pb2, ydb_status_codes_pb2 -else: - from ydb._grpc.common.protos import ydb_topic_pb2, ydb_status_codes_pb2 - from .._topic_wrapper.writer import StreamWriteMessage -from ydb._topic_writer.topic_writer_asyncio import WriterAsyncIOStream, WriterAsyncIOReconnector, TokenGetter, \ +from .topic_writer_asyncio import WriterAsyncIOStream, WriterAsyncIOReconnector, TokenGetter, \ WriterAsyncIO @pytest.fixture -def default_driver() -> ydb.aio.Driver: - driver = mock.Mock(spec=ydb.aio.Driver) +def default_driver() -> aio.Driver: + driver = mock.Mock(spec=aio.Driver) driver._credentials = mock.Mock() return driver diff --git a/ydb/_topic_writer/topic_writer_sync.py b/ydb/_topic_writer/topic_writer_sync.py index 97ecee2a..bbd6d71e 100644 --- a/ydb/_topic_writer/topic_writer_sync.py +++ b/ydb/_topic_writer/topic_writer_sync.py @@ -5,11 +5,11 @@ import threading from typing import Union, List, Optional, Coroutine -import ydb from .._topic_wrapper.common import SupportedDriverType -from ydb._topic_writer import PublicWriterSettings, TopicWriterError, PublicWriterInitInfo, PublicMessage, Writer, \ +from .topic_writer import PublicWriterSettings, TopicWriterError, PublicWriterInitInfo, PublicMessage, Writer, \ PublicWriteResult -from ydb._topic_writer.topic_writer_asyncio import WriterAsyncIO + +from .topic_writer_asyncio import WriterAsyncIO _shared_event_loop_lock = threading.Lock() _shared_event_loop = None # type: Optional[asyncio.AbstractEventLoop] diff --git a/ydb/topic.py b/ydb/topic.py index d6bfe0f7..e644a567 100644 --- a/ydb/topic.py +++ b/ydb/topic.py @@ -7,9 +7,10 @@ Selector as TopicSelector, ) -from ydb._topic_writer import ( +from ydb._topic_writer.topic_writer import ( Writer as TopicWriter, PublicWriterSettings as TopicWriterSettings, + PublicMessage as TopicWriterMessage, ) from ydb._topic_writer.topic_writer_asyncio import WriterAsyncIO as TopicWriterAsyncIO From b4f2f23f45300fe03bb25369efb55dfbce760cac Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 24 Jan 2023 14:24:12 +0300 Subject: [PATCH 006/147] black format --- examples/topic/reader_async_example.py | 47 ++-- examples/topic/reader_example.py | 42 +++- examples/topic/writer_async_example.py | 50 ++-- examples/topic/writer_example.py | 55 +++-- ydb/__init__.py | 2 +- ydb/_topic_reader/topic_reader.py | 101 +++++--- ydb/_topic_wrapper/common.py | 47 ++-- ydb/_topic_wrapper/control_plane.py | 1 - ydb/_topic_wrapper/reader.py | 1 - ydb/_topic_wrapper/writer.py | 78 ++++-- ydb/_topic_writer/topic_writer.py | 61 +++-- ydb/_topic_writer/topic_writer_asyncio.py | 151 ++++++++---- .../topic_writer_asyncio_test.py | 232 +++++++++++------- ydb/_topic_writer/topic_writer_sync.py | 68 +++-- ydb/topic.py | 147 ++++++----- 15 files changed, 703 insertions(+), 380 deletions(-) diff --git a/examples/topic/reader_async_example.py b/examples/topic/reader_async_example.py index 8a5b84f6..540e780f 100644 --- a/examples/topic/reader_async_example.py +++ b/examples/topic/reader_async_example.py @@ -6,12 +6,19 @@ async def connect(): - db = ydb.aio.Driver(connection_string="grpc://localhost:2135?database=/local", credentials=ydb.credentials.AnonymousCredentials()) - reader = ydb.TopicClientAsyncIO(db).topic_reader("/local/topic", consumer="consumer") + db = ydb.aio.Driver( + connection_string="grpc://localhost:2135?database=/local", + credentials=ydb.credentials.AnonymousCredentials(), + ) + reader = ydb.TopicClientAsyncIO(db).topic_reader( + "/local/topic", consumer="consumer" + ) async def create_reader_and_close_with_context_manager(db: ydb.aio.Driver): - with ydb.TopicClientAsyncIO(db).topic_reader("/database/topic/path", consumer="consumer") as reader: + with ydb.TopicClientAsyncIO(db).topic_reader( + "/database/topic/path", consumer="consumer" + ) as reader: async for message in reader.messages(): pass @@ -83,9 +90,13 @@ async def get_one_batch_from_external_loop_async(reader: ydb.TopicReaderAsyncIO) async def auto_deserialize_message(db: ydb.aio.Driver): # async, batch work similar to this - async with ydb.TopicClientAsyncIO(db).topic_reader("/database/topic/path", consumer="asd", deserializer=json.loads) as reader: + async with ydb.TopicClientAsyncIO(db).topic_reader( + "/database/topic/path", consumer="asd", deserializer=json.loads + ) as reader: async for message in reader.messages(): - print(message.data.Name) # message.data replaces by json.loads(message.data) of raw message + print( + message.data.Name + ) # message.data replaces by json.loads(message.data) of raw message reader.commit(message) @@ -122,7 +133,11 @@ def process_batch(batch): async def connect_and_read_few_topics(db: ydb.aio.Driver): with ydb.TopicClientAsyncIO(db).topic_reader( - ["/database/topic/path", ydb.TopicSelector("/database/second-topic", partitions=3)]) as reader: + [ + "/database/topic/path", + ydb.TopicSelector("/database/second-topic", partitions=3), + ] + ) as reader: async for message in reader.messages(): await _process(message) await reader.commit(message) @@ -140,26 +155,28 @@ def on_commit(event: ydb.TopicReaderEvents.OnCommit) -> None: print(event.topic) print(event.offset) - async with ydb.TopicClientAsyncIO(db).topic_reader("/local", - consumer="consumer", - commit_batch_time=4, - on_commit=on_commit) as reader: + async with ydb.TopicClientAsyncIO(db).topic_reader( + "/local", consumer="consumer", commit_batch_time=4, on_commit=on_commit + ) as reader: async for message in reader.messages(): await _process(message) await reader.commit(message) async def advanced_read_with_own_progress_storage(db: ydb.TopicReaderAsyncIO): - async def on_get_partition_start_offset(req: ydb.TopicReaderEvents.OnPartitionGetStartOffsetRequest) -> \ - ydb.TopicReaderEvents.OnPartitionGetStartOffsetResponse: + async def on_get_partition_start_offset( + req: ydb.TopicReaderEvents.OnPartitionGetStartOffsetRequest, + ) -> ydb.TopicReaderEvents.OnPartitionGetStartOffsetResponse: # read current progress from database resp = ydb.TopicReaderEvents.OnPartitionGetStartOffsetResponse() resp.start_offset = 123 return resp - async with ydb.TopicClient(db).topic_reader("/local/test", consumer="consumer", - on_get_partition_start_offset=on_get_partition_start_offset - ) as reader: + async with ydb.TopicClient(db).topic_reader( + "/local/test", + consumer="consumer", + on_get_partition_start_offset=on_get_partition_start_offset, + ) as reader: async for mess in reader.messages(): await _process(mess) # save progress to own database diff --git a/examples/topic/reader_example.py b/examples/topic/reader_example.py index 0bb7bb8f..130679c1 100644 --- a/examples/topic/reader_example.py +++ b/examples/topic/reader_example.py @@ -5,12 +5,17 @@ def connect(): - db = ydb.Driver(connection_string="grpc://localhost:2135?database=/local", credentials=ydb.credentials.AnonymousCredentials()) + db = ydb.Driver( + connection_string="grpc://localhost:2135?database=/local", + credentials=ydb.credentials.AnonymousCredentials(), + ) reader = ydb.TopicClient(db).topic_reader("/local/topic", consumer="consumer") def create_reader_and_close_with_context_manager(db: ydb.Driver): - with ydb.TopicClient(db).topic_reader("/database/topic/path", consumer="consumer", buffer_size_bytes=123) as reader: + with ydb.TopicClient(db).topic_reader( + "/database/topic/path", consumer="consumer", buffer_size_bytes=123 + ) as reader: for message in reader: pass @@ -75,9 +80,13 @@ def get_one_batch_from_external_loop(reader: ydb.TopicReader): def auto_deserialize_message(db: ydb.Driver): # async, batch work similar to this - reader = ydb.TopicClient(db).topic_reader("/database/topic/path", consumer="asd", deserializer=json.loads) + reader = ydb.TopicClient(db).topic_reader( + "/database/topic/path", consumer="asd", deserializer=json.loads + ) for message in reader.messages(): - print(message.data.Name) # message.data replaces by json.loads(message.data) of raw message + print( + message.data.Name + ) # message.data replaces by json.loads(message.data) of raw message reader.commit(message) @@ -113,7 +122,12 @@ def process_batch(batch): def connect_and_read_few_topics(db: ydb.Driver): - with ydb.TopicClient(db).topic_reader(["/database/topic/path", ydb.TopicSelector("/database/second-topic", partitions=3)]) as reader: + with ydb.TopicClient(db).topic_reader( + [ + "/database/topic/path", + ydb.TopicSelector("/database/second-topic", partitions=3), + ] + ) as reader: for message in reader: _process(message) reader.commit(message) @@ -131,24 +145,29 @@ def on_commit(event: ydb.TopicReaderEvents.OnCommit) -> None: print(event.topic) print(event.offset) - with ydb.TopicClient(db).topic_reader("/local", consumer="consumer", commit_batch_time=4, on_commit=on_commit) as reader: + with ydb.TopicClient(db).topic_reader( + "/local", consumer="consumer", commit_batch_time=4, on_commit=on_commit + ) as reader: for message in reader: with reader.commit_on_exit(message): _process(message) def advanced_read_with_own_progress_storage(db: ydb.TopicReader): - def on_get_partition_start_offset(req: ydb.TopicReaderEvents.OnPartitionGetStartOffsetRequest) -> \ - ydb.TopicReaderEvents.OnPartitionGetStartOffsetResponse: + def on_get_partition_start_offset( + req: ydb.TopicReaderEvents.OnPartitionGetStartOffsetRequest, + ) -> ydb.TopicReaderEvents.OnPartitionGetStartOffsetResponse: # read current progress from database resp = ydb.TopicReaderEvents.OnPartitionGetStartOffsetResponse() resp.start_offset = 123 return resp - with ydb.TopicClient(db).topic_reader("/local/test", consumer="consumer", - on_get_partition_start_offset=on_get_partition_start_offset - ) as reader: + with ydb.TopicClient(db).topic_reader( + "/local/test", + consumer="consumer", + on_get_partition_start_offset=on_get_partition_start_offset, + ) as reader: for mess in reader: _process(mess) # save progress to own database @@ -170,4 +189,3 @@ def get_current_statistics(reader: ydb.TopicReader): def _process(msg): raise NotImplementedError() - diff --git a/examples/topic/writer_async_example.py b/examples/topic/writer_async_example.py index 4b26c702..1db7ce39 100644 --- a/examples/topic/writer_async_example.py +++ b/examples/topic/writer_async_example.py @@ -7,23 +7,26 @@ async def create_writer(db: ydb.aio.Driver): - async with ydb.TopicClientAsyncIO(db).topic_writer("/database/topic/path", - producer_and_message_group_id="producer-id", - ) as writer: + async with ydb.TopicClientAsyncIO(db).topic_writer( + "/database/topic/path", + producer_and_message_group_id="producer-id", + ) as writer: pass async def connect_and_wait(db: ydb.aio.Driver): - async with ydb.TopicClientAsyncIO(db).topic_writer("/database/topic/path", - producer_and_message_group_id="producer-id", - ) as writer: + async with ydb.TopicClientAsyncIO(db).topic_writer( + "/database/topic/path", + producer_and_message_group_id="producer-id", + ) as writer: writer.wait_init() async def connect_without_context_manager(db: ydb.aio.Driver): - writer = ydb.TopicClientAsyncIO(db).topic_writer("/database/topic/path", - producer_and_message_group_id="producer-id", - ) + writer = ydb.TopicClientAsyncIO(db).topic_writer( + "/database/topic/path", + producer_and_message_group_id="producer-id", + ) try: pass # some code finally: @@ -39,14 +42,19 @@ async def send_messages(writer: ydb.TopicWriterAsyncIO): # full forms await writer.write(ydb.TopicWriterMessage("mess")) # send text await writer.write(ydb.TopicWriterMessage(bytes([1, 2, 3]))) # send bytes - await writer.write(ydb.TopicWriterMessage("mess-1"), - ydb.TopicWriterMessage("mess-2")) # send few messages by one call + await writer.write( + ydb.TopicWriterMessage("mess-1"), ydb.TopicWriterMessage("mess-2") + ) # send few messages by one call # with meta - await writer.write(ydb.TopicWriterMessage("asd", seqno=123, created_at_ns=time.time_ns())) + await writer.write( + ydb.TopicWriterMessage("asd", seqno=123, created_at_ns=time.time_ns()) + ) -async def send_message_without_block_if_internal_buffer_is_full(writer: ydb.TopicWriterAsyncIO, msg) -> bool: +async def send_message_without_block_if_internal_buffer_is_full( + writer: ydb.TopicWriterAsyncIO, msg +) -> bool: try: # put message to internal queue for send, but if buffer is full - fast return # without wait @@ -62,7 +70,9 @@ def send_messages_with_manual_seqno(writer: ydb.TopicWriter): async def send_messages_with_wait_ack(writer: ydb.TopicWriterAsyncIO): # future wait - await writer.write_with_result(ydb.TopicWriterMessage("mess", seqno=1), ydb.TopicWriterMessage("mess", seqno=2)) + await writer.write_with_result( + ydb.TopicWriterMessage("mess", seqno=1), ydb.TopicWriterMessage("mess", seqno=2) + ) # send with flush await writer.write("1", "2", "3") @@ -70,7 +80,9 @@ async def send_messages_with_wait_ack(writer: ydb.TopicWriterAsyncIO): async def send_json_message(db: ydb.aio.Driver): - async with ydb.TopicClientAsyncIO(db).topic_writer("/database/path/topic", serializer=json.dumps) as writer: + async with ydb.TopicClientAsyncIO(db).topic_writer( + "/database/path/topic", serializer=json.dumps + ) as writer: writer.write({"a": 123}) @@ -80,7 +92,9 @@ async def send_messages_and_wait_all_commit_with_flush(writer: ydb.TopicWriterAs await writer.flush() -async def send_messages_and_wait_all_commit_with_results(writer: ydb.TopicWriterAsyncIO): +async def send_messages_and_wait_all_commit_with_results( + writer: ydb.TopicWriterAsyncIO, +): last_future = None for i in range(10): content = "%s" % i @@ -91,7 +105,9 @@ async def send_messages_and_wait_all_commit_with_results(writer: ydb.TopicWriter raise last_future.exception() -async def switch_messages_with_many_producers(writers: Dict[str, ydb.TopicWriterAsyncIO], messages: List[str]): +async def switch_messages_with_many_producers( + writers: Dict[str, ydb.TopicWriterAsyncIO], messages: List[str] +): futures = [] # type: List[asyncio.Future] for msg in messages: diff --git a/examples/topic/writer_example.py b/examples/topic/writer_example.py index 99966791..bb9e1bea 100644 --- a/examples/topic/writer_example.py +++ b/examples/topic/writer_example.py @@ -8,28 +8,37 @@ async def connect(): - db = ydb.aio.Driver(connection_string="grpc://localhost:2135?database=/local", credentials=ydb.credentials.AnonymousCredentials()) - reader = ydb.TopicClientAsyncIO(db).topic_writer("/local/topic", producer_and_message_group_id="producer-id", ) + db = ydb.aio.Driver( + connection_string="grpc://localhost:2135?database=/local", + credentials=ydb.credentials.AnonymousCredentials(), + ) + reader = ydb.TopicClientAsyncIO(db).topic_writer( + "/local/topic", + producer_and_message_group_id="producer-id", + ) def create_writer(db: ydb.Driver): - with ydb.TopicClient(db).topic_writer("/database/topic/path", - producer_and_message_group_id="producer-id", - ) as writer: + with ydb.TopicClient(db).topic_writer( + "/database/topic/path", + producer_and_message_group_id="producer-id", + ) as writer: pass def connect_and_wait(db: ydb.Driver): - with ydb.TopicClient(db).topic_writer("/database/topic/path", - producer_and_message_group_id="producer-id", - ) as writer: + with ydb.TopicClient(db).topic_writer( + "/database/topic/path", + producer_and_message_group_id="producer-id", + ) as writer: writer.wait() def connect_without_context_manager(db: ydb.Driver): - writer = ydb.TopicClient(db).topic_writer("/database/topic/path", - producer_and_message_group_id="producer-id", - ) + writer = ydb.TopicClient(db).topic_writer( + "/database/topic/path", + producer_and_message_group_id="producer-id", + ) try: pass # some code finally: @@ -45,13 +54,17 @@ def send_messages(writer: ydb.TopicWriter): # full forms writer.write(ydb.TopicWriterMessage("mess")) # send text writer.write(ydb.TopicWriterMessage(bytes([1, 2, 3]))) # send bytes - writer.write(ydb.TopicWriterMessage("mess-1"), ydb.TopicWriterMessage("mess-2")) # send few messages by one call + writer.write( + ydb.TopicWriterMessage("mess-1"), ydb.TopicWriterMessage("mess-2") + ) # send few messages by one call # with meta writer.write(ydb.TopicWriterMessage("asd", seqno=123, created_at_ns=time.time_ns())) -def send_message_without_block_if_internal_buffer_is_full(writer: ydb.TopicWriter, msg) -> bool: +def send_message_without_block_if_internal_buffer_is_full( + writer: ydb.TopicWriter, msg +) -> bool: try: # put message to internal queue for send, but if buffer is full - fast return # without wait @@ -67,10 +80,14 @@ def send_messages_with_manual_seqno(writer: ydb.TopicWriter): def send_messages_with_wait_ack(writer: ydb.TopicWriter): # Explicit future wait - writer.async_write_with_ack(ydb.TopicWriterMessage("mess", seqno=1), ydb.TopicWriterMessage("mess", seqno=2)).result() + writer.async_write_with_ack( + ydb.TopicWriterMessage("mess", seqno=1), ydb.TopicWriterMessage("mess", seqno=2) + ).result() # implicit, by sync call - writer.write_with_ack(ydb.TopicWriterMessage("mess", seqno=1), ydb.TopicWriterMessage("mess", seqno=2)) + writer.write_with_ack( + ydb.TopicWriterMessage("mess", seqno=1), ydb.TopicWriterMessage("mess", seqno=2) + ) # write_with_ack # send with flush @@ -79,7 +96,9 @@ def send_messages_with_wait_ack(writer: ydb.TopicWriter): def send_json_message(db: ydb.Driver): - with ydb.TopicClient(db).topic_writer("/database/path/topic", serializer=json.dumps) as writer: + with ydb.TopicClient(db).topic_writer( + "/database/path/topic", serializer=json.dumps + ) as writer: writer.write({"a": 123}) @@ -102,7 +121,9 @@ def send_messages_and_wait_all_commit_with_results(writer: ydb.TopicWriter): raise future.exception() -def switch_messages_with_many_producers(writers: Dict[str, ydb.TopicWriter], messages: List[str]): +def switch_messages_with_many_producers( + writers: Dict[str, ydb.TopicWriter], messages: List[str] +): futures = [] # type: List[Future] for msg in messages: diff --git a/ydb/__init__.py b/ydb/__init__.py index 7af0087b..6607e1a4 100644 --- a/ydb/__init__.py +++ b/ydb/__init__.py @@ -12,7 +12,7 @@ from .scripting import * # noqa from .import_client import * # noqa from .tracing import * # noqa -from .topic import * # noqa +from .topic import * # noqa try: import ydb.aio as aio # noqa diff --git a/ydb/_topic_reader/topic_reader.py b/ydb/_topic_reader/topic_reader.py index 9c40f5c3..4f65b2fc 100644 --- a/ydb/_topic_reader/topic_reader.py +++ b/ydb/_topic_reader/topic_reader.py @@ -3,8 +3,17 @@ import enum import io import datetime -from typing import Union, Optional, List, Mapping, Callable, Iterable, AsyncIterable, AsyncContextManager, \ - Any +from typing import ( + Union, + Optional, + List, + Mapping, + Callable, + Iterable, + AsyncIterable, + AsyncContextManager, + Any, +) class Selector: @@ -33,7 +42,9 @@ async def sessions_stat(self) -> List["SessionStat"]: """ raise NotImplementedError() - def messages(self, *, timeout: Union[float, None] = None) -> AsyncIterable["Message"]: + def messages( + self, *, timeout: Union[float, None] = None + ) -> AsyncIterable["Message"]: """ Block until receive new message @@ -49,8 +60,13 @@ async def receive_message(self) -> Union["Message", None]: """ raise NotImplementedError() - def batches(self, *, max_messages: Union[int, None] = None, max_bytes: Union[int, None] = None, - timeout: Union[float, None] = None) -> AsyncIterable["Batch"]: + def batches( + self, + *, + max_messages: Union[int, None] = None, + max_bytes: Union[int, None] = None, + timeout: Union[float, None] = None, + ) -> AsyncIterable["Batch"]: """ Block until receive new batch. All messages in a batch from same partition. @@ -59,7 +75,9 @@ def batches(self, *, max_messages: Union[int, None] = None, max_bytes: Union[int """ raise NotImplementedError() - async def receive_batch(self, *, max_messages: Union[int, None] = None, max_bytes: Union[int, None]) -> Union["Batch", None]: + async def receive_batch( + self, *, max_messages: Union[int, None] = None, max_bytes: Union[int, None] + ) -> Union["Batch", None]: """ Get one messages batch from reader. All messages in a batch from same partition. @@ -85,7 +103,9 @@ def commit(self, mess: "ICommittable"): """ raise NotImplementedError() - async def commit_with_ack(self, mess: "ICommittable") -> Union["CommitResult", List["CommitResult"]]: + async def commit_with_ack( + self, mess: "ICommittable" + ) -> Union["CommitResult", List["CommitResult"]]: """ write commit message to a buffer and wait ack from the server. @@ -152,8 +172,13 @@ def async_wait_message(self) -> concurrent.futures.Future: """ raise NotImplementedError() - def batches(self, *, max_messages: Union[int, None] = None, max_bytes: Union[int, None] = None, - timeout: Union[float, None] = None) -> Iterable["Batch"]: + def batches( + self, + *, + max_messages: Union[int, None] = None, + max_bytes: Union[int, None] = None, + timeout: Union[float, None] = None, + ) -> Iterable["Batch"]: """ Block until receive new batch. It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available. @@ -163,8 +188,13 @@ def batches(self, *, max_messages: Union[int, None] = None, max_bytes: Union[int """ raise NotImplementedError() - def receive_batch(self, *, max_messages: Union[int, None] = None, max_bytes: Union[int, None], - timeout: Union[float, None] = None) -> Union["Batch", None]: + def receive_batch( + self, + *, + max_messages: Union[int, None] = None, + max_bytes: Union[int, None], + timeout: Union[float, None] = None, + ) -> Union["Batch", None]: """ Get one messages batch from reader It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available. @@ -183,7 +213,9 @@ def commit(self, mess: "ICommittable"): """ raise NotImplementedError() - def commit_with_ack(self, mess: "ICommittable") -> Union["CommitResult", List["CommitResult"]]: + def commit_with_ack( + self, mess: "ICommittable" + ) -> Union["CommitResult", List["CommitResult"]]: """ write commit message to a buffer and wait ack from the server. @@ -191,7 +223,9 @@ def commit_with_ack(self, mess: "ICommittable") -> Union["CommitResult", List["C """ raise NotImplementedError() - def async_commit_with_ack(self, mess: "ICommittable") -> Union["CommitResult", List["CommitResult"]]: + def async_commit_with_ack( + self, mess: "ICommittable" + ) -> Union["CommitResult", List["CommitResult"]]: """ write commit message to a buffer and return Future for wait result. """ @@ -214,21 +248,24 @@ def close(self): class ReaderSettings: - def __init__(self, *, - consumer: str, - buffer_size_bytes: int = 50 * 1024 * 1024, - on_commit: Callable[["OnCommitEvent"], None] = None, - on_get_partition_start_offset: Callable[ - ["OnPartitionGetStartOffsetRequest"], "OnPartitionGetStartOffsetResponse"] = None, - on_partition_session_start: Callable[["StubEvent"], None] = None, - on_partition_session_stop: Callable[["StubEvent"], None] = None, - on_partition_session_close: Callable[["StubEvent"], None] = None, # todo? - decoder: Union[Mapping[int, Callable[[bytes], bytes]], None] = None, - deserializer: Union[Callable[[bytes], Any], None] = None, - one_attempt_connection_timeout: Union[float, None] = 1, - connection_timeout: Union[float, None] = None, - retry_policy: Union["RetryPolicy", None] = None, - ): + def __init__( + self, + *, + consumer: str, + buffer_size_bytes: int = 50 * 1024 * 1024, + on_commit: Callable[["OnCommitEvent"], None] = None, + on_get_partition_start_offset: Callable[ + ["OnPartitionGetStartOffsetRequest"], "OnPartitionGetStartOffsetResponse" + ] = None, + on_partition_session_start: Callable[["StubEvent"], None] = None, + on_partition_session_stop: Callable[["StubEvent"], None] = None, + on_partition_session_close: Callable[["StubEvent"], None] = None, # todo? + decoder: Union[Mapping[int, Callable[[bytes], bytes]], None] = None, + deserializer: Union[Callable[[bytes], Any], None] = None, + one_attempt_connection_timeout: Union[float, None] = 1, + connection_timeout: Union[float, None] = None, + retry_policy: Union["RetryPolicy", None] = None, + ): raise NotImplementedError() @@ -259,7 +296,9 @@ class Message(ICommittable, ISessionAlive): offset: int written_at_ns: int producer_id: int - data: Union[bytes, Any] # set as original decompressed bytes or deserialized object if deserializer set in reader + data: Union[ + bytes, Any + ] # set as original decompressed bytes or deserialized object if deserializer set in reader def __init__(self): self.seqno = -1 @@ -335,8 +374,8 @@ class CommitResult: class State(enum.Enum): UNSENT = 1 # commit didn't send to the server - SENT = 2 # commit was sent to server, but ack hasn't received - ACKED = 3 # ack from server is received + SENT = 2 # commit was sent to server, but ack hasn't received + ACKED = 3 # ack from server is received class SessionStat: diff --git a/ydb/_topic_wrapper/common.py b/ydb/_topic_wrapper/common.py index bc5beb0b..50291c75 100644 --- a/ydb/_topic_wrapper/common.py +++ b/ydb/_topic_wrapper/common.py @@ -11,10 +11,18 @@ # Workaround for good autocomplete in IDE and universal import at runtime if False: - from ydb._grpc.v4.protos import ydb_status_codes_pb2, ydb_issue_message_pb2, ydb_topic_pb2 + from ydb._grpc.v4.protos import ( + ydb_status_codes_pb2, + ydb_issue_message_pb2, + ydb_topic_pb2, + ) else: # noinspection PyUnresolvedReferences - from ydb._grpc.common.protos import ydb_status_codes_pb2, ydb_issue_message_pb2, ydb_topic_pb2 + from ydb._grpc.common.protos import ( + ydb_status_codes_pb2, + ydb_issue_message_pb2, + ydb_topic_pb2, + ) class Codec(Enum): @@ -32,7 +40,6 @@ class OffsetsRange: class IToProto(abc.ABC): - @abc.abstractmethod def to_proto(self) -> Message: pass @@ -81,14 +88,15 @@ def __iter__(self): def __next__(self): try: - res = asyncio.run_coroutine_threadsafe(self._queue.get(), self._loop).result() + res = asyncio.run_coroutine_threadsafe( + self._queue.get(), self._loop + ).result() return res except asyncio.QueueEmpty: raise StopIteration() class SyncIteratorToAsyncIterator: - def __init__(self, sync_iterator: typing.Iterator): self._sync_iterator = sync_iterator @@ -118,10 +126,12 @@ async def get(self) -> typing.Any: class IGrpcWrapperAsyncIO(abc.ABC): @abc.abstractmethod - async def receive(self) -> typing.Any: ... + async def receive(self) -> typing.Any: + ... @abc.abstractmethod - def write(self, wrap_message: IToProto): ... + def write(self, wrap_message: IToProto): + ... SupportedDriverType = typing.Union[ydb.Driver, ydb.aio.Driver] @@ -153,11 +163,12 @@ async def _start_asyncio_driver(self, driver: ydb.aio.Driver, stub, method): async def _start_sync_driver(self, driver: ydb.Driver, stub, method): requests_iterator = AsyncQueueToSyncIteratorAsyncIO(self.from_client_grpc) - stream_call = await asyncio.to_thread(driver, - requests_iterator, - stub, - method, - ) + stream_call = await asyncio.to_thread( + driver, + requests_iterator, + stub, + method, + ) self.from_server_grpc = SyncIteratorToAsyncIterator(stream_call.__iter__()) async def receive(self) -> typing.Any: @@ -173,9 +184,11 @@ def write(self, wrap_message: IToProto): class ServerStatus(IFromProto): __slots__ = ("status", "_issues") - def __init__(self, - status: ydb_status_codes_pb2.StatusIds.StatusCode, - issues: typing.Iterable[ydb_issue_message_pb2.IssueMessage]): + def __init__( + self, + status: ydb_status_codes_pb2.StatusIds.StatusCode, + issues: typing.Iterable[ydb_issue_message_pb2.IssueMessage], + ): self.status = status self._issues = issues @@ -184,9 +197,7 @@ def __str__(self): @staticmethod def from_proto(msg: Message) -> "ServerStatus": - return ServerStatus( - msg.status - ) + return ServerStatus(msg.status) def is_success(self) -> bool: return self.status == ydb_status_codes_pb2.StatusIds.SUCCESS diff --git a/ydb/_topic_wrapper/control_plane.py b/ydb/_topic_wrapper/control_plane.py index b8bbdff0..052e8aeb 100644 --- a/ydb/_topic_wrapper/control_plane.py +++ b/ydb/_topic_wrapper/control_plane.py @@ -11,4 +11,3 @@ class CreateTopicRequest: @dataclass class Consumer: name: str - diff --git a/ydb/_topic_wrapper/reader.py b/ydb/_topic_wrapper/reader.py index 51a21c49..9fb091bd 100644 --- a/ydb/_topic_wrapper/reader.py +++ b/ydb/_topic_wrapper/reader.py @@ -114,4 +114,3 @@ class StopPartitionSessionRequest: @dataclass class StopPartitionSessionResponse: partition_session_id: int - diff --git a/ydb/_topic_wrapper/writer.py b/ydb/_topic_wrapper/writer.py index 784dc711..18f821fc 100644 --- a/ydb/_topic_wrapper/writer.py +++ b/ydb/_topic_wrapper/writer.py @@ -7,8 +7,14 @@ from google.protobuf.message import Message -from ydb._topic_wrapper.common import IToProto, IFromProto, ServerStatus, UpdateTokenRequest, UpdateTokenResponse, \ - UnknownGrpcMessageError +from ydb._topic_wrapper.common import ( + IToProto, + IFromProto, + ServerStatus, + UpdateTokenRequest, + UpdateTokenResponse, + UnknownGrpcMessageError, +) # Workaround for good autocomplete in IDE and universal import at runtime if False: @@ -33,12 +39,18 @@ def to_proto(self) -> ydb_topic_pb2.StreamWriteMessage.InitRequest: if self.partitioning is None: pass - elif isinstance(self.partitioning, StreamWriteMessage.PartitioningMessageGroupID): + elif isinstance( + self.partitioning, StreamWriteMessage.PartitioningMessageGroupID + ): proto.message_group_id = self.partitioning.message_group_id - elif isinstance(self.partitioning, StreamWriteMessage.PartitioningPartitionID): + elif isinstance( + self.partitioning, StreamWriteMessage.PartitioningPartitionID + ): proto.partition_id = self.partitioning.partition_id else: - raise Exception("Bad partitioning type at StreamWriteMessage.InitRequest") + raise Exception( + "Bad partitioning type at StreamWriteMessage.InitRequest" + ) if self.write_session_meta: for key in self.write_session_meta: @@ -56,7 +68,9 @@ class InitResponse(IFromProto): status: ServerStatus = None @staticmethod - def from_proto(msg: ydb_topic_pb2.StreamWriteMessage.InitResponse) -> "StreamWriteMessage.InitResponse": + def from_proto( + msg: ydb_topic_pb2.StreamWriteMessage.InitResponse, + ) -> "StreamWriteMessage.InitResponse": codecs = [] # type: typing.List[int] if msg.supported_codecs: for codec in msg.supported_codecs.codecs: @@ -66,7 +80,7 @@ def from_proto(msg: ydb_topic_pb2.StreamWriteMessage.InitResponse) -> "StreamWri last_seq_no=msg.last_seq_no, session_id=msg.session_id, partition_id=msg.partition_id, - supported_codecs=codecs + supported_codecs=codecs, ) @dataclass @@ -82,7 +96,9 @@ class MessageData(IToProto): uncompressed_size: int partitioning: "StreamWriteMessage.PartitioningType" - def to_proto(self) -> ydb_topic_pb2.StreamWriteMessage.WriteRequest.MessageData: + def to_proto( + self, + ) -> ydb_topic_pb2.StreamWriteMessage.WriteRequest.MessageData: proto = ydb_topic_pb2.StreamWriteMessage.WriteRequest.MessageData() proto.seq_no = self.seq_no proto.created_at.FromDatetime(self.created_at) @@ -91,12 +107,18 @@ def to_proto(self) -> ydb_topic_pb2.StreamWriteMessage.WriteRequest.MessageData: if self.partitioning is None: pass - elif isinstance(self.partitioning, StreamWriteMessage.PartitioningPartitionID): + elif isinstance( + self.partitioning, StreamWriteMessage.PartitioningPartitionID + ): proto.partition_id = self.partitioning.partition_id - elif isinstance(self.partitioning, StreamWriteMessage.PartitioningMessageGroupID): + elif isinstance( + self.partitioning, StreamWriteMessage.PartitioningMessageGroupID + ): proto.message_group_id = self.partitioning.message_group_id else: - raise Exception("Bad partition at StreamWriteMessage.WriteRequest.MessageData") + raise Exception( + "Bad partition at StreamWriteMessage.WriteRequest.MessageData" + ) return proto @@ -118,7 +140,9 @@ class WriteResponse(IFromProto): status: ServerStatus = field(default=None) @staticmethod - def from_proto(msg: ydb_topic_pb2.StreamWriteMessage.WriteResponse) -> "StreamWriteMessage.WriteResponse": + def from_proto( + msg: ydb_topic_pb2.StreamWriteMessage.WriteResponse, + ) -> "StreamWriteMessage.WriteResponse": acks = [] for proto_ack in msg.acks: ack = StreamWriteMessage.WriteResponse.WriteAck.from_proto(proto_ack) @@ -143,20 +167,26 @@ class WriteAck(IFromProto): message_write_status: Union[ "StreamWriteMessage.WriteResponse.WriteAck.StatusWritten", "StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped", - int + int, ] @classmethod - def from_proto(cls, proto_ack: ydb_topic_pb2.StreamWriteMessage.WriteResponse.WriteAck): + def from_proto( + cls, proto_ack: ydb_topic_pb2.StreamWriteMessage.WriteResponse.WriteAck + ): if proto_ack.HasField("written"): - message_write_status = StreamWriteMessage.WriteResponse.WriteAck.StatusWritten( - proto_ack.written.offset + message_write_status = ( + StreamWriteMessage.WriteResponse.WriteAck.StatusWritten( + proto_ack.written.offset + ) ) elif proto_ack.HasField("skipped"): reason = proto_ack.skipped.reason try: message_write_status = StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped( - reason=StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped.Reason.from_protobuf_code(reason) + reason=StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped.Reason.from_protobuf_code( + reason + ) ) except ValueError: message_write_status = reason @@ -181,12 +211,16 @@ class Reason(enum.Enum): ALREADY_WRITTEN = 1 @classmethod - def from_protobuf_code(cls, code: int) -> Union[ + def from_protobuf_code( + cls, code: int + ) -> Union[ "StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped.Reason", - int + int, ]: try: - return StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped.Reason(code) + return StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped.Reason( + code + ) except ValueError: return code @@ -250,5 +284,7 @@ def from_proto(msg: ydb_topic_pb2.StreamWriteMessage.FromServer) -> typing.Any: StreamWriteMessage.InitRequest, StreamWriteMessage.WriteRequest, UpdateTokenRequest ] WriterMessagesFromServerToClient = Union[ - StreamWriteMessage.InitResponse, StreamWriteMessage.WriteResponse, UpdateTokenResponse + StreamWriteMessage.InitResponse, + StreamWriteMessage.WriteResponse, + UpdateTokenResponse, ] diff --git a/ydb/_topic_writer/topic_writer.py b/ydb/_topic_writer/topic_writer.py index 0e7231c2..7f734173 100644 --- a/ydb/_topic_writer/topic_writer.py +++ b/ydb/_topic_writer/topic_writer.py @@ -33,8 +33,12 @@ def close(self): MessageType = typing.Union["PublicMessage", "Message.SimpleMessageSourceType"] - def write(self, message: Union[MessageType, List[MessageType]], *args: Optional[MessageType], - timeout: [float, None] = None): + def write( + self, + message: Union[MessageType, List[MessageType]], + *args: Optional[MessageType], + timeout: [float, None] = None, + ): """ send one or number of messages to server. it fast put message to internal buffer, without wait message result @@ -49,8 +53,12 @@ def write(self, message: Union[MessageType, List[MessageType]], *args: Optional[ """ raise NotImplementedError() - def async_write_with_ack(self, message: Union[MessageType, List[MessageType]], *args: Optional[MessageType], - timeout: [float, None] = None) -> concurrent.futures.Future: + def async_write_with_ack( + self, + message: Union[MessageType, List[MessageType]], + *args: Optional[MessageType], + timeout: [float, None] = None, + ) -> concurrent.futures.Future: """ send one or number of messages to server. return feature, which can be waited for check send result: ack/duplicate/error @@ -66,9 +74,12 @@ def async_write_with_ack(self, message: Union[MessageType, List[MessageType]], * """ raise NotImplementedError() - def write_with_ack(self, message: Union[MessageType, List[MessageType]], *args: Optional[MessageType], - buffer_timeout: [float, None] = None) -> Union[ - "MessageWriteStatus", List["MessageWriteStatus"]]: + def write_with_ack( + self, + message: Union[MessageType, List[MessageType]], + *args: Optional[MessageType], + buffer_timeout: [float, None] = None, + ) -> Union["MessageWriteStatus", List["MessageWriteStatus"]]: """ IT IS SLOWLY WAY. IT IS BAD CHOISE IN MOST CASES. It is recommended to use write with optionally flush or async_write_with_ack and receive acks by wait future. @@ -148,9 +159,7 @@ class PublicWriterSettings: class PublicWriteResult: @dataclass(eq=True) class Written: - __slots__ = ( - "offset" - ) + __slots__ = "offset" offset: int @dataclass(eq=True) @@ -174,7 +183,9 @@ def create_init_request(self) -> StreamWriteMessage.InitRequest: def get_partitioning(self) -> StreamWriteMessage.PartitioningType: if self.partition_id is not None: return StreamWriteMessage.PartitioningPartitionID(self.partition_id) - return StreamWriteMessage.PartitioningMessageGroupID(self.producer_and_message_group_id) + return StreamWriteMessage.PartitioningMessageGroupID( + self.producer_and_message_group_id + ) class SendMode(Enum): @@ -184,9 +195,7 @@ class SendMode(Enum): @dataclass class PublicWriterInitInfo: - __slots__ = ( - "last_seqno" - ) + __slots__ = "last_seqno" last_seqno: Optional[int] @@ -197,11 +206,13 @@ class PublicMessage: SimpleMessageSourceType = Union[str, bytes, TextIO, BinaryIO] - def __init__(self, - data: SimpleMessageSourceType, *, - seqno: Optional[int] = None, - created_at: Optional[datetime.datetime] = None, - ): + def __init__( + self, + data: SimpleMessageSourceType, + *, + seqno: Optional[int] = None, + created_at: Optional[datetime.datetime] = None, + ): self.seqno = seqno self.created_at = created_at self.data = data @@ -215,7 +226,7 @@ def __init__(self, mess: PublicMessage): created_at=mess.created_at, data=mess.data, uncompressed_size=len(mess.data), - partitioning = None, + partitioning=None, ) def get_bytes(self) -> bytes: @@ -265,7 +276,9 @@ class TopicWriterRepeatableError(TopicWriterError): class TopicWriterStopped(TopicWriterError): def __init__(self): - super(TopicWriterStopped, self).__init__("topic writer was stopped by call close") + super(TopicWriterStopped, self).__init__( + "topic writer was stopped by call close" + ) def default_serializer_message_content(data: Any) -> bytes: @@ -276,11 +289,13 @@ def default_serializer_message_content(data: Any) -> bytes: if isinstance(data, bytearray): return bytes(data) if isinstance(data, str): - return data.encode(encoding='utf-8') + return data.encode(encoding="utf-8") raise ValueError("can't serialize type %s to bytes" % type(data)) -def messages_to_proto_requests(messages: List[InternalMessage]) -> List[StreamWriteMessage.FromClient]: +def messages_to_proto_requests( + messages: List[InternalMessage], +) -> List[StreamWriteMessage.FromClient]: # todo split by proto message size and codec res = [] for msg in messages: diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index daa4705a..f0292a81 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -5,9 +5,22 @@ import ydb from .topic_writer import * -from .. import _apis, YDB_AUTH_TICKET_HEADER, issues, check_retriable_error, RetrySettings -from .._topic_wrapper.common import UpdateTokenResponse, UpdateTokenRequest, QueueToIteratorAsyncIO, Codec, \ - GrpcWrapperAsyncIO, IGrpcWrapperAsyncIO, SupportedDriverType +from .. import ( + _apis, + YDB_AUTH_TICKET_HEADER, + issues, + check_retriable_error, + RetrySettings, +) +from .._topic_wrapper.common import ( + UpdateTokenResponse, + UpdateTokenRequest, + QueueToIteratorAsyncIO, + Codec, + GrpcWrapperAsyncIO, + IGrpcWrapperAsyncIO, + SupportedDriverType, +) from .._topic_wrapper.writer import StreamWriteMessage, WriterMessagesFromServerToClient # Workaround for good autocomplete in IDE and universal import at runtime @@ -31,7 +44,9 @@ def last_seqno(self) -> int: def __init__(self, driver: SupportedDriverType, settings: PublicWriterSettings): self._loop = asyncio.get_running_loop() self._closed = False - self._reconnector = WriterAsyncIOReconnector(driver=driver, settings=WriterSettings(settings)) + self._reconnector = WriterAsyncIOReconnector( + driver=driver, settings=WriterSettings(settings) + ) async def __aenter__(self): return self @@ -53,10 +68,11 @@ async def close(self): await self._reconnector.close() - async def write_with_ack(self, - messages: Union[Writer.MessageType, List[Writer.MessageType]], - *args: Optional[Writer.MessageType], - ) -> Union[PublicWriteResult, List[PublicWriteResult]]: + async def write_with_ack( + self, + messages: Union[Writer.MessageType, List[Writer.MessageType]], + *args: Optional[Writer.MessageType], + ) -> Union[PublicWriteResult, List[PublicWriteResult]]: """ IT IS SLOWLY WAY. IT IS BAD CHOISE IN MOST CASES. It is recommended to use write with optionally flush or write_with_ack_futures and receive acks by wait futures. @@ -81,10 +97,11 @@ async def write_with_ack(self, raise NotImplementedError() - async def write_with_ack_future(self, - messages: Union[Writer.MessageType, List[Writer.MessageType]], - *args: Optional[Writer.MessageType], - ) -> Union[asyncio.Future, List[asyncio.Future]]: + async def write_with_ack_future( + self, + messages: Union[Writer.MessageType, List[Writer.MessageType]], + *args: Optional[Writer.MessageType], + ) -> Union[asyncio.Future, List[asyncio.Future]]: """ send one or number of messages to server. return feature, which can be waited for check send result. @@ -103,10 +120,11 @@ async def write_with_ack_future(self, return await self._reconnector.write_with_ack(messages) raise NotImplementedError() - async def write(self, - messages: Union[Writer.MessageType, List[Writer.MessageType]], - *args: Optional[Writer.MessageType], - ): + async def write( + self, + messages: Union[Writer.MessageType, List[Writer.MessageType]], + *args: Optional[Writer.MessageType], + ): """ send one or number of messages to server. it put message to internal buffer @@ -176,7 +194,9 @@ async def close(self): async def wait_init(self) -> PublicWriterInitInfo: return await self._init_info - async def write_with_ack(self, messages: List[PublicMessage]) -> List[asyncio.Future]: + async def write_with_ack( + self, messages: List[PublicMessage] + ) -> List[asyncio.Future]: # todo check internal buffer limit await self._check_stop() @@ -209,12 +229,18 @@ def _prepare_internal_messages_locked(self, messages: List[PublicMessage]): self._last_known_seq_no += 1 internal_message.seq_no = self._last_known_seq_no else: - raise TopicWriterError("Explicit seqno and auto_seq setting is mutual exclusive") + raise TopicWriterError( + "Explicit seqno and auto_seq setting is mutual exclusive" + ) else: if internal_message.seq_no is None or internal_message.seq_no == 0: - raise TopicWriterError("Empty seqno and auto_seq setting is disabled") + raise TopicWriterError( + "Empty seqno and auto_seq setting is disabled" + ) elif internal_message.seq_no <= self._last_known_seq_no: - raise TopicWriterError("Message seqno is duplicated: %s" % internal_message.seq_no) + raise TopicWriterError( + "Message seqno is duplicated: %s" % internal_message.seq_no + ) else: self._last_known_seq_no = internal_message.seq_no @@ -236,7 +262,7 @@ async def _check_stop(self): raise self._stop_reason async def _connection_loop(self): - retry_settings = RetrySettings() # todo + retry_settings = RetrySettings() # todo while True: attempt = 0 # todo calc and reset @@ -250,22 +276,32 @@ async def on_stop(): # noinspection PyBroadException try: - stream_writer = await WriterAsyncIOStream.create(self._driver, self._init_message, self._get_token) + stream_writer = await WriterAsyncIOStream.create( + self._driver, self._init_message, self._get_token + ) try: async with self._lock: self._last_known_seq_no = stream_writer.last_seqno - self._init_info.set_result(PublicWriterInitInfo(last_seqno=stream_writer.last_seqno)) + self._init_info.set_result( + PublicWriterInitInfo(last_seqno=stream_writer.last_seqno) + ) except asyncio.InvalidStateError: pass self._stream_connected.set() - send_loop = asyncio.create_task(self._send_loop(stream_writer), name="writer send loop") - receive_loop = asyncio.create_task(self._read_loop(stream_writer), name="writer receive loop") + send_loop = asyncio.create_task( + self._send_loop(stream_writer), name="writer send loop" + ) + receive_loop = asyncio.create_task( + self._read_loop(stream_writer), name="writer receive loop" + ) pending = [send_loop, receive_loop] - done, pending = await asyncio.wait([send_loop, receive_loop], return_when=asyncio.FIRST_COMPLETED) + done, pending = await asyncio.wait( + [send_loop, receive_loop], return_when=asyncio.FIRST_COMPLETED + ) done.pop().result() except issues.Error as err: # todo log error @@ -299,10 +335,12 @@ def _handle_receive_ack_need_lock(self, ack): message_future = self._messages_future.popleft() if current_message.seq_no != ack.seq_no: raise TopicWriterError( - "internal error - receive unexpected ack. Expected seqno: %s, received seqno: %s" % - (current_message.seq_no, ack.seq_no) + "internal error - receive unexpected ack. Expected seqno: %s, received seqno: %s" + % (current_message.seq_no, ack.seq_no) ) - message_future.set_result(None) # todo - return result with offset or skip status + message_future.set_result( + None + ) # todo - return result with offset or skip status async def _send_loop(self, writer: "WriterAsyncIOStream"): try: @@ -350,42 +388,49 @@ class WriterAsyncIOStream: _requests: asyncio.Queue _responses: AsyncIterator - def __init__(self, - token_getter: "TokenGetter", - ): + def __init__( + self, + token_getter: "TokenGetter", + ): self._token_getter = token_getter @staticmethod - async def create(driver: SupportedDriverType, init_request: StreamWriteMessage.InitRequest, token_getter: "TokenGetter")\ - -> "WriterAsyncIOStream": + async def create( + driver: SupportedDriverType, + init_request: StreamWriteMessage.InitRequest, + token_getter: "TokenGetter", + ) -> "WriterAsyncIOStream": stream = GrpcWrapperAsyncIO(StreamWriteMessage.FromServer.from_proto) - await stream.start(driver, _apis.TopicService.Stub, _apis.TopicService.StreamWrite) + await stream.start( + driver, _apis.TopicService.Stub, _apis.TopicService.StreamWrite + ) writer = WriterAsyncIOStream(token_getter) - await writer._start( - stream, - init_request - ) + await writer._start(stream, init_request) return writer - @staticmethod - async def _create_stream_from_async(driver: ydb.aio.Driver, init_request: StreamWriteMessage.InitRequest, token_getter: "TokenGetter")\ - -> "WriterAsyncIOStream": + async def _create_stream_from_async( + driver: ydb.aio.Driver, + init_request: StreamWriteMessage.InitRequest, + token_getter: "TokenGetter", + ) -> "WriterAsyncIOStream": return GrpcWrapperAsyncIO(StreamWriteMessage.FromServer.from_proto) @staticmethod - async def _create_from_sync(driver: ydb.Driver, init_request: StreamWriteMessage.InitRequest, token_getter: "TokenGetter")\ - -> "WriterAsyncIOStream": + async def _create_from_sync( + driver: ydb.Driver, + init_request: StreamWriteMessage.InitRequest, + token_getter: "TokenGetter", + ) -> "WriterAsyncIOStream": stream = GrpcWrapperAsyncIO(StreamWriteMessage.FromServer.from_proto) - await stream.start(driver, _apis.TopicService.Stub, _apis.TopicService.StreamWrite) + await stream.start( + driver, _apis.TopicService.Stub, _apis.TopicService.StreamWrite + ) writer = WriterAsyncIOStream(token_getter) - await writer._start( - stream, - init_request - ) + await writer._start(stream, init_request) return writer async def receive(self) -> StreamWriteMessage.WriteResponse: @@ -400,7 +445,9 @@ async def receive(self) -> StreamWriteMessage.WriteResponse: # todo log unknown messages instead of raise exception raise Exception("Unknown message while read writer answers: %s" % item) - async def _start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamWriteMessage.InitRequest): + async def _start( + self, stream: IGrpcWrapperAsyncIO, init_message: StreamWriteMessage.InitRequest + ): stream.write(StreamWriteMessage.FromClient(init_message)) resp = await stream.receive() @@ -415,7 +462,9 @@ async def _start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamWriteMes @staticmethod def _ensure_ok(message: WriterMessagesFromServerToClient): if not message.status.is_success(): - raise TopicWriterError("status error from server in writer: %s", message.status) + raise TopicWriterError( + "status error from server in writer: %s", message.status + ) def write(self, messages: List[InternalMessage]): for request in messages_to_proto_requests(messages): diff --git a/ydb/_topic_writer/topic_writer_asyncio_test.py b/ydb/_topic_writer/topic_writer_asyncio_test.py index 4ea16764..154e0fea 100644 --- a/ydb/_topic_writer/topic_writer_asyncio_test.py +++ b/ydb/_topic_writer/topic_writer_asyncio_test.py @@ -15,12 +15,23 @@ from .. import aio from .. import StatusCode, issues from .._topic_wrapper.common import ServerStatus, IGrpcWrapperAsyncIO, IToProto, Codec -from .topic_writer import InternalMessage, PublicMessage, WriterSettings, PublicWriterSettings, \ - PublicWriterInitInfo, PublicWriteResult, TopicWriterError +from .topic_writer import ( + InternalMessage, + PublicMessage, + WriterSettings, + PublicWriterSettings, + PublicWriterInitInfo, + PublicWriteResult, + TopicWriterError, +) from .._topic_wrapper.writer import StreamWriteMessage -from .topic_writer_asyncio import WriterAsyncIOStream, WriterAsyncIOReconnector, TokenGetter, \ - WriterAsyncIO +from .topic_writer_asyncio import ( + WriterAsyncIOStream, + WriterAsyncIOReconnector, + TokenGetter, + WriterAsyncIO, +) @pytest.fixture @@ -60,22 +71,29 @@ def stream(self): @pytest.fixture async def writer_and_stream(self, stream) -> WriterWithMockedStream: - stream.from_server.put_nowait(StreamWriteMessage.InitResponse( - last_seq_no=4, - session_id="123", - partition_id=3, - supported_codecs=[Codec.CODEC_RAW.value, Codec.CODEC_GZIP.value], - status=ServerStatus(StatusCode.SUCCESS, []) - )) + stream.from_server.put_nowait( + StreamWriteMessage.InitResponse( + last_seq_no=4, + session_id="123", + partition_id=3, + supported_codecs=[Codec.CODEC_RAW.value, Codec.CODEC_GZIP.value], + status=ServerStatus(StatusCode.SUCCESS, []), + ) + ) writer = WriterAsyncIOStream(None) - await writer._start(stream, init_message=StreamWriteMessage.InitRequest( - path="/local/test", - producer_id="producer-id", - write_session_meta={"a": "b"}, - partitioning=StreamWriteMessage.PartitioningMessageGroupID(message_group_id="message-group-id"), - get_last_seq_no=False, - )) + await writer._start( + stream, + init_message=StreamWriteMessage.InitRequest( + path="/local/test", + producer_id="producer-id", + write_session_meta={"a": "b"}, + partitioning=StreamWriteMessage.PartitioningMessageGroupID( + message_group_id="message-group-id" + ), + get_last_seq_no=False, + ), + ) await stream.from_client.get() return TestWriterAsyncIOStream.WriterWithMockedStream( @@ -89,16 +107,20 @@ async def test_init_writer(self, stream): path="/local/test", producer_id="producer-id", write_session_meta={"a": "b"}, - partitioning=StreamWriteMessage.PartitioningMessageGroupID(message_group_id="message-group-id"), + partitioning=StreamWriteMessage.PartitioningMessageGroupID( + message_group_id="message-group-id" + ), get_last_seq_no=False, ) - stream.from_server.put_nowait(StreamWriteMessage.InitResponse( - last_seq_no=init_seqno, - session_id="123", - partition_id=0, - supported_codecs=[], - status=ServerStatus(StatusCode.SUCCESS, []) - )) + stream.from_server.put_nowait( + StreamWriteMessage.InitResponse( + last_seq_no=init_seqno, + session_id="123", + partition_id=0, + supported_codecs=[], + status=ServerStatus(StatusCode.SUCCESS, []), + ) + ) writer = WriterAsyncIOStream(None) await writer._start(stream, init_message) @@ -112,26 +134,32 @@ async def test_init_writer(self, stream): async def test_write_a_message(self, writer_and_stream: WriterWithMockedStream): data = "123".encode() now = datetime.datetime.now() - writer_and_stream.writer.write([InternalMessage( - PublicMessage( - seqno=1, - created_at=now, - data=data, - ) - )]) - - expected_message = StreamWriteMessage.FromClient(StreamWriteMessage.WriteRequest( - codec=Codec.CODEC_RAW.value, - messages=[ - StreamWriteMessage.WriteRequest.MessageData( - seq_no=1, - created_at=now, - data=data, - uncompressed_size=len(data), - partitioning=None, + writer_and_stream.writer.write( + [ + InternalMessage( + PublicMessage( + seqno=1, + created_at=now, + data=data, + ) ) ] - )) + ) + + expected_message = StreamWriteMessage.FromClient( + StreamWriteMessage.WriteRequest( + codec=Codec.CODEC_RAW.value, + messages=[ + StreamWriteMessage.WriteRequest.MessageData( + seq_no=1, + created_at=now, + data=data, + uncompressed_size=len(data), + partitioning=None, + ) + ], + ) + ) sent_message = await writer_and_stream.stream.from_client.get() assert expected_message == sent_message @@ -163,7 +191,6 @@ async def receive(self) -> StreamWriteMessage.WriteResponse: @pytest.fixture(autouse=True) async def stream_writer_double_queue(self, monkeypatch): - class DoubleQueueWriters: _first: Queue _second: Queue @@ -201,20 +228,26 @@ async def async_create(driver, init_message, token_getter): return res @pytest.fixture - def get_stream_writer(self, stream_writer_double_queue) -> typing.Callable[[], "TestWriterAsyncIOReconnector.StreamWriterMock"]: + def get_stream_writer( + self, stream_writer_double_queue + ) -> typing.Callable[[], "TestWriterAsyncIOReconnector.StreamWriterMock"]: return stream_writer_double_queue.get_second @pytest.fixture def default_settings(self) -> WriterSettings: - return WriterSettings(PublicWriterSettings( - topic="/local/topic", - producer_and_message_group_id="test-producer", - auto_seqno=False, - auto_created_at=False, - )) + return WriterSettings( + PublicWriterSettings( + topic="/local/topic", + producer_and_message_group_id="test-producer", + auto_seqno=False, + auto_created_at=False, + ) + ) @pytest.fixture - def default_write_statistic(self) -> StreamWriteMessage.WriteResponse.WriteStatistics: + def default_write_statistic( + self, + ) -> StreamWriteMessage.WriteResponse.WriteStatistics: return StreamWriteMessage.WriteResponse.WriteStatistics( persisting_time=datetime.timedelta(milliseconds=1), min_queue_wait_time=datetime.timedelta(milliseconds=2), @@ -224,14 +257,16 @@ def default_write_statistic(self) -> StreamWriteMessage.WriteResponse.WriteStati ) @pytest.fixture - async def reconnector(self, default_driver, default_settings) -> WriterAsyncIOReconnector: + async def reconnector( + self, default_driver, default_settings + ) -> WriterAsyncIOReconnector: return WriterAsyncIOReconnector(default_driver, default_settings) async def test_reconnect_and_resent_non_acked_messages_on_retriable_error( - self, - reconnector: WriterAsyncIOReconnector, - get_stream_writer, - default_write_statistic, + self, + reconnector: WriterAsyncIOReconnector, + get_stream_writer, + default_write_statistic, ): now = datetime.datetime.now() data = "123".encode() @@ -257,18 +292,20 @@ async def test_reconnect_and_resent_non_acked_messages_on_retriable_error( assert [InternalMessage(message2)] == messages # ack first message - stream_writer.from_server.put_nowait(StreamWriteMessage.WriteResponse( - partition_id=1, - acks=[ - StreamWriteMessage.WriteResponse.WriteAck( - seq_no=1, - message_write_status=StreamWriteMessage.WriteResponse.WriteAck.StatusWritten( - offset=1 + stream_writer.from_server.put_nowait( + StreamWriteMessage.WriteResponse( + partition_id=1, + acks=[ + StreamWriteMessage.WriteResponse.WriteAck( + seq_no=1, + message_write_status=StreamWriteMessage.WriteResponse.WriteAck.StatusWritten( + offset=1 + ), ) - ) - ], - write_statistics=default_write_statistic, - )) + ], + write_statistics=default_write_statistic, + ) + ) stream_writer.from_server.put_nowait(issues.Overloaded("test")) @@ -279,7 +316,9 @@ async def test_reconnect_and_resent_non_acked_messages_on_retriable_error( assert second_sent_msg == expected_messages await reconnector.close() - async def test_stop_on_unexpected_exception(self, reconnector: WriterAsyncIOReconnector, get_stream_writer): + async def test_stop_on_unexpected_exception( + self, reconnector: WriterAsyncIOReconnector, get_stream_writer + ): class TestException(Exception): pass @@ -292,6 +331,7 @@ class TestException(Exception): ) with pytest.raises(TestException): + async def wait_stop(): while True: await reconnector.write_with_ack([message]) @@ -305,7 +345,9 @@ async def wait_stop(): async def test_wait_init(self, default_driver, default_settings, get_stream_writer): init_seqno = 100 expected_init_info = PublicWriterInitInfo(init_seqno) - with mock.patch.object(TestWriterAsyncIOReconnector, "init_last_seqno", init_seqno): + with mock.patch.object( + TestWriterAsyncIOReconnector, "init_last_seqno", init_seqno + ): reconnector = WriterAsyncIOReconnector(default_driver, default_settings) info = await reconnector.wait_init() assert info == expected_init_info @@ -313,9 +355,13 @@ async def test_wait_init(self, default_driver, default_settings, get_stream_writ reconnector._stream_connected.clear() # force reconnect - with mock.patch.object(TestWriterAsyncIOReconnector, "init_last_seqno", init_seqno+1): + with mock.patch.object( + TestWriterAsyncIOReconnector, "init_last_seqno", init_seqno + 1 + ): stream_writer = get_stream_writer() - stream_writer.from_server.put_nowait(issues.Overloaded("test")) # some retriable error + stream_writer.from_server.put_nowait( + issues.Overloaded("test") + ) # some retriable error await reconnector._stream_connected.wait() info = await reconnector.wait_init() @@ -323,7 +369,9 @@ async def test_wait_init(self, default_driver, default_settings, get_stream_writ await reconnector.close() - async def test_write_message(self, reconnector: WriterAsyncIOReconnector, get_stream_writer): + async def test_write_message( + self, reconnector: WriterAsyncIOReconnector, get_stream_writer + ): stream_writer = get_stream_writer() message = PublicMessage( data="123", @@ -336,9 +384,13 @@ async def test_write_message(self, reconnector: WriterAsyncIOReconnector, get_st await reconnector.close() - async def test_auto_seq_no(self, default_driver, default_settings, get_stream_writer): + async def test_auto_seq_no( + self, default_driver, default_settings, get_stream_writer + ): last_seq_no = 100 - with mock.patch.object(TestWriterAsyncIOReconnector, "init_last_seqno", last_seq_no): + with mock.patch.object( + TestWriterAsyncIOReconnector, "init_last_seqno", last_seq_no + ): settings = copy.deepcopy(default_settings) settings.auto_seqno = True @@ -350,13 +402,19 @@ async def test_auto_seq_no(self, default_driver, default_settings, get_stream_wr stream_writer = get_stream_writer() sent = await stream_writer.from_client.get() - assert [InternalMessage(PublicMessage(seqno=last_seq_no+1, data="123"))] == sent + assert [ + InternalMessage(PublicMessage(seqno=last_seq_no + 1, data="123")) + ] == sent sent = await stream_writer.from_client.get() - assert [InternalMessage(PublicMessage(seqno=last_seq_no+2, data="456"))] == sent + assert [ + InternalMessage(PublicMessage(seqno=last_seq_no + 2, data="456")) + ] == sent with pytest.raises(TopicWriterError): - await reconnector.write_with_ack([PublicMessage(seqno=last_seq_no+3, data="123")]) + await reconnector.write_with_ack( + [PublicMessage(seqno=last_seq_no + 3, data="123")] + ) await reconnector.close() @@ -374,7 +432,9 @@ async def test_deny_double_seqno(self, reconnector: WriterAsyncIOReconnector): await reconnector.close() @freezegun.freeze_time("2022-01-13 20:50:00", tz_offset=0) - async def test_auto_created_at(self, default_driver, default_settings, get_stream_writer): + async def test_auto_created_at( + self, default_driver, default_settings, get_stream_writer + ): now = datetime.datetime.now() settings = copy.deepcopy(default_settings) @@ -385,7 +445,9 @@ async def test_auto_created_at(self, default_driver, default_settings, get_strea stream_writer = get_stream_writer() sent = await stream_writer.from_client.get() - assert [InternalMessage(PublicMessage(seqno=4, data="123", created_at=now))] == sent + assert [ + InternalMessage(PublicMessage(seqno=4, data="123", created_at=now)) + ] == sent await reconnector.close() @@ -425,6 +487,7 @@ def default_settings(self) -> PublicWriterSettings: def mock_reconnector_init(self, monkeypatch, reconnector): def t(cls, driver, settings): return reconnector + monkeypatch.setattr(WriterAsyncIOReconnector, "__new__", t) @pytest.fixture @@ -457,6 +520,7 @@ async def ack_first_message(): await reconnector.messages_writted.wait() async with reconnector.lock: reconnector.futures[0].set_result(PublicWriteResult.Written(offset=1)) + asyncio.create_task(ack_first_message()) m = PublicMessage(seqno=1, data="123") @@ -474,8 +538,10 @@ async def ack_next_messages(): async with reconnector.lock: reconnector.futures[0].set_result(PublicWriteResult.Written(offset=2)) reconnector.futures[1].set_result(PublicWriteResult.Skipped()) + asyncio.create_task(ack_next_messages()) - res = await writer.write_with_ack([PublicMessage(seqno=2, data="123"), PublicMessage(seqno=3, data="123")]) + res = await writer.write_with_ack( + [PublicMessage(seqno=2, data="123"), PublicMessage(seqno=3, data="123")] + ) assert res == [PublicWriteResult.Written(offset=2), PublicWriteResult.Skipped()] - diff --git a/ydb/_topic_writer/topic_writer_sync.py b/ydb/_topic_writer/topic_writer_sync.py index bbd6d71e..9c39e5e6 100644 --- a/ydb/_topic_writer/topic_writer_sync.py +++ b/ydb/_topic_writer/topic_writer_sync.py @@ -6,8 +6,14 @@ from typing import Union, List, Optional, Coroutine from .._topic_wrapper.common import SupportedDriverType -from .topic_writer import PublicWriterSettings, TopicWriterError, PublicWriterInitInfo, PublicMessage, Writer, \ - PublicWriteResult +from .topic_writer import ( + PublicWriterSettings, + TopicWriterError, + PublicWriterInitInfo, + PublicMessage, + Writer, + PublicWriteResult, +) from .topic_writer_asyncio import WriterAsyncIO @@ -34,7 +40,11 @@ def start_event_loop(): asyncio.set_event_loop(_shared_event_loop) _shared_event_loop.run_forever() - t = threading.Thread(target=start_event_loop, name="Common ydb topic writer event loop", daemon=True) + t = threading.Thread( + target=start_event_loop, + name="Common ydb topic writer event loop", + daemon=True, + ) t.start() event_loop_set_done.result() @@ -46,11 +56,13 @@ class WriterSync: _async_writer: WriterAsyncIO _closed: bool - def __init__(self, - driver: SupportedDriverType, - settings: PublicWriterSettings, - *, - eventloop: asyncio.AbstractEventLoop = None): + def __init__( + self, + driver: SupportedDriverType, + settings: PublicWriterSettings, + *, + eventloop: asyncio.AbstractEventLoop = None, + ): self._closed = False @@ -62,7 +74,9 @@ def __init__(self, async def create_async_writer(): return WriterAsyncIO(driver, settings) - self._async_writer = asyncio.run_coroutine_threadsafe(create_async_writer(), self._loop).result() + self._async_writer = asyncio.run_coroutine_threadsafe( + create_async_writer(), self._loop + ).result() def _call(self, coro, *args, **kwargs): if self._closed: @@ -82,7 +96,9 @@ def close(self): if self._closed: return self._closed = True - asyncio.run_coroutine_threadsafe(self._async_writer.close(), self._loop).result() + asyncio.run_coroutine_threadsafe( + self._async_writer.close(), self._loop + ).result() def async_flush(self) -> Future: if self._closed: @@ -98,19 +114,27 @@ def async_wait_init(self) -> Future[PublicWriterInitInfo]: def wait_init(self, timeout) -> PublicWriterInitInfo: return self._call_sync(self._async_writer.wait_init(), timeout) - def write(self, message: Union[PublicMessage, List[PublicMessage]], *args: Optional[PublicMessage], - timeout: Union[float, None] = None): + def write( + self, + message: Union[PublicMessage, List[PublicMessage]], + *args: Optional[PublicMessage], + timeout: Union[float, None] = None, + ): self._call_sync(self._async_writer.write(message, *args), timeout=timeout) - def async_write_with_ack(self, - messages: Union[Writer.MessageType, List[Writer.MessageType]], - *args: Optional[Writer.MessageType], - ) -> Future[Union[PublicWriteResult, List[PublicWriteResult]]]: + def async_write_with_ack( + self, + messages: Union[Writer.MessageType, List[Writer.MessageType]], + *args: Optional[Writer.MessageType], + ) -> Future[Union[PublicWriteResult, List[PublicWriteResult]]]: return self._call(self._async_writer.write_with_ack(messages, *args)) - def write_with_ack(self, - messages: Union[Writer.MessageType, List[Writer.MessageType]], - *args: Optional[Writer.MessageType], - timeout: Union[float, None] = None, - ) -> Union[PublicWriteResult, List[PublicWriteResult]]: - return self._call_sync(self._async_writer.write_with_ack(messages, *args), timeout=timeout) + def write_with_ack( + self, + messages: Union[Writer.MessageType, List[Writer.MessageType]], + *args: Optional[Writer.MessageType], + timeout: Union[float, None] = None, + ) -> Union[PublicWriteResult, List[PublicWriteResult]]: + return self._call_sync( + self._async_writer.write_with_ack(messages, *args), timeout=timeout + ) diff --git a/ydb/topic.py b/ydb/topic.py index e644a567..763a7baa 100644 --- a/ydb/topic.py +++ b/ydb/topic.py @@ -3,7 +3,8 @@ import ydb._topic_writer from ydb._topic_reader import ( - Reader as TopicReader, ReaderAsyncIO as TopicReaderAsyncIO, + Reader as TopicReader, + ReaderAsyncIO as TopicReaderAsyncIO, Selector as TopicSelector, ) @@ -23,43 +24,49 @@ class TopicClientAsyncIO: def __init__(self, driver: ydb.aio.Driver, settings: "TopicClientSettings" = None): self._driver = driver - def topic_reader(self, topic: Union[str, TopicSelector, List[Union[str, TopicSelector]]], - consumer: str, - commit_batch_time: Union[float, None] = 0.1, - commit_batch_count: Union[int, None] = 1000, - buffer_size_bytes: int = 50 * 1024 * 1024, - sync_commit: bool = False, # reader.commit(...) will wait commit ack from server - on_commit: Callable[["OnCommitEvent"], None] = None, - on_get_partition_start_offset: Callable[ - ["ydb._topic_reader.Events.OnPartitionGetStartOffsetRequest"], "ydb._topic_reader.Events.OnPartitionGetStartOffsetResponse"] = None, - on_init_partition: Callable[["StubEvent"], None] = None, - on_shutdown_partition: Callable[["StubEvent"], None] = None, - decoder: Union[Mapping[int, Callable[[bytes], bytes]], None] = None, - deserializer: Union[Callable[[bytes], Any], None] = None, - one_attempt_connection_timeout: Union[float, None] = 1, - connection_timeout: Union[float, None] = None, - retry_policy: Union["ydb._topic_reader.RetryPolicy", None] = None, - ) -> TopicReaderAsyncIO: + def topic_reader( + self, + topic: Union[str, TopicSelector, List[Union[str, TopicSelector]]], + consumer: str, + commit_batch_time: Union[float, None] = 0.1, + commit_batch_count: Union[int, None] = 1000, + buffer_size_bytes: int = 50 * 1024 * 1024, + sync_commit: bool = False, # reader.commit(...) will wait commit ack from server + on_commit: Callable[["OnCommitEvent"], None] = None, + on_get_partition_start_offset: Callable[ + ["ydb._topic_reader.Events.OnPartitionGetStartOffsetRequest"], + "ydb._topic_reader.Events.OnPartitionGetStartOffsetResponse", + ] = None, + on_init_partition: Callable[["StubEvent"], None] = None, + on_shutdown_partition: Callable[["StubEvent"], None] = None, + decoder: Union[Mapping[int, Callable[[bytes], bytes]], None] = None, + deserializer: Union[Callable[[bytes], Any], None] = None, + one_attempt_connection_timeout: Union[float, None] = 1, + connection_timeout: Union[float, None] = None, + retry_policy: Union["ydb._topic_reader.RetryPolicy", None] = None, + ) -> TopicReaderAsyncIO: raise NotImplementedError() - def topic_writer(self, topic, - *, - producer_and_message_group_id: str, - session_metadata: Mapping[str, str] = None, - encoders: Union[Mapping[int, Callable[[bytes], bytes]], None] = None, - serializer: Union[Callable[[Any], bytes], None] = None, - send_buffer_count: Union[int, None] = 10000, - send_buffer_bytes: Union[int, None] = 100 * 1024 * 1024, - partition_id: Union[int, None] = None, - codec: Union[int, None] = None, - codec_autoselect: bool = True, - auto_seqno: bool = True, - auto_created_at: bool = True, - get_last_seqno: bool = False, - retry_policy: Union["ydb._topic_writer.RetryPolicy", None] = None, - ) -> TopicWriterAsyncIO: + def topic_writer( + self, + topic, + *, + producer_and_message_group_id: str, + session_metadata: Mapping[str, str] = None, + encoders: Union[Mapping[int, Callable[[bytes], bytes]], None] = None, + serializer: Union[Callable[[Any], bytes], None] = None, + send_buffer_count: Union[int, None] = 10000, + send_buffer_bytes: Union[int, None] = 100 * 1024 * 1024, + partition_id: Union[int, None] = None, + codec: Union[int, None] = None, + codec_autoselect: bool = True, + auto_seqno: bool = True, + auto_created_at: bool = True, + get_last_seqno: bool = False, + retry_policy: Union["ydb._topic_writer.RetryPolicy", None] = None, + ) -> TopicWriterAsyncIO: args = locals() - del args['self'] + del args["self"] settings = TopicWriterSettings(**args) return TopicWriterAsyncIO(self._driver, settings) @@ -68,40 +75,46 @@ class TopicClient: def __init__(self, driver, topic_client_settings: "TopicClientSettings" = None): pass - def topic_reader(self, topic: Union[str, TopicSelector, List[Union[str, TopicSelector]]], - consumer: str, - commit_batch_time: Union[float, None] = 0.1, - commit_batch_count: Union[int, None] = 1000, - buffer_size_bytes: int = 50 * 1024 * 1024, - sync_commit: bool = False, # reader.commit(...) will wait commit ack from server - on_commit: Callable[["OnCommitEvent"], None] = None, - on_get_partition_start_offset: Callable[ - ["ydb._topic_reader.Events.OnPartitionGetStartOffsetRequest"], "ydb._topic_reader.Events.OnPartitionGetStartOffsetResponse"] = None, - on_init_partition: Callable[["StubEvent"], None] = None, - on_shutdown_partition: Callable[["StubEvent"], None] = None, - decoder: Union[Mapping[int, Callable[[bytes], bytes]], None] = None, - deserializer: Union[Callable[[bytes], Any], None] = None, - one_attempt_connection_timeout: Union[float, None] = 1, - connection_timeout: Union[float, None] = None, - retry_policy: Union["ydb._topic_reader.RetryPolicy", None] = None, - ) -> TopicReader: + def topic_reader( + self, + topic: Union[str, TopicSelector, List[Union[str, TopicSelector]]], + consumer: str, + commit_batch_time: Union[float, None] = 0.1, + commit_batch_count: Union[int, None] = 1000, + buffer_size_bytes: int = 50 * 1024 * 1024, + sync_commit: bool = False, # reader.commit(...) will wait commit ack from server + on_commit: Callable[["OnCommitEvent"], None] = None, + on_get_partition_start_offset: Callable[ + ["ydb._topic_reader.Events.OnPartitionGetStartOffsetRequest"], + "ydb._topic_reader.Events.OnPartitionGetStartOffsetResponse", + ] = None, + on_init_partition: Callable[["StubEvent"], None] = None, + on_shutdown_partition: Callable[["StubEvent"], None] = None, + decoder: Union[Mapping[int, Callable[[bytes], bytes]], None] = None, + deserializer: Union[Callable[[bytes], Any], None] = None, + one_attempt_connection_timeout: Union[float, None] = 1, + connection_timeout: Union[float, None] = None, + retry_policy: Union["ydb._topic_reader.RetryPolicy", None] = None, + ) -> TopicReader: raise NotImplementedError() - def topic_writer(self, topic, - producer_and_message_group_id: str, - session_metadata: Mapping[str, str] = None, - encoders: Union[Mapping[int, Callable[[bytes], bytes]], None] = None, - serializer: Union[Callable[[Any], bytes], None] = None, - send_buffer_count: Union[int, None] = 10000, - send_buffer_bytes: Union[int, None] = 100 * 1024 * 1024, - partition_id: Union[int, None] = None, - codec: Union[int, None] = None, - codec_autoselect: bool = True, - auto_seqno: bool = True, - auto_created_at: bool = True, - get_last_seqno: bool = False, - retry_policy: Union["ydb._topic_writer.RetryPolicy", None] = None, - ) -> TopicWriter: + def topic_writer( + self, + topic, + producer_and_message_group_id: str, + session_metadata: Mapping[str, str] = None, + encoders: Union[Mapping[int, Callable[[bytes], bytes]], None] = None, + serializer: Union[Callable[[Any], bytes], None] = None, + send_buffer_count: Union[int, None] = 10000, + send_buffer_bytes: Union[int, None] = 100 * 1024 * 1024, + partition_id: Union[int, None] = None, + codec: Union[int, None] = None, + codec_autoselect: bool = True, + auto_seqno: bool = True, + auto_created_at: bool = True, + get_last_seqno: bool = False, + retry_policy: Union["ydb._topic_writer.RetryPolicy", None] = None, + ) -> TopicWriter: raise NotImplementedError() From b3c4bb550a1ac0518770e0f0ea485a23982466e9 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 24 Jan 2023 14:55:10 +0300 Subject: [PATCH 007/147] fix flake 8 linter --- examples/topic/reader_async_example.py | 1 + examples/topic/reader_example.py | 1 + examples/topic/writer_async_example.py | 5 ++- examples/topic/writer_example.py | 6 ++- ydb/_topic_reader/__init__.py | 1 - ydb/_topic_reader/topic_reader.py | 8 +++- ydb/_topic_wrapper/common.py | 1 - ydb/_topic_wrapper/reader.py | 11 ++--- ydb/_topic_wrapper/writer.py | 1 - ydb/_topic_writer/topic_writer.py | 6 +-- ydb/_topic_writer/topic_writer_asyncio.py | 25 ++++------- .../topic_writer_asyncio_test.py | 1 - ydb/topic.py | 42 ++++++++++--------- 13 files changed, 51 insertions(+), 58 deletions(-) diff --git a/examples/topic/reader_async_example.py b/examples/topic/reader_async_example.py index 540e780f..e702903f 100644 --- a/examples/topic/reader_async_example.py +++ b/examples/topic/reader_async_example.py @@ -13,6 +13,7 @@ async def connect(): reader = ydb.TopicClientAsyncIO(db).topic_reader( "/local/topic", consumer="consumer" ) + return reader async def create_reader_and_close_with_context_manager(db: ydb.aio.Driver): diff --git a/examples/topic/reader_example.py b/examples/topic/reader_example.py index 130679c1..7cea2a35 100644 --- a/examples/topic/reader_example.py +++ b/examples/topic/reader_example.py @@ -10,6 +10,7 @@ def connect(): credentials=ydb.credentials.AnonymousCredentials(), ) reader = ydb.TopicClient(db).topic_reader("/local/topic", consumer="consumer") + return reader def create_reader_and_close_with_context_manager(db: ydb.Driver): diff --git a/examples/topic/writer_async_example.py b/examples/topic/writer_async_example.py index 1db7ce39..6dd37490 100644 --- a/examples/topic/writer_async_example.py +++ b/examples/topic/writer_async_example.py @@ -1,9 +1,10 @@ import asyncio import json import time -from typing import Dict, List, Set +from typing import Dict, List import ydb +from ydb import TopicWriterMessage async def create_writer(db: ydb.aio.Driver): @@ -11,7 +12,7 @@ async def create_writer(db: ydb.aio.Driver): "/database/topic/path", producer_and_message_group_id="producer-id", ) as writer: - pass + await writer.write(TopicWriterMessage("asd")) async def connect_and_wait(db: ydb.aio.Driver): diff --git a/examples/topic/writer_example.py b/examples/topic/writer_example.py index bb9e1bea..27387e11 100644 --- a/examples/topic/writer_example.py +++ b/examples/topic/writer_example.py @@ -5,6 +5,7 @@ from concurrent.futures import Future, wait import ydb +from ydb import TopicWriterMessage async def connect(): @@ -12,10 +13,11 @@ async def connect(): connection_string="grpc://localhost:2135?database=/local", credentials=ydb.credentials.AnonymousCredentials(), ) - reader = ydb.TopicClientAsyncIO(db).topic_writer( + writer = ydb.TopicClientAsyncIO(db).topic_writer( "/local/topic", producer_and_message_group_id="producer-id", ) + await writer.write(TopicWriterMessage("asd")) def create_writer(db: ydb.Driver): @@ -23,7 +25,7 @@ def create_writer(db: ydb.Driver): "/database/topic/path", producer_and_message_group_id="producer-id", ) as writer: - pass + writer.write(TopicWriterMessage("asd")) def connect_and_wait(db: ydb.Driver): diff --git a/ydb/_topic_reader/__init__.py b/ydb/_topic_reader/__init__.py index 3aab85c2..e69de29b 100644 --- a/ydb/_topic_reader/__init__.py +++ b/ydb/_topic_reader/__init__.py @@ -1 +0,0 @@ -from .topic_reader import * diff --git a/ydb/_topic_reader/topic_reader.py b/ydb/_topic_reader/topic_reader.py index 4f65b2fc..bc2f6cb5 100644 --- a/ydb/_topic_reader/topic_reader.py +++ b/ydb/_topic_reader/topic_reader.py @@ -253,9 +253,9 @@ def __init__( *, consumer: str, buffer_size_bytes: int = 50 * 1024 * 1024, - on_commit: Callable[["OnCommitEvent"], None] = None, + on_commit: Callable[["Events.OnCommit"], None] = None, on_get_partition_start_offset: Callable[ - ["OnPartitionGetStartOffsetRequest"], "OnPartitionGetStartOffsetResponse" + ["Events.OnPartitionGetStartOffsetRequest"], "Events.OnPartitionGetStartOffsetResponse" ] = None, on_partition_session_start: Callable[["StubEvent"], None] = None, on_partition_session_stop: Callable[["StubEvent"], None] = None, @@ -390,3 +390,7 @@ class SessionStat: class OffsetRange: start: int end: int + + +class StubEvent: + pass diff --git a/ydb/_topic_wrapper/common.py b/ydb/_topic_wrapper/common.py index 50291c75..2f07fde4 100644 --- a/ydb/_topic_wrapper/common.py +++ b/ydb/_topic_wrapper/common.py @@ -1,6 +1,5 @@ import abc import asyncio -import queue import typing from dataclasses import dataclass from enum import Enum diff --git a/ydb/_topic_wrapper/reader.py b/ydb/_topic_wrapper/reader.py index 9fb091bd..0a84d343 100644 --- a/ydb/_topic_wrapper/reader.py +++ b/ydb/_topic_wrapper/reader.py @@ -1,12 +1,7 @@ -import abc import datetime -import typing -from codecs import Codec from dataclasses import dataclass, field from typing import List, Union, Dict -from google.protobuf.message import Message - from ydb._topic_wrapper.common import OffsetsRange @@ -53,7 +48,7 @@ class MessageData: @dataclass class Batch: - message_data: List["MessageData"] + message_data: List["StreamReadMessage.ReadResponse.MessageData"] producer_id: str write_session_meta: Dict[str, str] codec: int @@ -62,7 +57,7 @@ class Batch: @dataclass class PartitionData: partition_session_id: int - batches: List["Batch"] + batches: List["StreamReadMessage.ReadResponse.Batch"] @dataclass class CommitOffsetRequest: @@ -95,7 +90,7 @@ class PartitionSessionStatusResponse: @dataclass class StartPartitionSessionRequest: - partition_session: "PartitionSession" + partition_session: "StreamReadMessage.PartitionSession" committed_offset: int partition_offsets: OffsetsRange diff --git a/ydb/_topic_wrapper/writer.py b/ydb/_topic_wrapper/writer.py index 18f821fc..6710f544 100644 --- a/ydb/_topic_wrapper/writer.py +++ b/ydb/_topic_wrapper/writer.py @@ -1,4 +1,3 @@ -import asyncio import datetime import enum import typing diff --git a/ydb/_topic_writer/topic_writer.py b/ydb/_topic_writer/topic_writer.py index 7f734173..da614bf2 100644 --- a/ydb/_topic_writer/topic_writer.py +++ b/ydb/_topic_writer/topic_writer.py @@ -1,8 +1,6 @@ -import asyncio import concurrent.futures import datetime import enum -import time from dataclasses import dataclass from enum import Enum from typing import List, Union, TextIO, BinaryIO, Optional, Callable, Mapping, Any, Dict @@ -17,7 +15,7 @@ class Writer: @property def last_seqno(self) -> int: - raise NotImplemented() + raise NotImplementedError() def __init__(self, db: ydb.Driver): pass @@ -31,7 +29,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): def close(self): pass - MessageType = typing.Union["PublicMessage", "Message.SimpleMessageSourceType"] + MessageType = typing.Union["PublicMessage", "PublicMessage.SimpleMessageSourceType"] def write( self, diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index f0292a81..b8463542 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -1,35 +1,26 @@ import asyncio +import datetime import threading from collections import deque -from typing import Dict, Awaitable, Deque, AsyncIterator +from typing import Deque, AsyncIterator, Union, List, Optional, Callable import ydb -from .topic_writer import * +from .topic_writer import PublicWriterSettings, WriterSettings, Writer, PublicWriteResult, PublicMessage, \ + PublicWriterInitInfo, InternalMessage, TopicWriterStopped, TopicWriterError, messages_to_proto_requests from .. import ( _apis, - YDB_AUTH_TICKET_HEADER, issues, check_retriable_error, RetrySettings, ) from .._topic_wrapper.common import ( UpdateTokenResponse, - UpdateTokenRequest, - QueueToIteratorAsyncIO, - Codec, GrpcWrapperAsyncIO, IGrpcWrapperAsyncIO, SupportedDriverType, ) from .._topic_wrapper.writer import StreamWriteMessage, WriterMessagesFromServerToClient -# Workaround for good autocomplete in IDE and universal import at runtime -if False: - from .._grpc.v4.protos import ydb_topic_pb2 -else: - # noinspection PyUnresolvedReferences - from .._grpc.common.protos import ydb_topic_pb2 - class WriterAsyncIO: _loop: asyncio.AbstractEventLoop @@ -39,7 +30,7 @@ class WriterAsyncIO: @property def last_seqno(self) -> int: - raise NotImplemented() + raise NotImplementedError() def __init__(self, driver: SupportedDriverType, settings: PublicWriterSettings): self._loop = asyncio.get_running_loop() @@ -268,7 +259,7 @@ async def _connection_loop(self): attempt = 0 # todo calc and reset pending = [] - async def on_stop(): + async def on_stop(e): for t in pending: self._background_tasks.append(t) pending.clear() @@ -309,13 +300,13 @@ async def on_stop(): err_info = check_retriable_error(err, retry_settings, attempt) if not err_info.is_retriable: - await on_stop() + await on_stop(err) return await asyncio.sleep(err_info.sleep_timeout_seconds) except Exception as e: - await on_stop() + await on_stop(e) return finally: if len(pending) > 0: diff --git a/ydb/_topic_writer/topic_writer_asyncio_test.py b/ydb/_topic_writer/topic_writer_asyncio_test.py index 154e0fea..5d157f15 100644 --- a/ydb/_topic_writer/topic_writer_asyncio_test.py +++ b/ydb/_topic_writer/topic_writer_asyncio_test.py @@ -29,7 +29,6 @@ from .topic_writer_asyncio import ( WriterAsyncIOStream, WriterAsyncIOReconnector, - TokenGetter, WriterAsyncIO, ) diff --git a/ydb/topic.py b/ydb/topic.py index 763a7baa..6de3a847 100644 --- a/ydb/topic.py +++ b/ydb/topic.py @@ -1,27 +1,31 @@ from typing import List, Callable, Union, Mapping, Any -import ydb._topic_writer - -from ydb._topic_reader import ( +from . import aio, Credentials +from ._topic_reader.topic_reader import ( Reader as TopicReader, ReaderAsyncIO as TopicReaderAsyncIO, Selector as TopicSelector, + Events as TopicReaderEvents, + RetryPolicy as TopicReaderRetryPolicy, + StubEvent as TopicReaderStubEvent, ) -from ydb._topic_writer.topic_writer import ( + +from ._topic_writer.topic_writer import ( # noqa: F401 Writer as TopicWriter, PublicWriterSettings as TopicWriterSettings, PublicMessage as TopicWriterMessage, + RetryPolicy as TopicWriterRetryPolicy, ) from ydb._topic_writer.topic_writer_asyncio import WriterAsyncIO as TopicWriterAsyncIO class TopicClientAsyncIO: - _driver: ydb.aio.Driver - _credentials: Union[ydb.Credentials, None] + _driver: aio.Driver + _credentials: Union[Credentials, None] - def __init__(self, driver: ydb.aio.Driver, settings: "TopicClientSettings" = None): + def __init__(self, driver: aio.Driver, settings: "TopicClientSettings" = None): self._driver = driver def topic_reader( @@ -32,18 +36,18 @@ def topic_reader( commit_batch_count: Union[int, None] = 1000, buffer_size_bytes: int = 50 * 1024 * 1024, sync_commit: bool = False, # reader.commit(...) will wait commit ack from server - on_commit: Callable[["OnCommitEvent"], None] = None, + on_commit: Callable[["TopicReaderEvents.OnCommit"], None] = None, on_get_partition_start_offset: Callable[ - ["ydb._topic_reader.Events.OnPartitionGetStartOffsetRequest"], - "ydb._topic_reader.Events.OnPartitionGetStartOffsetResponse", + ["TopicReaderEvents.OnPartitionGetStartOffsetRequest"], + "TopicReaderEvents.OnPartitionGetStartOffsetResponse", ] = None, - on_init_partition: Callable[["StubEvent"], None] = None, - on_shutdown_partition: Callable[["StubEvent"], None] = None, + on_init_partition: Callable[["TopicReaderStubEvent"], None] = None, + on_shutdown_partition: Callable[["TopicReaderStubEvent"], None] = None, decoder: Union[Mapping[int, Callable[[bytes], bytes]], None] = None, deserializer: Union[Callable[[bytes], Any], None] = None, one_attempt_connection_timeout: Union[float, None] = 1, connection_timeout: Union[float, None] = None, - retry_policy: Union["ydb._topic_reader.RetryPolicy", None] = None, + retry_policy: Union["TopicReaderRetryPolicy", None] = None, ) -> TopicReaderAsyncIO: raise NotImplementedError() @@ -63,7 +67,7 @@ def topic_writer( auto_seqno: bool = True, auto_created_at: bool = True, get_last_seqno: bool = False, - retry_policy: Union["ydb._topic_writer.RetryPolicy", None] = None, + retry_policy: Union["TopicWriterRetryPolicy", None] = None, ) -> TopicWriterAsyncIO: args = locals() del args["self"] @@ -83,10 +87,10 @@ def topic_reader( commit_batch_count: Union[int, None] = 1000, buffer_size_bytes: int = 50 * 1024 * 1024, sync_commit: bool = False, # reader.commit(...) will wait commit ack from server - on_commit: Callable[["OnCommitEvent"], None] = None, + on_commit: Callable[["TopicReaderStubEvent"], None] = None, on_get_partition_start_offset: Callable[ - ["ydb._topic_reader.Events.OnPartitionGetStartOffsetRequest"], - "ydb._topic_reader.Events.OnPartitionGetStartOffsetResponse", + ["TopicReaderEvents.OnPartitionGetStartOffsetRequest"], + "TopicReaderEvents.OnPartitionGetStartOffsetResponse", ] = None, on_init_partition: Callable[["StubEvent"], None] = None, on_shutdown_partition: Callable[["StubEvent"], None] = None, @@ -94,7 +98,7 @@ def topic_reader( deserializer: Union[Callable[[bytes], Any], None] = None, one_attempt_connection_timeout: Union[float, None] = 1, connection_timeout: Union[float, None] = None, - retry_policy: Union["ydb._topic_reader.RetryPolicy", None] = None, + retry_policy: Union["TopicReaderRetryPolicy", None] = None, ) -> TopicReader: raise NotImplementedError() @@ -113,7 +117,7 @@ def topic_writer( auto_seqno: bool = True, auto_created_at: bool = True, get_last_seqno: bool = False, - retry_policy: Union["ydb._topic_writer.RetryPolicy", None] = None, + retry_policy: Union["TopicWriterRetryPolicy", None] = None, ) -> TopicWriter: raise NotImplementedError() From 96a749b5ba24afa1a883ad9c6182da49618c78c1 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 24 Jan 2023 15:01:27 +0300 Subject: [PATCH 008/147] black once yet --- ydb/_topic_reader/topic_reader.py | 3 ++- ydb/_topic_writer/topic_writer_asyncio.py | 14 ++++++++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/ydb/_topic_reader/topic_reader.py b/ydb/_topic_reader/topic_reader.py index bc2f6cb5..ac961674 100644 --- a/ydb/_topic_reader/topic_reader.py +++ b/ydb/_topic_reader/topic_reader.py @@ -255,7 +255,8 @@ def __init__( buffer_size_bytes: int = 50 * 1024 * 1024, on_commit: Callable[["Events.OnCommit"], None] = None, on_get_partition_start_offset: Callable[ - ["Events.OnPartitionGetStartOffsetRequest"], "Events.OnPartitionGetStartOffsetResponse" + ["Events.OnPartitionGetStartOffsetRequest"], + "Events.OnPartitionGetStartOffsetResponse", ] = None, on_partition_session_start: Callable[["StubEvent"], None] = None, on_partition_session_stop: Callable[["StubEvent"], None] = None, diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index b8463542..b4373b17 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -5,8 +5,18 @@ from typing import Deque, AsyncIterator, Union, List, Optional, Callable import ydb -from .topic_writer import PublicWriterSettings, WriterSettings, Writer, PublicWriteResult, PublicMessage, \ - PublicWriterInitInfo, InternalMessage, TopicWriterStopped, TopicWriterError, messages_to_proto_requests +from .topic_writer import ( + PublicWriterSettings, + WriterSettings, + Writer, + PublicWriteResult, + PublicMessage, + PublicWriterInitInfo, + InternalMessage, + TopicWriterStopped, + TopicWriterError, + messages_to_proto_requests, +) from .. import ( _apis, issues, From 2bee743a4111dae672ec81246747f82667b6eff0 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 24 Jan 2023 15:40:26 +0300 Subject: [PATCH 009/147] fix anext for python 3.8 --- ydb/_topic_wrapper/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ydb/_topic_wrapper/common.py b/ydb/_topic_wrapper/common.py index 2f07fde4..1bcdfde8 100644 --- a/ydb/_topic_wrapper/common.py +++ b/ydb/_topic_wrapper/common.py @@ -118,7 +118,7 @@ def __init__(self, iterator: typing.AsyncIterator[typing.Any]): async def get(self) -> typing.Any: try: - return anext(self._iterator) + return self._iterator.__anext__() except StopAsyncIteration: raise asyncio.QueueEmpty() From 4bf0e15107e319ee05e44bdff626801cabb01e59 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 24 Jan 2023 15:52:10 +0300 Subject: [PATCH 010/147] Update CHANGELOG.md --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9c6f0e7e..18986cbb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,5 @@ +* Add initial topic writer + ## 3.0.1b1 ## * start 3.0 beta branch From 24651e3836e7c6b0eba956be553842d6cb2eb07c Mon Sep 17 00:00:00 2001 From: robot Date: Tue, 24 Jan 2023 12:53:29 +0000 Subject: [PATCH 011/147] Release: 3.0.1b2 --- CHANGELOG.md | 1 + setup.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 18986cbb..86e3a167 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,4 @@ +## 3.0.1b2 ## * Add initial topic writer ## 3.0.1b1 ## diff --git a/setup.py b/setup.py index 992b0ee9..726b351c 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setuptools.setup( name="ydb", - version="3.0.1b1", # AUTOVERSION + version="3.0.1b2", # AUTOVERSION description="YDB Python SDK", author="Yandex LLC", author_email="ydb@yandex-team.ru", From 295fc6c00cae9f1e5865b1adfe42c0074af3f4e5 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Mon, 30 Jan 2023 19:03:42 +0300 Subject: [PATCH 012/147] move set cred env var from tox to fixture --- tests/conftest.py | 8 ++++++++ tox.ini | 1 - 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 18b18fa0..336ffdae 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,17 @@ import os +from unittest import mock + import pytest import ydb import time +@pytest.fixture(autouse=True, scope="session") +def mock_settings_env_vars(): + with mock.patch.dict(os.environ, {"YDB_ANONYMOUS_CREDENTIALS": "1"}): + yield + + @pytest.fixture(scope="module") def docker_compose_file(pytestconfig): return os.path.join(str(pytestconfig.rootdir), "docker-compose.yml") diff --git a/tox.ini b/tox.ini index 1246209c..b4d04f63 100644 --- a/tox.ini +++ b/tox.ini @@ -8,7 +8,6 @@ ignore_basepython_conflict = true usedevelop = True install_command = pip install {opts} {packages} setenv = - YDB_ANONYMOUS_CREDENTIALS = 1 PYTHONPATH = {env:PYTHONPATH}{:}{toxinidir} deps = -r{toxinidir}/test-requirements.txt From 9578c400914dc88ba352be3edf68e3c62d8aa092 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 24 Jan 2023 18:01:59 +0300 Subject: [PATCH 013/147] sync --- tests/conftest.py | 22 +++++++++++++++++++--- tests/topics/test_topic_writer_async.py | 4 ++++ 2 files changed, 23 insertions(+), 3 deletions(-) create mode 100644 tests/topics/test_topic_writer_async.py diff --git a/tests/conftest.py b/tests/conftest.py index 336ffdae..7a168717 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,7 @@ import pytest import ydb import time +import subprocess @pytest.fixture(autouse=True, scope="session") @@ -56,9 +57,9 @@ def secure_endpoint(pytestconfig, session_scoped_container_getter): assert os.path.exists(ca_path) os.environ["YDB_SSL_ROOT_CERTIFICATES_FILE"] = ca_path with ydb.Driver( - endpoint="grpcs://localhost:2135", - database="/local", - root_certificates=ydb.load_ydb_root_certificate(), + endpoint="grpcs://localhost:2135", + database="/local", + root_certificates=ydb.load_ydb_root_certificate(), ) as driver: wait_container_ready(driver) yield "localhost:2135" @@ -96,3 +97,18 @@ async def driver(endpoint, database, event_loop): yield driver await driver.stop(timeout=10) + + +@pytest.fixture() +def topic_path() -> str: + subprocess.run( + """docker-compose exec ydb /ydb -e grpc://localhost:2136 -d /local topic drop /local/test-topic""", + shell=True, + ) + res = subprocess.run( + """exec ydb /ydb -e grpc://localhost:2136 -d /local topic create /local/test-topic""", + shell=True, + ) + assert res.returncode == 0 + + return "/local/test-topic" diff --git a/tests/topics/test_topic_writer_async.py b/tests/topics/test_topic_writer_async.py new file mode 100644 index 00000000..6fd7d329 --- /dev/null +++ b/tests/topics/test_topic_writer_async.py @@ -0,0 +1,4 @@ + + +def test_write_single_message(driver, topic_path): + print(topic_path) From 28e87f58f1ee204086c3d207df1c07cca2167774 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Mon, 30 Jan 2023 17:44:54 +0300 Subject: [PATCH 014/147] sync --- tests/topics/test_topic_writer_async.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/topics/test_topic_writer_async.py b/tests/topics/test_topic_writer_async.py index 6fd7d329..f823ffd4 100644 --- a/tests/topics/test_topic_writer_async.py +++ b/tests/topics/test_topic_writer_async.py @@ -1,4 +1,8 @@ +import pytest -def test_write_single_message(driver, topic_path): - print(topic_path) +@pytest.mark.asyncio +class TesttopicWriter: + async def test_send_message(self, driver): + pass + From 994eef3b09d19be470de9cd433ab72c4c7dbdb0c Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 31 Jan 2023 14:14:53 +0300 Subject: [PATCH 015/147] use anonymous credentials by default --- tests/conftest.py | 17 +++++++------ tests/topics/test_topic_writer.py | 30 +++++++++++++++++++++++ tests/topics/test_topic_writer_async.py | 8 ------ ydb/_topic_writer/topic_writer.py | 3 +++ ydb/_topic_writer/topic_writer_asyncio.py | 12 ++++----- 5 files changed, 49 insertions(+), 21 deletions(-) create mode 100644 tests/topics/test_topic_writer.py delete mode 100644 tests/topics/test_topic_writer_async.py diff --git a/tests/conftest.py b/tests/conftest.py index 7a168717..2681037b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -57,9 +57,9 @@ def secure_endpoint(pytestconfig, session_scoped_container_getter): assert os.path.exists(ca_path) os.environ["YDB_SSL_ROOT_CERTIFICATES_FILE"] = ca_path with ydb.Driver( - endpoint="grpcs://localhost:2135", - database="/local", - root_certificates=ydb.load_ydb_root_certificate(), + endpoint="grpcs://localhost:2135", + database="/local", + root_certificates=ydb.load_ydb_root_certificate(), ) as driver: wait_container_ready(driver) yield "localhost:2135" @@ -100,15 +100,18 @@ async def driver(endpoint, database, event_loop): @pytest.fixture() -def topic_path() -> str: +def topic_path(endpoint) -> str: subprocess.run( - """docker-compose exec ydb /ydb -e grpc://localhost:2136 -d /local topic drop /local/test-topic""", + """docker-compose exec -T ydb /ydb -e grpc://%s -d /local topic drop /local/test-topic""" + % endpoint, shell=True, ) res = subprocess.run( - """exec ydb /ydb -e grpc://localhost:2136 -d /local topic create /local/test-topic""", + """docker-compose exec -T ydb /ydb -e grpc://%s -d /local topic create /local/test-topic""" + % endpoint, shell=True, + capture_output=True, ) - assert res.returncode == 0 + assert res.returncode == 0, res.stderr + res.stdout return "/local/test-topic" diff --git a/tests/topics/test_topic_writer.py b/tests/topics/test_topic_writer.py new file mode 100644 index 00000000..3071c655 --- /dev/null +++ b/tests/topics/test_topic_writer.py @@ -0,0 +1,30 @@ +import pytest + +import ydb.aio + + +@pytest.mark.asyncio +class TestTopicWriterAsyncIO: + async def test_send_message(self, driver: ydb.aio.Driver, topic_path): + writer = driver.topic_client.topic_writer( + topic_path, producer_and_message_group_id="test" + ) + writer.write(ydb.TopicWriterMessage(data="123".encode())) + + async def test_wait_last_seqno(self, driver: ydb.aio.Driver, topic_path): + async with driver.topic_client.topic_writer( + topic_path, + producer_and_message_group_id="test", + auto_seqno=False, + ) as writer: + await writer.write_with_ack( + ydb.TopicWriterMessage(data="123".encode(), seqno=5) + ) + + async with driver.topic_client.topic_writer( + topic_path, + producer_and_message_group_id="test", + get_last_seqno=True, + ) as writer2: + init_info = await writer2.wait_init() + assert init_info.last_seqno == 5 diff --git a/tests/topics/test_topic_writer_async.py b/tests/topics/test_topic_writer_async.py deleted file mode 100644 index f823ffd4..00000000 --- a/tests/topics/test_topic_writer_async.py +++ /dev/null @@ -1,8 +0,0 @@ -import pytest - - -@pytest.mark.asyncio -class TesttopicWriter: - async def test_send_message(self, driver): - pass - diff --git a/ydb/_topic_writer/topic_writer.py b/ydb/_topic_writer/topic_writer.py index da614bf2..ecc20e10 100644 --- a/ydb/_topic_writer/topic_writer.py +++ b/ydb/_topic_writer/topic_writer.py @@ -165,6 +165,9 @@ class Skipped: pass +PublicWriteResultTypes = Union[PublicWriteResult.Written, PublicWriteResult.Skipped] + + class WriterSettings(PublicWriterSettings): def __init__(self, settings: PublicWriterSettings): self.__dict__ = settings.__dict__.copy() diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index b4373b17..61ab1e8c 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -1,6 +1,5 @@ import asyncio import datetime -import threading from collections import deque from typing import Deque, AsyncIterator, Union, List, Optional, Callable @@ -9,13 +8,13 @@ PublicWriterSettings, WriterSettings, Writer, - PublicWriteResult, PublicMessage, PublicWriterInitInfo, InternalMessage, TopicWriterStopped, TopicWriterError, messages_to_proto_requests, + PublicWriteResultTypes, ) from .. import ( _apis, @@ -35,7 +34,7 @@ class WriterAsyncIO: _loop: asyncio.AbstractEventLoop _reconnector: "WriterAsyncIOReconnector" - _lock: threading.Lock + _lock: asyncio.Lock _closed: bool @property @@ -43,13 +42,14 @@ def last_seqno(self) -> int: raise NotImplementedError() def __init__(self, driver: SupportedDriverType, settings: PublicWriterSettings): + self._lock = asyncio.Lock() self._loop = asyncio.get_running_loop() self._closed = False self._reconnector = WriterAsyncIOReconnector( driver=driver, settings=WriterSettings(settings) ) - async def __aenter__(self): + async def __aenter__(self) -> "WriterAsyncIO": return self async def __aexit__(self, exc_type, exc_val, exc_tb): @@ -62,7 +62,7 @@ def __del__(self): self._loop.call_soon(self.close) async def close(self): - with self._lock: + async with self._lock: if self._closed: return self._closed = True @@ -73,7 +73,7 @@ async def write_with_ack( self, messages: Union[Writer.MessageType, List[Writer.MessageType]], *args: Optional[Writer.MessageType], - ) -> Union[PublicWriteResult, List[PublicWriteResult]]: + ) -> Union[PublicWriteResultTypes, List[PublicWriteResultTypes]]: """ IT IS SLOWLY WAY. IT IS BAD CHOISE IN MOST CASES. It is recommended to use write with optionally flush or write_with_ack_futures and receive acks by wait futures. From 7769f5dc5aeacb9d9833767644ff4908c9555e58 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Mon, 6 Feb 2023 19:24:17 +0300 Subject: [PATCH 016/147] fix check retriable for idempotent error --- CHANGELOG.md | 2 ++ ydb/_errors.py | 7 ++++--- ydb/table_test.py | 5 +++++ 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 86e3a167..c2b1b247 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,5 @@ +* Fix error of check retriable error for idempotent operations (error exist since 2.12.1) + ## 3.0.1b2 ## * Add initial topic writer diff --git a/ydb/_errors.py b/ydb/_errors.py index e8628586..ae3057b6 100644 --- a/ydb/_errors.py +++ b/ydb/_errors.py @@ -49,9 +49,10 @@ def check_retriable_error(err, retry_settings, attempt): if retry_settings.idempotent: for t in _errors_retriable_slow_backoff_idempotent_types: - return ErrorRetryInfo( - True, retry_settings.slow_backoff.calc_timeout(attempt) - ) + if isinstance(err, t): + return ErrorRetryInfo( + True, retry_settings.slow_backoff.calc_timeout(attempt) + ) return ErrorRetryInfo(False, None) diff --git a/ydb/table_test.py b/ydb/table_test.py index 361719be..2cb2a6a0 100644 --- a/ydb/table_test.py +++ b/ydb/table_test.py @@ -132,4 +132,9 @@ def check_retriable_error(err_type, backoff): check_retriable_error(issues.Unavailable, retry_once_settings.fast_backoff) check_unretriable_error(issues.Error, True) + with mock.patch.object(retry_once_settings, "idempotent", True): + check_unretriable_error(issues.Error, True) + check_unretriable_error(TestException, False) + with mock.patch.object(retry_once_settings, "idempotent", True): + check_unretriable_error(TestException, False) From 6b63bcf7f0ab057128b7786444f047fc9b00c46c Mon Sep 17 00:00:00 2001 From: robot Date: Mon, 6 Feb 2023 16:40:05 +0000 Subject: [PATCH 017/147] Release: 3.0.1b3 --- CHANGELOG.md | 1 + setup.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c2b1b247..ed909b29 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,4 @@ +## 3.0.1b3 ## * Fix error of check retriable error for idempotent operations (error exist since 2.12.1) ## 3.0.1b2 ## diff --git a/setup.py b/setup.py index 726b351c..6716c84a 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setuptools.setup( name="ydb", - version="3.0.1b2", # AUTOVERSION + version="3.0.1b3", # AUTOVERSION description="YDB Python SDK", author="Yandex LLC", author_email="ydb@yandex-team.ru", From 848283309986f922461abad174b3c64e5101681c Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 31 Jan 2023 15:29:18 +0300 Subject: [PATCH 018/147] simple init stream reader --- ydb/_topic_reader/topic_reader_asyncio.py | 33 +++++++ .../topic_reader_asyncio_test.py | 46 +++++++++ ydb/_topic_wrapper/common.py | 3 + ydb/_topic_wrapper/reader.py | 93 +++++++++++++++++-- ydb/_topic_wrapper/test_helpers.py | 22 +++++ ydb/_topic_writer/topic_writer_asyncio.py | 15 ++- .../topic_writer_asyncio_test.py | 22 +---- 7 files changed, 201 insertions(+), 33 deletions(-) create mode 100644 ydb/_topic_reader/topic_reader_asyncio.py create mode 100644 ydb/_topic_reader/topic_reader_asyncio_test.py create mode 100644 ydb/_topic_wrapper/test_helpers.py diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py new file mode 100644 index 00000000..e0176f4f --- /dev/null +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +import asyncio +from typing import Optional + +from ydb._topic_wrapper.common import TokenGetterFuncType, IGrpcWrapperAsyncIO +from ydb._topic_wrapper.reader import StreamReadMessage + + +class PublicAsyncIOReader: + pass + + +class ReaderReconnector: + pass + + +class ReaderStream: + _token_getter: Optional[TokenGetterFuncType] + _session_id: str + _init_completed: asyncio.Future[None] + + def __init__(self, token_getter: Optional[TokenGetterFuncType]): + self._token_getter = token_getter + self._session_id = "not initialized" + + async def start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamReadMessage.InitRequest): + stream.write(StreamReadMessage.FromClient(client_message=init_message)) + init_response = await stream.receive() # type: StreamReadMessage.FromServer + if isinstance(init_response.server_message, StreamReadMessage.InitResponse): + self._session_id = init_response.server_message.session_id + + diff --git a/ydb/_topic_reader/topic_reader_asyncio_test.py b/ydb/_topic_reader/topic_reader_asyncio_test.py new file mode 100644 index 00000000..7d7795d1 --- /dev/null +++ b/ydb/_topic_reader/topic_reader_asyncio_test.py @@ -0,0 +1,46 @@ +import asyncio +from unittest import mock + +import pytest + +from ydb import aio +from ydb._topic_reader.topic_reader_asyncio import ReaderStream +from ydb._topic_wrapper.reader import StreamReadMessage +from ydb._topic_wrapper.test_helpers import StreamMock + + +def default_driver() -> aio.Driver: + driver = mock.Mock(spec=aio.Driver) + driver._credentials = mock.Mock() + return driver + + +@pytest.mark.asyncio +class TestReaderStream: + @pytest.fixture() + def stream(self): + return StreamMock() + + async def test_init_reader(self, stream): + reader = ReaderStream(None) + init_message = StreamReadMessage.InitRequest( + consumer="test-consumer", + topics_read_settings=[StreamReadMessage.InitRequest.TopicReadSettings( + path="/local/test-topic", + partition_ids=[], + max_lag_seconds=None, + read_from=None, + )] + ) + start_task = asyncio.create_task(reader.start(stream, init_message)) + + sent_message = await stream.from_client.get() + expected_sent_init_message = StreamReadMessage.FromClient(client_message=init_message) + assert sent_message == expected_sent_init_message + + stream.from_server.put_nowait(StreamReadMessage.FromServer( + server_message=StreamReadMessage.InitResponse(session_id="test")) + ) + + await start_task + assert reader._session_id == "test" diff --git a/ydb/_topic_wrapper/common.py b/ydb/_topic_wrapper/common.py index 1bcdfde8..15173df8 100644 --- a/ydb/_topic_wrapper/common.py +++ b/ydb/_topic_wrapper/common.py @@ -225,3 +225,6 @@ class UpdateTokenResponse(IFromProto): @staticmethod def from_proto(msg: ydb_topic_pb2.UpdateTokenResponse) -> typing.Any: return UpdateTokenResponse() + + +TokenGetterFuncType = typing.Optional[typing.Callable[[], str]] diff --git a/ydb/_topic_wrapper/reader.py b/ydb/_topic_wrapper/reader.py index 0a84d343..435d3a28 100644 --- a/ydb/_topic_wrapper/reader.py +++ b/ydb/_topic_wrapper/reader.py @@ -1,8 +1,18 @@ import datetime +import typing from dataclasses import dataclass, field from typing import List, Union, Dict -from ydb._topic_wrapper.common import OffsetsRange +from google.protobuf.message import Message + +from ydb._topic_wrapper.common import OffsetsRange, IToProto, UpdateTokenRequest, UpdateTokenResponse, IFromProto +from google.protobuf.duration_pb2 import Duration as ProtoDuration + +# Workaround for good autocomplete in IDE and universal import at runtime +if False: + from ydb._grpc.v4.protos import ydb_topic_pb2 +else: + from ydb._grpc.common.protos import ydb_topic_pb2 class StreamReadMessage: @@ -13,21 +23,41 @@ class PartitionSession: partition_id: int @dataclass - class InitRequest: - topics_read_settings: List["TopicReadSettings"] + class InitRequest(IToProto): + topics_read_settings: List["StreamReadMessage.InitRequest.TopicReadSettings"] consumer: str + def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.InitRequest: + res = ydb_topic_pb2.StreamReadMessage.InitRequest() + res.consumer = self.consumer + for settings in self.topics_read_settings: + res.topics_read_settings.append(settings.to_proto()) + return res + @dataclass - class TopicReadSettings: + class TopicReadSettings(IToProto): path: str partition_ids: List[int] = field(default_factory=list) - max_lag_seconds: Union[float, None] = None + max_lag_seconds: Union[datetime.timedelta, None] = None read_from: Union[int, float, datetime.datetime, None] = None + def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.InitRequest.TopicReadSettings: + res = ydb_topic_pb2.StreamReadMessage.InitRequest.TopicReadSettings() + res.path = self.path + res.partition_ids.extend(self.partition_ids) + if self.max_lag_seconds is not None: + res.max_lag = ProtoDuration() + res.max_lag.FromTimedelta(self.max_lag_seconds) + return res + @dataclass - class InitResponse: + class InitResponse(IFromProto): session_id: str + @staticmethod + def from_proto(msg: ydb_topic_pb2.StreamReadMessage.InitResponse) -> "StreamReadMessage.InitResponse": + return StreamReadMessage.InitResponse(session_id=msg.session_id) + @dataclass class ReadRequest: bytes_size: int @@ -109,3 +139,54 @@ class StopPartitionSessionRequest: @dataclass class StopPartitionSessionResponse: partition_session_id: int + + @dataclass + class FromClient(IToProto): + client_message: "ReaderMessagesFromClientToServer" + + def __init__(self, client_message: "ReaderMessagesFromClientToServer"): + self.client_message = client_message + + def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.FromClient: + res = ydb_topic_pb2.StreamReadMessage.FromClient() + if isinstance(self.client_message, StreamReadMessage.InitRequest): + res.init_request.CopyFrom(self.client_message.to_proto()) + else: + raise NotImplementedError() + return res + + @dataclass + class FromServer(IFromProto): + server_message: "ReaderMessagesFromServerToClient" + + @staticmethod + def from_proto(msg: ydb_topic_pb2.StreamReadMessage.FromServer) -> "StreamReadMessage.FromServer": + mess_type = msg.WhichOneof("server_message") + if mess_type == "init_response": + return StreamReadMessage.FromServer( + server_message=StreamReadMessage.InitResponse.from_proto(msg.init_response), + ) + + # todo replace exception to log + raise NotImplementedError() + + +ReaderMessagesFromClientToServer = Union[ + StreamReadMessage.InitRequest, + StreamReadMessage.ReadRequest, + StreamReadMessage.CommitOffsetRequest, + StreamReadMessage.PartitionSessionStatusRequest, + UpdateTokenRequest, + StreamReadMessage.StartPartitionSessionResponse, + StreamReadMessage.StopPartitionSessionResponse, +] + +ReaderMessagesFromServerToClient = Union[ + StreamReadMessage.InitResponse, + StreamReadMessage.ReadResponse, + StreamReadMessage.CommitOffsetResponse, + StreamReadMessage.PartitionSessionStatusResponse, + UpdateTokenResponse, + StreamReadMessage.StartPartitionSessionRequest, + StreamReadMessage.StopPartitionSessionRequest, +] diff --git a/ydb/_topic_wrapper/test_helpers.py b/ydb/_topic_wrapper/test_helpers.py new file mode 100644 index 00000000..9a562188 --- /dev/null +++ b/ydb/_topic_wrapper/test_helpers.py @@ -0,0 +1,22 @@ +import asyncio +import typing + +from .common import IGrpcWrapperAsyncIO, IToProto + + +class StreamMock(IGrpcWrapperAsyncIO): + from_server: asyncio.Queue + from_client: asyncio.Queue + + def __init__(self): + self.from_server = asyncio.Queue() + self.from_client = asyncio.Queue() + + async def receive(self) -> typing.Any: + item = await self.from_server.get() + if isinstance(item, Exception): + raise item + return item + + def write(self, wrap_message: IToProto): + self.from_client.put_nowait(wrap_message) diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 61ab1e8c..01be0a8c 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -26,7 +26,7 @@ UpdateTokenResponse, GrpcWrapperAsyncIO, IGrpcWrapperAsyncIO, - SupportedDriverType, + SupportedDriverType, TokenGetterFuncType, ) from .._topic_wrapper.writer import StreamWriteMessage, WriterMessagesFromServerToClient @@ -156,7 +156,7 @@ class WriterAsyncIOReconnector: _credentials: Union[ydb.Credentials, None] _driver: ydb.aio.Driver _update_token_interval: int - _token_get_function: "TokenGetter" + _token_get_function: TokenGetterFuncType _init_message: StreamWriteMessage.InitRequest _new_messages: asyncio.Queue _init_info: asyncio.Future @@ -385,13 +385,13 @@ class WriterAsyncIOStream: last_seqno: int _stream: IGrpcWrapperAsyncIO - _token_getter: "TokenGetter" + _token_getter: TokenGetterFuncType _requests: asyncio.Queue _responses: AsyncIterator def __init__( self, - token_getter: "TokenGetter", + token_getter: TokenGetterFuncType, ): self._token_getter = token_getter @@ -399,7 +399,7 @@ def __init__( async def create( driver: SupportedDriverType, init_request: StreamWriteMessage.InitRequest, - token_getter: "TokenGetter", + token_getter: TokenGetterFuncType, ) -> "WriterAsyncIOStream": stream = GrpcWrapperAsyncIO(StreamWriteMessage.FromServer.from_proto) @@ -415,7 +415,7 @@ async def create( async def _create_stream_from_async( driver: ydb.aio.Driver, init_request: StreamWriteMessage.InitRequest, - token_getter: "TokenGetter", + token_getter: TokenGetterFuncType, ) -> "WriterAsyncIOStream": return GrpcWrapperAsyncIO(StreamWriteMessage.FromServer.from_proto) @@ -423,7 +423,7 @@ async def _create_stream_from_async( async def _create_from_sync( driver: ydb.Driver, init_request: StreamWriteMessage.InitRequest, - token_getter: "TokenGetter", + token_getter: TokenGetterFuncType, ) -> "WriterAsyncIOStream": stream = GrpcWrapperAsyncIO(StreamWriteMessage.FromServer.from_proto) await stream.start( @@ -472,4 +472,3 @@ def write(self, messages: List[InternalMessage]): self._stream.write(request) -TokenGetter = Optional[Callable[[], str]] diff --git a/ydb/_topic_writer/topic_writer_asyncio_test.py b/ydb/_topic_writer/topic_writer_asyncio_test.py index 5d157f15..38a0a2dd 100644 --- a/ydb/_topic_writer/topic_writer_asyncio_test.py +++ b/ydb/_topic_writer/topic_writer_asyncio_test.py @@ -24,6 +24,7 @@ PublicWriteResult, TopicWriterError, ) +from .._topic_wrapper.test_helpers import StreamMock from .._topic_wrapper.writer import StreamWriteMessage from .topic_writer_asyncio import ( @@ -42,31 +43,14 @@ def default_driver() -> aio.Driver: @pytest.mark.asyncio class TestWriterAsyncIOStream: - class StreamMock(IGrpcWrapperAsyncIO): - from_server: asyncio.Queue - from_client: asyncio.Queue - - def __init__(self): - self.from_server = asyncio.Queue() - self.from_client = asyncio.Queue() - - async def receive(self) -> typing.Any: - item = await self.from_server.get() - if isinstance(item, Exception): - raise item - return item - - def write(self, wrap_message: IToProto): - self.from_client.put_nowait(wrap_message) - @dataclasses.dataclass class WriterWithMockedStream: writer: WriterAsyncIOStream - stream: "TestWriterAsyncIOStream.StreamMock" + stream: StreamMock @pytest.fixture def stream(self): - return TestWriterAsyncIOStream.StreamMock() + return StreamMock() @pytest.fixture async def writer_and_stream(self, stream) -> WriterWithMockedStream: From 353a85717ad5c19bccacc3cbc1c674d5048fe292 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Wed, 1 Feb 2023 18:06:00 +0300 Subject: [PATCH 019/147] start/stop partitions init read request --- ydb/_topic_reader/topic_reader.py | 50 +++-- ydb/_topic_reader/topic_reader_asyncio.py | 152 +++++++++++++- .../topic_reader_asyncio_test.py | 189 +++++++++++++++++- ydb/_topic_wrapper/reader.py | 3 - ydb/_topic_wrapper/test_helpers.py | 19 ++ 5 files changed, 365 insertions(+), 48 deletions(-) diff --git a/ydb/_topic_reader/topic_reader.py b/ydb/_topic_reader/topic_reader.py index ac961674..d335e40d 100644 --- a/ydb/_topic_reader/topic_reader.py +++ b/ydb/_topic_reader/topic_reader.py @@ -3,6 +3,7 @@ import enum import io import datetime +from dataclasses import dataclass from typing import ( Union, Optional, @@ -15,6 +16,8 @@ Any, ) +from ydb._topic_wrapper.common import OffsetsRange, TokenGetterFuncType + class Selector: path: str @@ -247,27 +250,25 @@ def close(self): raise NotImplementedError() -class ReaderSettings: - def __init__( - self, - *, - consumer: str, - buffer_size_bytes: int = 50 * 1024 * 1024, - on_commit: Callable[["Events.OnCommit"], None] = None, - on_get_partition_start_offset: Callable[ - ["Events.OnPartitionGetStartOffsetRequest"], - "Events.OnPartitionGetStartOffsetResponse", - ] = None, - on_partition_session_start: Callable[["StubEvent"], None] = None, - on_partition_session_stop: Callable[["StubEvent"], None] = None, - on_partition_session_close: Callable[["StubEvent"], None] = None, # todo? - decoder: Union[Mapping[int, Callable[[bytes], bytes]], None] = None, - deserializer: Union[Callable[[bytes], Any], None] = None, - one_attempt_connection_timeout: Union[float, None] = 1, - connection_timeout: Union[float, None] = None, - retry_policy: Union["RetryPolicy", None] = None, - ): - raise NotImplementedError() +@dataclass +class PublicReaderSettings: + consumer: str + topic: str + buffer_size_bytes: int = 50 * 1024 * 1024 + _token_getter: Optional[TokenGetterFuncType] = None + # on_commit: Callable[["Events.OnCommit"], None] = None + # on_get_partition_start_offset: Callable[ + # ["Events.OnPartitionGetStartOffsetRequest"], + # "Events.OnPartitionGetStartOffsetResponse", + # ] = None + # on_partition_session_start: Callable[["StubEvent"], None] = None + # on_partition_session_stop: Callable[["StubEvent"], None] = None + # on_partition_session_close: Callable[["StubEvent"], None] = None # todo? + # decoder: Union[Mapping[int, Callable[[bytes], bytes]], None] = None + # deserializer: Union[Callable[[bytes], Any], None] = None + # one_attempt_connection_timeout: Union[float, None] = 1 + # connection_timeout: Union[float, None] = None + # retry_policy: Union["RetryPolicy", None] = None class ICommittable(abc.ABC): @@ -382,16 +383,11 @@ class State(enum.Enum): class SessionStat: path: str partition_id: str - partition_offsets: "OffsetRange" + partition_offsets: OffsetsRange committed_offset: int write_time_high_watermark: datetime.datetime write_time_high_watermark_timestamp_nano: int -class OffsetRange: - start: int - end: int - - class StubEvent: pass diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index e0176f4f..74f8f02e 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -1,10 +1,18 @@ from __future__ import annotations import asyncio -from typing import Optional +import enum +from dataclasses import dataclass +from typing import Optional, Set, Dict -from ydb._topic_wrapper.common import TokenGetterFuncType, IGrpcWrapperAsyncIO -from ydb._topic_wrapper.reader import StreamReadMessage +import ydb +from .topic_reader import PublicReaderSettings +from .._topic_wrapper.common import TokenGetterFuncType, IGrpcWrapperAsyncIO +from .._topic_wrapper.reader import StreamReadMessage + + +class TopicReaderError(ydb.Error): + pass class PublicAsyncIOReader: @@ -19,15 +27,141 @@ class ReaderStream: _token_getter: Optional[TokenGetterFuncType] _session_id: str _init_completed: asyncio.Future[None] + _stream: Optional[IGrpcWrapperAsyncIO] - def __init__(self, token_getter: Optional[TokenGetterFuncType]): - self._token_getter = token_getter + _lock: asyncio.Lock + _started: bool + _closed: bool + _first_error: Optional[ydb.Error] + _background_tasks: Set[asyncio.Task] + _partition_sessions: Dict[int, PartitionSession] + _buffer_size_bytes: int # use for init request, then for debug purposes only + + def __init__(self, settings: PublicReaderSettings): + self._token_getter = settings._token_getter self._session_id = "not initialized" + self._stream = None + + self._lock = asyncio.Lock() + self._started = False + self._closed = False + self._first_error = None + self._background_tasks = set() + self._partition_sessions = dict() + self._buffer_size_bytes = settings.buffer_size_bytes async def start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamReadMessage.InitRequest): - stream.write(StreamReadMessage.FromClient(client_message=init_message)) - init_response = await stream.receive() # type: StreamReadMessage.FromServer - if isinstance(init_response.server_message, StreamReadMessage.InitResponse): - self._session_id = init_response.server_message.session_id + async with self._lock: + if self._started: + raise TopicReaderError("Double start ReaderStream") + self._started = True + self._stream = stream + + stream.write(StreamReadMessage.FromClient(client_message=init_message)) + init_response = await stream.receive() # type: StreamReadMessage.FromServer + if isinstance(init_response.server_message, StreamReadMessage.InitResponse): + self._session_id = init_response.server_message.session_id + else: + raise TopicReaderError("Unexpected message after InitRequest: %s", init_response) + + read_messages_task = asyncio.create_task(self._read_messages(stream)) + self._background_tasks.add(read_messages_task) + + async def _read_messages(self, stream: IGrpcWrapperAsyncIO): + try: + self._stream.write(StreamReadMessage.FromClient( + client_message=StreamReadMessage.ReadRequest( + bytes_size=self._buffer_size_bytes, + ), + )) + while True: + message = await stream.receive() # type: StreamReadMessage.FromServer + if isinstance(message.server_message, StreamReadMessage.StartPartitionSessionRequest): + await self._on_start_partition_session_start(message.server_message) + elif isinstance(message.server_message, StreamReadMessage.StopPartitionSessionRequest): + await self._on_partition_session_stop(message.server_message) + else: + raise NotImplementedError( + "Unexpected type of StreamReadMessage.FromServer message: %s" % message.server_message + ) + except Exception as e: + await self._set_first_error(e) + raise e + + async def _on_start_partition_session_start(self, message: StreamReadMessage.StartPartitionSessionRequest): + async with self._lock: + try: + if message.partition_session.partition_session_id in self._partition_sessions: + raise TopicReaderError( + "Double start partition session: %s" % message.partition_session.partition_session_id + ) + + self._partition_sessions[message.partition_session.partition_session_id] = PartitionSession( + id=message.partition_session.partition_session_id, + state=PartitionSession.State.Active, + topic_path=message.partition_session.path, + partition_id=message.partition_session.partition_id, + ) + self._stream.write(StreamReadMessage.FromClient( + client_message=StreamReadMessage.StartPartitionSessionResponse( + partition_session_id=message.partition_session.partition_session_id, + read_offset=0, + commit_offset=0, + )), + ) + except ydb.Error as err: + self._set_first_error_locked(err) + + async def _on_partition_session_stop(self, message: StreamReadMessage.StopPartitionSessionRequest): + async with self._lock: + partition = self._partition_sessions.get(message.partition_session_id) + if partition is None: + # may if receive stop partition with graceful=false after response on stop partition + # with graceful=true and remove partition from internal dictionary + return + + del self._partition_sessions[message.partition_session_id] + partition.stop() + + if message.graceful: + self._stream.write(StreamReadMessage.FromClient( + client_message=StreamReadMessage.StopPartitionSessionResponse( + partition_session_id=message.partition_session_id, + )) + ) + + async def _set_first_error(self, err): + async with self._lock: + self._set_first_error_locked(err) + + def _set_first_error_locked(self, err): + if self._first_error is None: + self._first_error = err + + async def close(self): + async with self._lock: + if self._closed: + raise TopicReaderError(message="Double closed ReaderStream") + self._closed = True + self._set_first_error_locked(TopicReaderError("Reader closed")) + + for task in self._background_tasks: + task.cancel() + + await asyncio.wait(self._background_tasks) + + +@dataclass +class PartitionSession: + id: int + state: "PartitionSession.State" + topic_path: str + partition_id: int + def stop(self): + self.state = PartitionSession.State.Stopped + class State(enum.Enum): + Active = 1 + GracefulShutdown = 2 + Stopped = 3 diff --git a/ydb/_topic_reader/topic_reader_asyncio_test.py b/ydb/_topic_reader/topic_reader_asyncio_test.py index 7d7795d1..85326352 100644 --- a/ydb/_topic_reader/topic_reader_asyncio_test.py +++ b/ydb/_topic_reader/topic_reader_asyncio_test.py @@ -4,25 +4,94 @@ import pytest from ydb import aio -from ydb._topic_reader.topic_reader_asyncio import ReaderStream +from ydb._topic_reader.topic_reader import PublicReaderSettings +from ydb._topic_reader.topic_reader_asyncio import ReaderStream, PartitionSession +from ydb._topic_wrapper.common import OffsetsRange from ydb._topic_wrapper.reader import StreamReadMessage -from ydb._topic_wrapper.test_helpers import StreamMock +from ydb._topic_wrapper.test_helpers import StreamMock, wait_condition, wait_for_fast -def default_driver() -> aio.Driver: - driver = mock.Mock(spec=aio.Driver) - driver._credentials = mock.Mock() - return driver +@pytest.fixture() +def default_reader_settings(): + return PublicReaderSettings( + consumer="test-consumer", + topic="test-topic", + ) + + +class StreamMockForReader(StreamMock): + async def receive(self) -> StreamReadMessage.FromServer: + return await super(self).receive() + + def write(self, message: StreamReadMessage.FromClient): + return super().write(message) @pytest.mark.asyncio class TestReaderStream: + @pytest.fixture() def stream(self): return StreamMock() - async def test_init_reader(self, stream): - reader = ReaderStream(None) + @pytest.fixture() + def partition_session(self, default_reader_settings): + return PartitionSession( + id=2, + topic_path=default_reader_settings.topic, + partition_id=4, + state=PartitionSession.State.Active, + ) + + @pytest.fixture() + async def stream_reader(self, stream, partition_session, default_reader_settings) -> ReaderStream: + reader = ReaderStream(default_reader_settings) + init_message = object() + + # noinspection PyTypeChecker + start = asyncio.create_task(reader.start(stream, init_message)) + + stream.from_server.put_nowait(StreamReadMessage.FromServer( + StreamReadMessage.InitResponse(session_id="test-session") + )) + + init_request = await wait_for_fast(stream.from_client.get()) + assert init_request.client_message == init_message + + read_request = await wait_for_fast(stream.from_client.get()) + assert isinstance(read_request.client_message, StreamReadMessage.ReadRequest) + + stream.from_server.put_nowait( + StreamReadMessage.FromServer(server_message=StreamReadMessage.StartPartitionSessionRequest( + partition_session=StreamReadMessage.PartitionSession( + partition_session_id=partition_session.id, + path=partition_session.topic_path, + partition_id=partition_session.partition_id, + ), + committed_offset=0, + partition_offsets=OffsetsRange( + start=0, + end=0, + ) + )) + ) + await start + + start_partition_resp = await wait_for_fast(stream.from_client.get()) + assert isinstance(start_partition_resp.client_message, StreamReadMessage.StartPartitionSessionResponse) + + await asyncio.sleep(0) + with pytest.raises(asyncio.QueueEmpty): + stream.from_client.get_nowait() + + yield reader + + assert reader._first_error is None + + await reader.close() + + async def test_init_reader(self, stream, default_reader_settings): + reader = ReaderStream(default_reader_settings) init_message = StreamReadMessage.InitRequest( consumer="test-consumer", topics_read_settings=[StreamReadMessage.InitRequest.TopicReadSettings( @@ -34,7 +103,7 @@ async def test_init_reader(self, stream): ) start_task = asyncio.create_task(reader.start(stream, init_message)) - sent_message = await stream.from_client.get() + sent_message = await wait_for_fast(stream.from_client.get()) expected_sent_init_message = StreamReadMessage.FromClient(client_message=init_message) assert sent_message == expected_sent_init_message @@ -43,4 +112,106 @@ async def test_init_reader(self, stream): ) await start_task + + read_request = await wait_for_fast(stream.from_client.get()) + assert read_request.client_message == StreamReadMessage.ReadRequest( + bytes_size=default_reader_settings.buffer_size_bytes, + ) + assert reader._session_id == "test" + await reader.close() + + async def test_start_partition(self, stream_reader: ReaderStream, stream, default_reader_settings, partition_session): + def session_count(): + return len(stream_reader._partition_sessions) + + initial_session_count = session_count() + + test_partition_id = partition_session.partition_id+1 + test_partition_session_id = partition_session.id + 1 + test_topic_path = default_reader_settings.topic + "-asd" + + stream.from_server.put_nowait(StreamReadMessage.FromServer( + server_message=StreamReadMessage.StartPartitionSessionRequest( + partition_session=StreamReadMessage.PartitionSession( + partition_session_id=test_partition_session_id, + path=test_topic_path, + partition_id=test_partition_id, + ), + committed_offset=0, + partition_offsets=OffsetsRange( + start=0, + end=0, + ), + )), + ) + response = await wait_for_fast(stream.from_client.get()) + assert response == StreamReadMessage.FromClient(client_message=StreamReadMessage.StartPartitionSessionResponse( + partition_session_id=test_partition_session_id, + read_offset=0, + commit_offset=0, + )) + + assert len(stream_reader._partition_sessions) == initial_session_count + 1 + assert stream_reader._partition_sessions[test_partition_session_id] == PartitionSession( + id=test_partition_session_id, + state=PartitionSession.State.Active, + topic_path=test_topic_path, + partition_id=test_partition_id, + ) + + async def test_partition_stop_force(self, stream, stream_reader, partition_session): + def session_count(): + return len(stream_reader._partition_sessions) + + initial_session_count = session_count() + + stream.from_server.put_nowait(StreamReadMessage.FromServer( + server_message=StreamReadMessage.StopPartitionSessionRequest( + partition_session_id=partition_session.id, + graceful=False, + committed_offset=0, + ) + )) + + await asyncio.sleep(0) # wait next loop + with pytest.raises(asyncio.QueueEmpty): + stream.from_client.get_nowait() + + assert session_count() == initial_session_count - 1 + assert partition_session.id not in stream_reader._partition_sessions + + async def test_partition_stop_graceful(self, stream, stream_reader, partition_session): + def session_count(): + return len(stream_reader._partition_sessions) + + initial_session_count = session_count() + + stream.from_server.put_nowait(StreamReadMessage.FromServer( + server_message=StreamReadMessage.StopPartitionSessionRequest( + partition_session_id=partition_session.id, + graceful=True, + committed_offset=0, + ) + )) + + resp = await wait_for_fast(stream.from_client.get()) # type: StreamReadMessage.FromClient + assert session_count() == initial_session_count-1 + assert partition_session.id not in stream_reader._partition_sessions + assert resp.client_message == StreamReadMessage.StopPartitionSessionResponse( + partition_session_id=partition_session.id + + ) + + stream.from_server.put_nowait(StreamReadMessage.FromServer( + server_message=StreamReadMessage.StopPartitionSessionRequest( + partition_session_id=partition_session.id, + graceful=False, + committed_offset=0, + ) + )) + + await asyncio.sleep(0) # wait next loop + with pytest.raises(asyncio.QueueEmpty): + stream.from_client.get_nowait() + diff --git a/ydb/_topic_wrapper/reader.py b/ydb/_topic_wrapper/reader.py index 435d3a28..6cf492a6 100644 --- a/ydb/_topic_wrapper/reader.py +++ b/ydb/_topic_wrapper/reader.py @@ -1,10 +1,7 @@ import datetime -import typing from dataclasses import dataclass, field from typing import List, Union, Dict -from google.protobuf.message import Message - from ydb._topic_wrapper.common import OffsetsRange, IToProto, UpdateTokenRequest, UpdateTokenResponse, IFromProto from google.protobuf.duration_pb2 import Duration as ProtoDuration diff --git a/ydb/_topic_wrapper/test_helpers.py b/ydb/_topic_wrapper/test_helpers.py index 9a562188..b0c75a03 100644 --- a/ydb/_topic_wrapper/test_helpers.py +++ b/ydb/_topic_wrapper/test_helpers.py @@ -1,6 +1,9 @@ import asyncio +import time import typing +import pytest + from .common import IGrpcWrapperAsyncIO, IToProto @@ -20,3 +23,19 @@ async def receive(self) -> typing.Any: def write(self, wrap_message: IToProto): self.from_client.put_nowait(wrap_message) + + +async def wait_condition(f: typing.Callable[[], bool], timeout=1): + start = time.monotonic() + counter = 0 + while (time.monotonic() - start < timeout) or counter < 1000: + counter += 1 + if f(): + return + await asyncio.sleep(0) + + raise Exception("Bad condition in test") + + +async def wait_for_fast(fut): + return await asyncio.wait_for(fut, 1) From 5faa304bb4be0c6393ed12f00cc8b92e654f92f2 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Wed, 1 Feb 2023 18:22:21 +0300 Subject: [PATCH 020/147] start implement read messages --- ydb/_topic_reader/topic_reader_asyncio.py | 8 +++- .../topic_reader_asyncio_test.py | 40 ++++++++++++++++++- ydb/_topic_wrapper/common.py | 4 +- ydb/_topic_wrapper/reader.py | 4 +- 4 files changed, 50 insertions(+), 6 deletions(-) diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index 74f8f02e..49fb6e7a 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -76,7 +76,9 @@ async def _read_messages(self, stream: IGrpcWrapperAsyncIO): )) while True: message = await stream.receive() # type: StreamReadMessage.FromServer - if isinstance(message.server_message, StreamReadMessage.StartPartitionSessionRequest): + if isinstance(message.server_message, StreamReadMessage.ReadResponse): + await self._on_read_response(message.server_message) + elif isinstance(message.server_message, StreamReadMessage.StartPartitionSessionRequest): await self._on_start_partition_session_start(message.server_message) elif isinstance(message.server_message, StreamReadMessage.StopPartitionSessionRequest): await self._on_partition_session_stop(message.server_message) @@ -130,6 +132,10 @@ async def _on_partition_session_stop(self, message: StreamReadMessage.StopPartit )) ) + async def _on_read_response(self, message: StreamReadMessage.ReadResponse): + async with self._lock: + pass + async def _set_first_error(self, err): async with self._lock: self._set_first_error_locked(err) diff --git a/ydb/_topic_reader/topic_reader_asyncio_test.py b/ydb/_topic_reader/topic_reader_asyncio_test.py index 85326352..59582cb7 100644 --- a/ydb/_topic_reader/topic_reader_asyncio_test.py +++ b/ydb/_topic_reader/topic_reader_asyncio_test.py @@ -1,4 +1,5 @@ import asyncio +import datetime from unittest import mock import pytest @@ -6,7 +7,7 @@ from ydb import aio from ydb._topic_reader.topic_reader import PublicReaderSettings from ydb._topic_reader.topic_reader_asyncio import ReaderStream, PartitionSession -from ydb._topic_wrapper.common import OffsetsRange +from ydb._topic_wrapper.common import OffsetsRange, Codec from ydb._topic_wrapper.reader import StreamReadMessage from ydb._topic_wrapper.test_helpers import StreamMock, wait_condition, wait_for_fast @@ -215,3 +216,40 @@ def session_count(): with pytest.raises(asyncio.QueueEmpty): stream.from_client.get_nowait() + async def test_receive_one_raw_message_from_server(self, stream_reader, stream, partition_session): + bytes_size = 10 + created_at = datetime.datetime(2020, 1, 1, 18, 12) + written_at = datetime.datetime(2023, 2, 1, 18, 12) + producer_id = "test-producer-id" + data = "123".encode() + + message_group_id = "test-message-group-id" + + stream.from_server.put_nowait(StreamReadMessage.FromServer(server_message=StreamReadMessage.ReadResponse( + bytes_size=bytes_size, + partition_data=[ + StreamReadMessage.ReadResponse.PartitionData( + partition_session_id=partition_session.id, + batches=[ + StreamReadMessage.ReadResponse.Batch( + message_data=[ + StreamReadMessage.ReadResponse.MessageData( + offset=1, + seq_no=2, + created_at=created_at, + data=data, + uncompresed_size=len(data), + message_group_id=message_group_id, + ) + ], + producer_id=producer_id, + write_session_meta={"a": "b"}, + codec=Codec.CODEC_RAW, + written_at=written_at, + ) + ] + ) + ] + ))), + + raise NotImplementedError() diff --git a/ydb/_topic_wrapper/common.py b/ydb/_topic_wrapper/common.py index 15173df8..e1d228f9 100644 --- a/ydb/_topic_wrapper/common.py +++ b/ydb/_topic_wrapper/common.py @@ -2,7 +2,7 @@ import asyncio import typing from dataclasses import dataclass -from enum import Enum +from enum import IntEnum from google.protobuf.message import Message @@ -24,7 +24,7 @@ ) -class Codec(Enum): +class Codec(IntEnum): CODEC_UNSPECIFIED = 0 CODEC_RAW = 1 CODEC_GZIP = 2 diff --git a/ydb/_topic_wrapper/reader.py b/ydb/_topic_wrapper/reader.py index 6cf492a6..6b84ec71 100644 --- a/ydb/_topic_wrapper/reader.py +++ b/ydb/_topic_wrapper/reader.py @@ -68,7 +68,7 @@ class ReadResponse: class MessageData: offset: int seq_no: int - created_at: float # unix timestamp + created_at: datetime.datetime data: bytes uncompresed_size: int message_group_id: str @@ -79,7 +79,7 @@ class Batch: producer_id: str write_session_meta: Dict[str, str] codec: int - written_at: float # unix timestamp + written_at: datetime.datetime @dataclass class PartitionData: From 132aa1e43bd1d4f7e8369c7d6e170a424dd99a2a Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Fri, 3 Feb 2023 12:56:31 +0300 Subject: [PATCH 021/147] read messages to internal buffer --- ydb/_topic_reader/datatypes.py | 90 +++++++ ydb/_topic_reader/topic_reader.py | 89 +------ ydb/_topic_reader/topic_reader_asyncio.py | 68 +++-- .../topic_reader_asyncio_test.py | 239 +++++++++++++++++- ydb/_topic_wrapper/reader.py | 2 +- 5 files changed, 382 insertions(+), 106 deletions(-) create mode 100644 ydb/_topic_reader/datatypes.py diff --git a/ydb/_topic_reader/datatypes.py b/ydb/_topic_reader/datatypes.py new file mode 100644 index 00000000..f7076c6c --- /dev/null +++ b/ydb/_topic_reader/datatypes.py @@ -0,0 +1,90 @@ +import abc +import enum +from dataclasses import dataclass +import datetime +from typing import Mapping, Union, Any, List + + +class ICommittable(abc.ABC): + @property + @abc.abstractmethod + def start_offset(self) -> int: + pass + + @property + @abc.abstractmethod + def end_offset(self) -> int: + pass + + +class ISessionAlive(abc.ABC): + @property + @abc.abstractmethod + def is_alive(self) -> bool: + pass + + +@dataclass +class PublicMessage(ICommittable, ISessionAlive): + seqno: int + created_at: datetime.datetime + message_group_id: str + session_metadata: Mapping[str, str] + offset: int + written_at: datetime.datetime + producer_id: str + data: Union[ + bytes, Any + ] # set as original decompressed bytes or deserialized object if deserializer set in reader + _partition_session: "PartitionSession" + + @property + def start_offset(self) -> int: + raise NotImplementedError() + + @property + def end_offset(self) -> int: + raise NotImplementedError() + + # ISessionAlive implementation + @property + def is_alive(self) -> bool: + raise NotImplementedError() + + +@dataclass +class PartitionSession: + id: int + state: "PartitionSession.State" + topic_path: str + partition_id: int + + def stop(self): + self.state = PartitionSession.State.Stopped + + class State(enum.Enum): + Active = 1 + GracefulShutdown = 2 + Stopped = 3 + + +@dataclass +class PublicBatch(ICommittable, ISessionAlive): + session_metadata: Mapping[str, str] + messages: List[PublicMessage] + _partition_session: PartitionSession + _bytes_size: int + + @property + def start_offset(self) -> int: + raise NotImplementedError() + + @property + def end_offset(self) -> int: + raise NotImplementedError() + + # ISessionAlive implementation + @property + def is_alive(self) -> bool: + state = self._partition_session.state + return state == PartitionSession.State.Active or state == PartitionSession.State.GracefulShutdown diff --git a/ydb/_topic_reader/topic_reader.py b/ydb/_topic_reader/topic_reader.py index d335e40d..ea548ac5 100644 --- a/ydb/_topic_reader/topic_reader.py +++ b/ydb/_topic_reader/topic_reader.py @@ -13,7 +13,7 @@ Iterable, AsyncIterable, AsyncContextManager, - Any, + Any, Dict, ) from ydb._topic_wrapper.common import OffsetsRange, TokenGetterFuncType @@ -47,7 +47,7 @@ async def sessions_stat(self) -> List["SessionStat"]: def messages( self, *, timeout: Union[float, None] = None - ) -> AsyncIterable["Message"]: + ) -> AsyncIterable["PublicMessage"]: """ Block until receive new message @@ -55,7 +55,7 @@ def messages( """ raise NotImplementedError() - async def receive_message(self) -> Union["Message", None]: + async def receive_message(self) -> Union["PublicMessage", None]: """ Block until receive new message @@ -69,7 +69,7 @@ def batches( max_messages: Union[int, None] = None, max_bytes: Union[int, None] = None, timeout: Union[float, None] = None, - ) -> AsyncIterable["Batch"]: + ) -> AsyncIterable["PublicBatch"]: """ Block until receive new batch. All messages in a batch from same partition. @@ -80,7 +80,7 @@ def batches( async def receive_batch( self, *, max_messages: Union[int, None] = None, max_bytes: Union[int, None] - ) -> Union["Batch", None]: + ) -> Union["PublicBatch", None]: """ Get one messages batch from reader. All messages in a batch from same partition. @@ -143,7 +143,7 @@ async def sessions_stat(self) -> List["SessionStat"]: """ raise NotImplementedError() - def messages(self, *, timeout: Union[float, None] = None) -> Iterable["Message"]: + def messages(self, *, timeout: Union[float, None] = None) -> Iterable["PublicMessage"]: """ todo? @@ -155,7 +155,7 @@ def messages(self, *, timeout: Union[float, None] = None) -> Iterable["Message"] """ raise NotImplementedError() - def receive_message(self, *, timeout: Union[float, None] = None) -> "Message": + def receive_message(self, *, timeout: Union[float, None] = None) -> "PublicMessage": """ Block until receive new message It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available. @@ -181,7 +181,7 @@ def batches( max_messages: Union[int, None] = None, max_bytes: Union[int, None] = None, timeout: Union[float, None] = None, - ) -> Iterable["Batch"]: + ) -> Iterable["PublicBatch"]: """ Block until receive new batch. It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available. @@ -197,7 +197,7 @@ def receive_batch( max_messages: Union[int, None] = None, max_bytes: Union[int, None], timeout: Union[float, None] = None, - ) -> Union["Batch", None]: + ) -> Union["PublicBatch", None]: """ Get one messages batch from reader It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available. @@ -271,77 +271,6 @@ class PublicReaderSettings: # retry_policy: Union["RetryPolicy", None] = None -class ICommittable(abc.ABC): - @property - @abc.abstractmethod - def start_offset(self) -> int: - pass - - @property - @abc.abstractmethod - def end_offset(self) -> int: - pass - - -class ISessionAlive(abc.ABC): - @property - @abc.abstractmethod - def is_alive(self) -> bool: - pass - - -class Message(ICommittable, ISessionAlive): - seqno: int - created_at_ns: int - message_group_id: str - session_metadata: Mapping[str, str] - offset: int - written_at_ns: int - producer_id: int - data: Union[ - bytes, Any - ] # set as original decompressed bytes or deserialized object if deserializer set in reader - - def __init__(self): - self.seqno = -1 - self.created_at_ns = -1 - self.data = io.BytesIO() - - @property - def start_offset(self) -> int: - raise NotImplementedError() - - @property - def end_offset(self) -> int: - raise NotImplementedError() - - # ISessionAlive implementation - @property - def is_alive(self) -> bool: - raise NotImplementedError() - - -class Batch(ICommittable, ISessionAlive): - session_metadata: Mapping[str, str] - messages: List[Message] - - def __init__(self): - pass - - @property - def start_offset(self) -> int: - raise NotImplementedError() - - @property - def end_offset(self) -> int: - raise NotImplementedError() - - # ISessionAlive implementation - @property - def is_alive(self) -> bool: - raise NotImplementedError() - - class Events: class OnCommit: topic: str diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index 49fb6e7a..abd1bec1 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -2,10 +2,13 @@ import asyncio import enum +import typing +from collections import deque from dataclasses import dataclass from typing import Optional, Set, Dict import ydb +from .datatypes import PartitionSession, PublicMessage, PublicBatch from .topic_reader import PublicReaderSettings from .._topic_wrapper.common import TokenGetterFuncType, IGrpcWrapperAsyncIO from .._topic_wrapper.reader import StreamReadMessage @@ -36,6 +39,7 @@ class ReaderStream: _background_tasks: Set[asyncio.Task] _partition_sessions: Dict[int, PartitionSession] _buffer_size_bytes: int # use for init request, then for debug purposes only + _message_batches: typing.Deque def __init__(self, settings: PublicReaderSettings): self._token_getter = settings._token_getter @@ -49,6 +53,7 @@ def __init__(self, settings: PublicReaderSettings): self._background_tasks = set() self._partition_sessions = dict() self._buffer_size_bytes = settings.buffer_size_bytes + self._message_batches = deque() async def start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamReadMessage.InitRequest): async with self._lock: @@ -133,8 +138,53 @@ async def _on_partition_session_stop(self, message: StreamReadMessage.StopPartit ) async def _on_read_response(self, message: StreamReadMessage.ReadResponse): + batches = await self._read_response_to_batches(message) + async with self._lock: - pass + self._message_batches.extend(batches) + self._buffer_size_bytes -= message.bytes_size + + async def _read_response_to_batches(self, message: StreamReadMessage.ReadResponse) -> typing.List[PublicBatch]: + batches = [] + + batch_count = 0 + for partition_data in message.partition_data: + batch_count += len(partition_data.batches) + + if batch_count == 0: + return [] + + bytes_per_batch = message.bytes_size // batch_count + additional_bytes_to_last_batch = message.bytes_size - bytes_per_batch * batch_count + + for partition_data in message.partition_data: + async with self._lock: + partition_session = self._partition_sessions[partition_data.partition_session_id] + for server_batch in partition_data.batches: + messages = [] + for message_data in server_batch.message_data: + mess = PublicMessage( + seqno=message_data.seq_no, + created_at=message_data.created_at, + message_group_id=message_data.message_group_id, + session_metadata=server_batch.write_session_meta, + offset=message_data.offset, + written_at=server_batch.written_at, + producer_id=server_batch.producer_id, + data=message_data.data, + _partition_session=partition_session, + ) + messages.append(mess) + batch = PublicBatch( + session_metadata=server_batch.write_session_meta, + messages=messages, + _partition_session=partition_session, + _bytes_size=bytes_per_batch, + ) + batches.append(batch) + + batches[-1]._bytes_size += additional_bytes_to_last_batch + return batches async def _set_first_error(self, err): async with self._lock: @@ -155,19 +205,3 @@ async def close(self): task.cancel() await asyncio.wait(self._background_tasks) - - -@dataclass -class PartitionSession: - id: int - state: "PartitionSession.State" - topic_path: str - partition_id: int - - def stop(self): - self.state = PartitionSession.State.Stopped - - class State(enum.Enum): - Active = 1 - GracefulShutdown = 2 - Stopped = 3 diff --git a/ydb/_topic_reader/topic_reader_asyncio_test.py b/ydb/_topic_reader/topic_reader_asyncio_test.py index 59582cb7..8fd2c35c 100644 --- a/ydb/_topic_reader/topic_reader_asyncio_test.py +++ b/ydb/_topic_reader/topic_reader_asyncio_test.py @@ -5,6 +5,7 @@ import pytest from ydb import aio +from ydb._topic_reader.datatypes import PublicBatch, PublicMessage from ydb._topic_reader.topic_reader import PublicReaderSettings from ydb._topic_reader.topic_reader_asyncio import ReaderStream, PartitionSession from ydb._topic_wrapper.common import OffsetsRange, Codec @@ -45,7 +46,17 @@ def partition_session(self, default_reader_settings): ) @pytest.fixture() - async def stream_reader(self, stream, partition_session, default_reader_settings) -> ReaderStream: + def second_partition_session(self, default_reader_settings): + return PartitionSession( + id=12, + topic_path=default_reader_settings.topic, + partition_id=10, + state=PartitionSession.State.Active, + ) + + @pytest.fixture() + async def stream_reader(self, stream, default_reader_settings, partition_session, + second_partition_session) -> ReaderStream: reader = ReaderStream(default_reader_settings) init_message = object() @@ -81,6 +92,23 @@ async def stream_reader(self, stream, partition_session, default_reader_settings start_partition_resp = await wait_for_fast(stream.from_client.get()) assert isinstance(start_partition_resp.client_message, StreamReadMessage.StartPartitionSessionResponse) + stream.from_server.put_nowait( + StreamReadMessage.FromServer(server_message=StreamReadMessage.StartPartitionSessionRequest( + partition_session=StreamReadMessage.PartitionSession( + partition_session_id=second_partition_session.id, + path=second_partition_session.topic_path, + partition_id=second_partition_session.partition_id, + ), + committed_offset=0, + partition_offsets=OffsetsRange( + start=0, + end=0, + ) + )) + ) + start_partition_resp = await wait_for_fast(stream.from_client.get()) + assert isinstance(start_partition_resp.client_message, StreamReadMessage.StartPartitionSessionResponse) + await asyncio.sleep(0) with pytest.raises(asyncio.QueueEmpty): stream.from_client.get_nowait() @@ -122,13 +150,18 @@ async def test_init_reader(self, stream, default_reader_settings): assert reader._session_id == "test" await reader.close() - async def test_start_partition(self, stream_reader: ReaderStream, stream, default_reader_settings, partition_session): + async def test_start_partition(self, + stream_reader: ReaderStream, + stream, + default_reader_settings, + partition_session, + ): def session_count(): return len(stream_reader._partition_sessions) initial_session_count = session_count() - test_partition_id = partition_session.partition_id+1 + test_partition_id = partition_session.partition_id + 1 test_partition_session_id = partition_session.id + 1 test_topic_path = default_reader_settings.topic + "-asd" @@ -197,7 +230,7 @@ def session_count(): )) resp = await wait_for_fast(stream.from_client.get()) # type: StreamReadMessage.FromClient - assert session_count() == initial_session_count-1 + assert session_count() == initial_session_count - 1 assert partition_session.id not in stream_reader._partition_sessions assert resp.client_message == StreamReadMessage.StopPartitionSessionResponse( partition_session_id=partition_session.id @@ -216,13 +249,20 @@ def session_count(): with pytest.raises(asyncio.QueueEmpty): stream.from_client.get_nowait() - async def test_receive_one_raw_message_from_server(self, stream_reader, stream, partition_session): + async def test_receive_message_from_server(self, stream_reader, stream, partition_session, + second_partition_session): + def reader_batch_count(): + return len(stream_reader._message_batches) + + initial_buffer_size = stream_reader._buffer_size_bytes + initial_batch_count = reader_batch_count() + bytes_size = 10 created_at = datetime.datetime(2020, 1, 1, 18, 12) written_at = datetime.datetime(2023, 2, 1, 18, 12) producer_id = "test-producer-id" data = "123".encode() - + session_meta = {"a": "b"} message_group_id = "test-message-group-id" stream.from_server.put_nowait(StreamReadMessage.FromServer(server_message=StreamReadMessage.ReadResponse( @@ -243,7 +283,7 @@ async def test_receive_one_raw_message_from_server(self, stream_reader, stream, ) ], producer_id=producer_id, - write_session_meta={"a": "b"}, + write_session_meta=session_meta, codec=Codec.CODEC_RAW, written_at=written_at, ) @@ -252,4 +292,187 @@ async def test_receive_one_raw_message_from_server(self, stream_reader, stream, ] ))), - raise NotImplementedError() + await wait_condition(lambda: reader_batch_count() == initial_batch_count + 1) + + assert stream_reader._buffer_size_bytes == initial_buffer_size - bytes_size + + last_batch = stream_reader._message_batches[-1] + assert last_batch == PublicBatch( + session_metadata=session_meta, + messages=[ + PublicMessage( + seqno=2, + created_at=created_at, + message_group_id=message_group_id, + session_metadata=session_meta, + offset=1, + written_at=written_at, + producer_id=producer_id, + data=data, + _partition_session=partition_session, + ) + ], + _partition_session=partition_session, + _bytes_size=bytes_size, + ) + + async def test_read_batches(self, stream_reader, partition_session, second_partition_session): + created_at = datetime.datetime(2020, 2, 1, 18, 12) + created_at2 = datetime.datetime(2020, 2, 2, 18, 12) + created_at3 = datetime.datetime(2020, 2, 3, 18, 12) + created_at4 = datetime.datetime(2020, 2, 4, 18, 12) + written_at = datetime.datetime(2023, 3, 1, 18, 12) + written_at2 = datetime.datetime(2023, 3, 2, 18, 12) + producer_id = "test-producer-id" + producer_id2 = "test-producer-id" + data = "123".encode() + data2 = "1235".encode() + session_meta = {"a": "b"} + session_meta2 = {"b": "c"} + + message_group_id = "test-message-group-id" + message_group_id2 = "test-message-group-id-2" + + batches = await stream_reader._read_response_to_batches( + StreamReadMessage.ReadResponse( + bytes_size=3, + partition_data=[ + StreamReadMessage.ReadResponse.PartitionData( + partition_session_id=partition_session.id, + batches=[ + StreamReadMessage.ReadResponse.Batch( + message_data=[ + StreamReadMessage.ReadResponse.MessageData( + offset=2, + seq_no=3, + created_at=created_at, + data=data, + uncompresed_size=len(data), + message_group_id=message_group_id, + ) + ], + producer_id=producer_id, + write_session_meta=session_meta, + codec=Codec.CODEC_RAW, + written_at=written_at, + ) + ] + ), + StreamReadMessage.ReadResponse.PartitionData( + partition_session_id=second_partition_session.id, + batches=[ + StreamReadMessage.ReadResponse.Batch( + message_data=[ + StreamReadMessage.ReadResponse.MessageData( + offset=1, + seq_no=2, + created_at=created_at2, + data=data, + uncompresed_size=len(data), + message_group_id=message_group_id, + ) + ], + producer_id=producer_id, + write_session_meta=session_meta, + codec=Codec.CODEC_RAW, + written_at=written_at2, + ), + StreamReadMessage.ReadResponse.Batch( + message_data=[ + StreamReadMessage.ReadResponse.MessageData( + offset=2, + seq_no=3, + created_at=created_at3, + data=data2, + uncompresed_size=len(data2), + message_group_id=message_group_id, + ), + StreamReadMessage.ReadResponse.MessageData( + offset=4, + seq_no=5, + created_at=created_at4, + data=data, + uncompresed_size=len(data), + message_group_id=message_group_id2, + ) + ], + producer_id=producer_id2, + write_session_meta=session_meta2, + codec=Codec.CODEC_RAW, + written_at=written_at2, + ) + ] + ), + ] + ) + ) + + last0 = batches[0] + last1 = batches[1] + last2 = batches[2] + + assert last0 == PublicBatch( + session_metadata=session_meta, + messages=[ + PublicMessage( + seqno=3, + created_at=created_at, + message_group_id=message_group_id, + session_metadata=session_meta, + offset=2, + written_at=written_at, + producer_id=producer_id, + data=data, + _partition_session=partition_session, + ) + ], + _partition_session=partition_session, + _bytes_size=1, + ) + assert last1 == PublicBatch( + session_metadata=session_meta, + messages=[ + PublicMessage( + seqno=2, + created_at=created_at2, + message_group_id=message_group_id, + session_metadata=session_meta, + offset=1, + written_at=written_at2, + producer_id=producer_id, + data=data, + _partition_session=second_partition_session, + ) + ], + _partition_session=second_partition_session, + _bytes_size=1, + ) + assert last2 == PublicBatch( + session_metadata=session_meta2, + messages=[ + PublicMessage( + seqno=3, + created_at=created_at3, + message_group_id=message_group_id, + session_metadata=session_meta2, + offset=2, + written_at=written_at2, + producer_id=producer_id2, + data=data2, + _partition_session=second_partition_session, + ), + PublicMessage( + seqno=5, + created_at=created_at4, + message_group_id=message_group_id2, + session_metadata=session_meta2, + offset=4, + written_at=written_at2, + producer_id=producer_id, + data=data, + _partition_session=second_partition_session, + ) + ], + _partition_session=second_partition_session, + _bytes_size=1, + ) diff --git a/ydb/_topic_wrapper/reader.py b/ydb/_topic_wrapper/reader.py index 6b84ec71..62812789 100644 --- a/ydb/_topic_wrapper/reader.py +++ b/ydb/_topic_wrapper/reader.py @@ -61,7 +61,7 @@ class ReadRequest: @dataclass class ReadResponse: - partition_data: List["PartitionData"] + partition_data: List["StreamReadMessage.ReadResponse.PartitionData"] bytes_size: int @dataclass From 86d2877b6a23d2fa480397d3a0b52a4eb84ed867 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Fri, 3 Feb 2023 15:12:49 +0300 Subject: [PATCH 022/147] receive_batch_nowait --- ydb/_topic_reader/datatypes.py | 4 +- ydb/_topic_reader/topic_reader_asyncio.py | 50 +++++++++-- .../topic_reader_asyncio_test.py | 87 +++++++++++++++++++ 3 files changed, 131 insertions(+), 10 deletions(-) diff --git a/ydb/_topic_reader/datatypes.py b/ydb/_topic_reader/datatypes.py index f7076c6c..1c26b272 100644 --- a/ydb/_topic_reader/datatypes.py +++ b/ydb/_topic_reader/datatypes.py @@ -2,7 +2,7 @@ import enum from dataclasses import dataclass import datetime -from typing import Mapping, Union, Any, List +from typing import Mapping, Union, Any, List, Dict class ICommittable(abc.ABC): @@ -29,7 +29,7 @@ class PublicMessage(ICommittable, ISessionAlive): seqno: int created_at: datetime.datetime message_group_id: str - session_metadata: Mapping[str, str] + session_metadata: Dict[str, str] offset: int written_at: datetime.datetime producer_id: str diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index abd1bec1..75d4ebc8 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -1,23 +1,26 @@ from __future__ import annotations import asyncio -import enum import typing from collections import deque -from dataclasses import dataclass from typing import Optional, Set, Dict -import ydb +from ..issues import Error as YdbError from .datatypes import PartitionSession, PublicMessage, PublicBatch from .topic_reader import PublicReaderSettings from .._topic_wrapper.common import TokenGetterFuncType, IGrpcWrapperAsyncIO from .._topic_wrapper.reader import StreamReadMessage -class TopicReaderError(ydb.Error): +class TopicReaderError(YdbError): pass +class TopicReaderStreamClosedError(TopicReaderError): + def __init__(self): + super().__init__("Topic reader is closed") + + class PublicAsyncIOReader: pass @@ -31,20 +34,22 @@ class ReaderStream: _session_id: str _init_completed: asyncio.Future[None] _stream: Optional[IGrpcWrapperAsyncIO] + _has_new_messages: asyncio.Event _lock: asyncio.Lock _started: bool _closed: bool - _first_error: Optional[ydb.Error] + _first_error: Optional[YdbError] _background_tasks: Set[asyncio.Task] _partition_sessions: Dict[int, PartitionSession] _buffer_size_bytes: int # use for init request, then for debug purposes only - _message_batches: typing.Deque + _message_batches: typing.Deque[PublicBatch] def __init__(self, settings: PublicReaderSettings): self._token_getter = settings._token_getter self._session_id = "not initialized" self._stream = None + self._has_new_messages = asyncio.Event() self._lock = asyncio.Lock() self._started = False @@ -59,6 +64,7 @@ async def start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamReadMessa async with self._lock: if self._started: raise TopicReaderError("Double start ReaderStream") + self._started = True self._stream = stream @@ -72,6 +78,25 @@ async def start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamReadMessa read_messages_task = asyncio.create_task(self._read_messages(stream)) self._background_tasks.add(read_messages_task) + async def wait_messages(self): + if self._closed: + raise TopicReaderStreamClosedError() + + while len(self._message_batches) == 0: + await self._has_new_messages.wait() + self._has_new_messages.clear() + + def receive_batch_nowait(self): + if self._closed: + raise TopicReaderStreamClosedError() + + try: + batch = self._message_batches.popleft() + self._buffer_release_bytes(batch._bytes_size) + return batch + except IndexError: + return None + async def _read_messages(self, stream: IGrpcWrapperAsyncIO): try: self._stream.write(StreamReadMessage.FromClient( @@ -116,7 +141,7 @@ async def _on_start_partition_session_start(self, message: StreamReadMessage.Sta commit_offset=0, )), ) - except ydb.Error as err: + except YdbError as err: self._set_first_error_locked(err) async def _on_partition_session_stop(self, message: StreamReadMessage.StopPartitionSessionRequest): @@ -142,7 +167,16 @@ async def _on_read_response(self, message: StreamReadMessage.ReadResponse): async with self._lock: self._message_batches.extend(batches) - self._buffer_size_bytes -= message.bytes_size + self._buffer_consume_bytes(message.bytes_size) + + def _buffer_consume_bytes(self, bytes_size): + self._buffer_size_bytes -= bytes_size + + def _buffer_release_bytes(self, bytes_size): + self._buffer_size_bytes += bytes_size + self._stream.write(StreamReadMessage.FromClient(client_message=StreamReadMessage.ReadRequest( + bytes_size=bytes_size, + ))) async def _read_response_to_batches(self, message: StreamReadMessage.ReadResponse) -> typing.List[PublicBatch]: batches = [] diff --git a/ydb/_topic_reader/topic_reader_asyncio_test.py b/ydb/_topic_reader/topic_reader_asyncio_test.py index 8fd2c35c..0e1840a6 100644 --- a/ydb/_topic_reader/topic_reader_asyncio_test.py +++ b/ydb/_topic_reader/topic_reader_asyncio_test.py @@ -31,6 +31,7 @@ def write(self, message: StreamReadMessage.FromClient): @pytest.mark.asyncio class TestReaderStream: + default_batch_size = 1 @pytest.fixture() def stream(self): @@ -119,6 +120,53 @@ async def stream_reader(self, stream, default_reader_settings, partition_session await reader.close() + @staticmethod + def create_message(partition_session: PartitionSession, seqno: int): + return PublicMessage( + seqno=seqno, + created_at=datetime.datetime(2023, 2, 3, 14, 15), + message_group_id="test-message-group", + session_metadata={}, + offset=seqno+1, + written_at=datetime.datetime(2023, 2, 3, 14, 16), + producer_id="test-producer-id", + data=bytes(), + _partition_session=partition_session + ) + + async def send_message(self, stream_reader, message: PublicMessage): + def batch_count(): + return len(stream_reader._message_batches) + + initial_batches = batch_count() + + stream = stream_reader._stream # type: StreamMock + stream.from_server.put_nowait(StreamReadMessage.FromServer(server_message=StreamReadMessage.ReadResponse( + partition_data=[StreamReadMessage.ReadResponse.PartitionData( + partition_session_id=message._partition_session.id, + batches=[ + StreamReadMessage.ReadResponse.Batch( + message_data=[ + StreamReadMessage.ReadResponse.MessageData( + offset=message.offset, + seq_no=message.seqno, + created_at=message.created_at, + data=message.data, + uncompresed_size=len(message.data), + message_group_id=message.message_group_id, + ) + ], + producer_id=message.producer_id, + write_session_meta=message.session_metadata, + codec=Codec.CODEC_RAW, + written_at=message.written_at, + ) + ] + )], + bytes_size=self.default_batch_size, + ))) + await wait_condition(lambda: batch_count() > initial_batches) + async def test_init_reader(self, stream, default_reader_settings): reader = ReaderStream(default_reader_settings) init_message = StreamReadMessage.InitRequest( @@ -476,3 +524,42 @@ async def test_read_batches(self, stream_reader, partition_session, second_parti _partition_session=second_partition_session, _bytes_size=1, ) + + async def test_receive_batch_nowait(self, stream, stream_reader, partition_session): + assert stream_reader.receive_batch_nowait() is None + + mess1 = self.create_message(partition_session, 1) + await self.send_message(stream_reader, mess1) + + mess2 = self.create_message(partition_session, 2) + await self.send_message(stream_reader, mess2) + + initial_buffer_size = stream_reader._buffer_size_bytes + + received = stream_reader.receive_batch_nowait() + assert received == PublicBatch( + mess1.session_metadata, + messages=[ + mess1 + ], + _partition_session=mess1._partition_session, + _bytes_size=self.default_batch_size, + ) + + received = stream_reader.receive_batch_nowait() + assert received == PublicBatch( + mess2.session_metadata, + messages=[ + mess2 + ], + _partition_session=mess2._partition_session, + _bytes_size=self.default_batch_size, + ) + + assert stream_reader._buffer_size_bytes == initial_buffer_size + 2 * self.default_batch_size + + assert StreamReadMessage.ReadRequest(self.default_batch_size) == stream.from_client.get_nowait().client_message + assert StreamReadMessage.ReadRequest(self.default_batch_size) == stream.from_client.get_nowait().client_message + + with pytest.raises(asyncio.QueueEmpty): + stream.from_client.get_nowait() From 2c68fed256658322064975f63baf3c00ee199402 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Fri, 3 Feb 2023 15:35:23 +0300 Subject: [PATCH 023/147] switch from mutex to state_changed event --- ydb/_topic_reader/topic_reader_asyncio.py | 167 +++++++++--------- .../topic_reader_asyncio_test.py | 2 +- 2 files changed, 80 insertions(+), 89 deletions(-) diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index 75d4ebc8..27302663 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -32,59 +32,55 @@ class ReaderReconnector: class ReaderStream: _token_getter: Optional[TokenGetterFuncType] _session_id: str - _init_completed: asyncio.Future[None] _stream: Optional[IGrpcWrapperAsyncIO] - _has_new_messages: asyncio.Event - - _lock: asyncio.Lock _started: bool - _closed: bool - _first_error: Optional[YdbError] _background_tasks: Set[asyncio.Task] _partition_sessions: Dict[int, PartitionSession] _buffer_size_bytes: int # use for init request, then for debug purposes only + + _state_changed: asyncio.Event + _closed: bool + _first_error: Optional[YdbError] _message_batches: typing.Deque[PublicBatch] def __init__(self, settings: PublicReaderSettings): self._token_getter = settings._token_getter self._session_id = "not initialized" self._stream = None - self._has_new_messages = asyncio.Event() - - self._lock = asyncio.Lock() self._started = False - self._closed = False - self._first_error = None self._background_tasks = set() self._partition_sessions = dict() self._buffer_size_bytes = settings.buffer_size_bytes + + self._state_changed = asyncio.Event() + self._closed = False + self._first_error = None self._message_batches = deque() async def start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamReadMessage.InitRequest): - async with self._lock: - if self._started: - raise TopicReaderError("Double start ReaderStream") + if self._started: + raise TopicReaderError("Double start ReaderStream") - self._started = True - self._stream = stream + self._started = True + self._stream = stream - stream.write(StreamReadMessage.FromClient(client_message=init_message)) - init_response = await stream.receive() # type: StreamReadMessage.FromServer - if isinstance(init_response.server_message, StreamReadMessage.InitResponse): - self._session_id = init_response.server_message.session_id - else: - raise TopicReaderError("Unexpected message after InitRequest: %s", init_response) + stream.write(StreamReadMessage.FromClient(client_message=init_message)) + init_response = await stream.receive() # type: StreamReadMessage.FromServer + if isinstance(init_response.server_message, StreamReadMessage.InitResponse): + self._session_id = init_response.server_message.session_id + else: + raise TopicReaderError("Unexpected message after InitRequest: %s", init_response) - read_messages_task = asyncio.create_task(self._read_messages(stream)) - self._background_tasks.add(read_messages_task) + read_messages_task = asyncio.create_task(self._read_messages(stream)) + self._background_tasks.add(read_messages_task) async def wait_messages(self): if self._closed: raise TopicReaderStreamClosedError() while len(self._message_batches) == 0: - await self._has_new_messages.wait() - self._has_new_messages.clear() + await self._state_changed.wait() + self._state_changed.clear() def receive_batch_nowait(self): if self._closed: @@ -107,67 +103,65 @@ async def _read_messages(self, stream: IGrpcWrapperAsyncIO): while True: message = await stream.receive() # type: StreamReadMessage.FromServer if isinstance(message.server_message, StreamReadMessage.ReadResponse): - await self._on_read_response(message.server_message) + self._on_read_response(message.server_message) elif isinstance(message.server_message, StreamReadMessage.StartPartitionSessionRequest): - await self._on_start_partition_session_start(message.server_message) + self._on_start_partition_session_start(message.server_message) elif isinstance(message.server_message, StreamReadMessage.StopPartitionSessionRequest): - await self._on_partition_session_stop(message.server_message) + self._on_partition_session_stop(message.server_message) else: raise NotImplementedError( "Unexpected type of StreamReadMessage.FromServer message: %s" % message.server_message ) + + self._state_changed.set() except Exception as e: - await self._set_first_error(e) + self._set_first_error(e) raise e - async def _on_start_partition_session_start(self, message: StreamReadMessage.StartPartitionSessionRequest): - async with self._lock: - try: - if message.partition_session.partition_session_id in self._partition_sessions: - raise TopicReaderError( - "Double start partition session: %s" % message.partition_session.partition_session_id - ) - - self._partition_sessions[message.partition_session.partition_session_id] = PartitionSession( - id=message.partition_session.partition_session_id, - state=PartitionSession.State.Active, - topic_path=message.partition_session.path, - partition_id=message.partition_session.partition_id, - ) - self._stream.write(StreamReadMessage.FromClient( - client_message=StreamReadMessage.StartPartitionSessionResponse( - partition_session_id=message.partition_session.partition_session_id, - read_offset=0, - commit_offset=0, - )), - ) - except YdbError as err: - self._set_first_error_locked(err) - - async def _on_partition_session_stop(self, message: StreamReadMessage.StopPartitionSessionRequest): - async with self._lock: - partition = self._partition_sessions.get(message.partition_session_id) - if partition is None: - # may if receive stop partition with graceful=false after response on stop partition - # with graceful=true and remove partition from internal dictionary - return - - del self._partition_sessions[message.partition_session_id] - partition.stop() - - if message.graceful: - self._stream.write(StreamReadMessage.FromClient( - client_message=StreamReadMessage.StopPartitionSessionResponse( - partition_session_id=message.partition_session_id, - )) + def _on_start_partition_session_start(self, message: StreamReadMessage.StartPartitionSessionRequest): + try: + if message.partition_session.partition_session_id in self._partition_sessions: + raise TopicReaderError( + "Double start partition session: %s" % message.partition_session.partition_session_id ) - async def _on_read_response(self, message: StreamReadMessage.ReadResponse): - batches = await self._read_response_to_batches(message) + self._partition_sessions[message.partition_session.partition_session_id] = PartitionSession( + id=message.partition_session.partition_session_id, + state=PartitionSession.State.Active, + topic_path=message.partition_session.path, + partition_id=message.partition_session.partition_id, + ) + self._stream.write(StreamReadMessage.FromClient( + client_message=StreamReadMessage.StartPartitionSessionResponse( + partition_session_id=message.partition_session.partition_session_id, + read_offset=0, + commit_offset=0, + )), + ) + except YdbError as err: + self._set_first_error(err) + + def _on_partition_session_stop(self, message: StreamReadMessage.StopPartitionSessionRequest): + partition = self._partition_sessions.get(message.partition_session_id) + if partition is None: + # may if receive stop partition with graceful=false after response on stop partition + # with graceful=true and remove partition from internal dictionary + return + + del self._partition_sessions[message.partition_session_id] + partition.stop() + + if message.graceful: + self._stream.write(StreamReadMessage.FromClient( + client_message=StreamReadMessage.StopPartitionSessionResponse( + partition_session_id=message.partition_session_id, + )) + ) - async with self._lock: - self._message_batches.extend(batches) - self._buffer_consume_bytes(message.bytes_size) + def _on_read_response(self, message: StreamReadMessage.ReadResponse): + batches = self._read_response_to_batches(message) + self._message_batches.extend(batches) + self._buffer_consume_bytes(message.bytes_size) def _buffer_consume_bytes(self, bytes_size): self._buffer_size_bytes -= bytes_size @@ -178,7 +172,7 @@ def _buffer_release_bytes(self, bytes_size): bytes_size=bytes_size, ))) - async def _read_response_to_batches(self, message: StreamReadMessage.ReadResponse) -> typing.List[PublicBatch]: + def _read_response_to_batches(self, message: StreamReadMessage.ReadResponse) -> typing.List[PublicBatch]: batches = [] batch_count = 0 @@ -192,8 +186,7 @@ async def _read_response_to_batches(self, message: StreamReadMessage.ReadRespons additional_bytes_to_last_batch = message.bytes_size - bytes_per_batch * batch_count for partition_data in message.partition_data: - async with self._lock: - partition_session = self._partition_sessions[partition_data.partition_session_id] + partition_session = self._partition_sessions[partition_data.partition_session_id] for server_batch in partition_data.batches: messages = [] for message_data in server_batch.message_data: @@ -220,20 +213,18 @@ async def _read_response_to_batches(self, message: StreamReadMessage.ReadRespons batches[-1]._bytes_size += additional_bytes_to_last_batch return batches - async def _set_first_error(self, err): - async with self._lock: - self._set_first_error_locked(err) - - def _set_first_error_locked(self, err): + def _set_first_error(self, err): if self._first_error is None: self._first_error = err + self._state_changed.set() async def close(self): - async with self._lock: - if self._closed: - raise TopicReaderError(message="Double closed ReaderStream") - self._closed = True - self._set_first_error_locked(TopicReaderError("Reader closed")) + if self._closed: + raise TopicReaderError(message="Double closed ReaderStream") + + self._closed = True + self._set_first_error(TopicReaderError("Reader closed")) + self._state_changed.set() for task in self._background_tasks: task.cancel() diff --git a/ydb/_topic_reader/topic_reader_asyncio_test.py b/ydb/_topic_reader/topic_reader_asyncio_test.py index 0e1840a6..2e087431 100644 --- a/ydb/_topic_reader/topic_reader_asyncio_test.py +++ b/ydb/_topic_reader/topic_reader_asyncio_test.py @@ -381,7 +381,7 @@ async def test_read_batches(self, stream_reader, partition_session, second_parti message_group_id = "test-message-group-id" message_group_id2 = "test-message-group-id-2" - batches = await stream_reader._read_response_to_batches( + batches = stream_reader._read_response_to_batches( StreamReadMessage.ReadResponse( bytes_size=3, partition_data=[ From 7ac9adc3d6870fd75922daa9ab3e8c82d0ecdd59 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Fri, 3 Feb 2023 15:37:08 +0300 Subject: [PATCH 024/147] sync --- ydb/_topic_reader/topic_reader_asyncio.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index 27302663..b74d7c83 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -71,7 +71,7 @@ async def start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamReadMessa else: raise TopicReaderError("Unexpected message after InitRequest: %s", init_response) - read_messages_task = asyncio.create_task(self._read_messages(stream)) + read_messages_task = asyncio.create_task(self._read_messages_loop(stream)) self._background_tasks.add(read_messages_task) async def wait_messages(self): @@ -93,7 +93,7 @@ def receive_batch_nowait(self): except IndexError: return None - async def _read_messages(self, stream: IGrpcWrapperAsyncIO): + async def _read_messages_loop(self, stream: IGrpcWrapperAsyncIO): try: self._stream.write(StreamReadMessage.FromClient( client_message=StreamReadMessage.ReadRequest( From 9cd569533dc2e6edfa640b790f613d33e966e621 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Fri, 3 Feb 2023 17:05:49 +0300 Subject: [PATCH 025/147] sync --- ydb/_topic_reader/topic_reader.py | 11 ++ ydb/_topic_reader/topic_reader_asyncio.py | 67 +++++++++++- .../topic_reader_asyncio_test.py | 4 +- ydb/_topic_wrapper/common.py | 9 +- ydb/_topic_wrapper/reader.py | 103 ++++++++++++++++-- 5 files changed, 178 insertions(+), 16 deletions(-) diff --git a/ydb/_topic_reader/topic_reader.py b/ydb/_topic_reader/topic_reader.py index ea548ac5..8d818217 100644 --- a/ydb/_topic_reader/topic_reader.py +++ b/ydb/_topic_reader/topic_reader.py @@ -17,6 +17,7 @@ ) from ydb._topic_wrapper.common import OffsetsRange, TokenGetterFuncType +from ydb._topic_wrapper.reader import StreamReadMessage class Selector: @@ -270,6 +271,16 @@ class PublicReaderSettings: # connection_timeout: Union[float, None] = None # retry_policy: Union["RetryPolicy", None] = None + def _init_message(self) -> StreamReadMessage.InitRequest: + return StreamReadMessage.InitRequest( + topics_read_settings=[ + StreamReadMessage.InitRequest.TopicReadSettings( + path=self.topic, + ) + ], + consumer=self.consumer, + ) + class Events: class OnCommit: diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index b74d7c83..b96d5637 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -2,13 +2,16 @@ import asyncio import typing +from asyncio import Task from collections import deque from typing import Optional, Set, Dict +from .. import _apis +from ..aio import Driver from ..issues import Error as YdbError from .datatypes import PartitionSession, PublicMessage, PublicBatch from .topic_reader import PublicReaderSettings -from .._topic_wrapper.common import TokenGetterFuncType, IGrpcWrapperAsyncIO +from .._topic_wrapper.common import TokenGetterFuncType, IGrpcWrapperAsyncIO, SupportedDriverType, GrpcWrapperAsyncIO from .._topic_wrapper.reader import StreamReadMessage @@ -22,11 +25,52 @@ def __init__(self): class PublicAsyncIOReader: - pass + _loop: asyncio.AbstractEventLoop + _reconnector: ReaderReconnector + + def __init__(self, driver: Driver, settings: PublicReaderSettings): + self._loop = asyncio.get_running_loop() + self._reconnector = ReaderReconnector(driver, settings) class ReaderReconnector: - pass + _settings: PublicReaderSettings + _driver: Driver + _background_tasks: Set[Task] + + _state_changed: asyncio.Event + _stream_reader: Optional["ReaderStream"] + + def __init__(self, driver: Driver, settings: PublicReaderSettings): + self._settings = settings + self._driver = driver + self._background_tasks = set() + + self._state_changed = asyncio.Event() + self._stream_reader = None + self._background_tasks.add(asyncio.create_task(self.start())) + + async def start(self): + self._stream_reader = await ReaderStream.create(self._driver, self._settings) + self._state_changed.set() + + async def wait_message(self): + while True: + if self._stream_reader is not None: + await self._stream_reader.wait_messages() + + await self._state_changed.wait() + self._state_changed.clear() + + def receive_batch_nowait(self): + return self._stream_reader.receive_batch_nowait() + + async def close(self): + await self._stream_reader.close() + for task in self._background_tasks: + task.cancel() + + await asyncio.wait(self._background_tasks) class ReaderStream: @@ -57,7 +101,22 @@ def __init__(self, settings: PublicReaderSettings): self._first_error = None self._message_batches = deque() - async def start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamReadMessage.InitRequest): + @staticmethod + async def create( + driver: SupportedDriverType, + settings: PublicReaderSettings, + ) -> "ReaderStream": + stream = GrpcWrapperAsyncIO(StreamReadMessage.FromServer.from_proto) + + await stream.start( + driver, _apis.TopicService.Stub, _apis.TopicService.StreamRead + ) + + reader = ReaderStream(settings) + await reader._start(stream, settings._init_message()) + return reader + + async def _start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamReadMessage.InitRequest): if self._started: raise TopicReaderError("Double start ReaderStream") diff --git a/ydb/_topic_reader/topic_reader_asyncio_test.py b/ydb/_topic_reader/topic_reader_asyncio_test.py index 2e087431..58e332f7 100644 --- a/ydb/_topic_reader/topic_reader_asyncio_test.py +++ b/ydb/_topic_reader/topic_reader_asyncio_test.py @@ -62,7 +62,7 @@ async def stream_reader(self, stream, default_reader_settings, partition_session init_message = object() # noinspection PyTypeChecker - start = asyncio.create_task(reader.start(stream, init_message)) + start = asyncio.create_task(reader._start(stream, init_message)) stream.from_server.put_nowait(StreamReadMessage.FromServer( StreamReadMessage.InitResponse(session_id="test-session") @@ -178,7 +178,7 @@ async def test_init_reader(self, stream, default_reader_settings): read_from=None, )] ) - start_task = asyncio.create_task(reader.start(stream, init_message)) + start_task = asyncio.create_task(reader._start(stream, init_message)) sent_message = await wait_for_fast(stream.from_client.get()) expected_sent_init_message = StreamReadMessage.FromClient(client_message=init_message) diff --git a/ydb/_topic_wrapper/common.py b/ydb/_topic_wrapper/common.py index e1d228f9..30f39cb2 100644 --- a/ydb/_topic_wrapper/common.py +++ b/ydb/_topic_wrapper/common.py @@ -33,10 +33,17 @@ class Codec(IntEnum): @dataclass -class OffsetsRange: +class OffsetsRange(IFromProto): start: int end: int + @staticmethod + def from_proto(msg: ydb_topic_pb2.OffsetsRange) -> "OffsetsRange": + return OffsetsRange( + start=msg.start, + end=msg.end, + ) + class IToProto(abc.ABC): @abc.abstractmethod diff --git a/ydb/_topic_wrapper/reader.py b/ydb/_topic_wrapper/reader.py index 62812789..ee4de232 100644 --- a/ydb/_topic_wrapper/reader.py +++ b/ydb/_topic_wrapper/reader.py @@ -1,7 +1,10 @@ import datetime +import typing from dataclasses import dataclass, field from typing import List, Union, Dict +from google.protobuf.message import Message + from ydb._topic_wrapper.common import OffsetsRange, IToProto, UpdateTokenRequest, UpdateTokenResponse, IFromProto from google.protobuf.duration_pb2 import Duration as ProtoDuration @@ -14,11 +17,19 @@ class StreamReadMessage: @dataclass - class PartitionSession: + class PartitionSession(IFromProto): partition_session_id: int path: str partition_id: int + @staticmethod + def from_proto(msg: ydb_topic_pb2.StreamReadMessage.PartitionSession) -> "StreamReadMessage.PartitionSession": + return StreamReadMessage.PartitionSession( + partition_session_id=msg.partition_session_id, + path=msg.path, + partition_id=msg.partition_id, + ) + @dataclass class InitRequest(IToProto): topics_read_settings: List["StreamReadMessage.InitRequest.TopicReadSettings"] @@ -56,16 +67,31 @@ def from_proto(msg: ydb_topic_pb2.StreamReadMessage.InitResponse) -> "StreamRead return StreamReadMessage.InitResponse(session_id=msg.session_id) @dataclass - class ReadRequest: + class ReadRequest(IToProto): bytes_size: int + def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.ReadRequest: + res = ydb_topic_pb2.StreamReadMessage.ReadRequest() + res.bytes_size = self.bytes_size + return res + @dataclass - class ReadResponse: + class ReadResponse(IFromProto): partition_data: List["StreamReadMessage.ReadResponse.PartitionData"] bytes_size: int + @staticmethod + def from_proto(msg: ydb_topic_pb2.StreamReadMessage.ReadResponse) -> "StreamReadMessage.ReadResponse": + partition_data = [] + for proto_partition_data in msg.partition_data: + partition_data.append(StreamReadMessage.ReadResponse.PartitionData.from_proto(proto_partition_data)) + return StreamReadMessage.ReadResponse( + partition_data=partition_data, + bytes_size=msg.bytes_size, + ) + @dataclass - class MessageData: + class MessageData(IFromProto): offset: int seq_no: int created_at: datetime.datetime @@ -73,19 +99,58 @@ class MessageData: uncompresed_size: int message_group_id: str + @staticmethod + def from_proto(msg: ydb_topic_pb2.StreamReadMessage.ReadResponse.MessageData) ->\ + "StreamReadMessage.ReadResponse.MessageData": + return StreamReadMessage.ReadResponse.MessageData( + offset=msg.offset, + seq_no=msg.seq_no, + created_at=msg.created_at.ToDatetime(), + data=msg.data, + uncompresed_size=msg.uncompressed_size, + message_group_id=msg.message_group_id + ) + @dataclass - class Batch: + class Batch(IFromProto): message_data: List["StreamReadMessage.ReadResponse.MessageData"] producer_id: str write_session_meta: Dict[str, str] codec: int written_at: datetime.datetime + @staticmethod + def from_proto(msg: ydb_topic_pb2.StreamReadMessage.ReadResponse.Batch) -> \ + "StreamReadMessage.ReadResponse.Batch": + message_data = [] + for message in msg.message_data: + message_data.append(StreamReadMessage.ReadResponse.MessageData.from_proto(message)) + return StreamReadMessage.ReadResponse.Batch( + message_data=message_data, + producer_id=msg.producer_id, + write_session_meta=dict(msg.write_session_meta), + codec=msg.codec, + written_at=msg.written_at.ToDatetime(), + ) + + @dataclass - class PartitionData: + class PartitionData(IFromProto): partition_session_id: int batches: List["StreamReadMessage.ReadResponse.Batch"] + @staticmethod + def from_proto(msg: ydb_topic_pb2.StreamReadMessage.ReadResponse.PartitionData) ->\ + "StreamReadMessage.ReadResponse.PartitionData": + batches = [] + for proto_batch in msg.batches: + batches.append(StreamReadMessage.ReadResponse.Batch.from_proto(proto_batch)) + return StreamReadMessage.ReadResponse.PartitionData( + partition_session_id=msg.partition_session_id, + batches=batches, + ) + + @dataclass class CommitOffsetRequest: commit_offsets: List["PartitionCommitOffset"] @@ -116,17 +181,33 @@ class PartitionSessionStatusResponse: write_time_high_watermark: float @dataclass - class StartPartitionSessionRequest: + class StartPartitionSessionRequest(IFromProto): partition_session: "StreamReadMessage.PartitionSession" committed_offset: int partition_offsets: OffsetsRange + @staticmethod + def from_proto(msg: ydb_topic_pb2.StreamReadMessage.StartPartitionSessionRequest) -> \ + "StreamReadMessage.StartPartitionSessionRequest": + return StreamReadMessage.StartPartitionSessionRequest( + partition_session=StreamReadMessage.PartitionSession.from_proto(msg.partition_session), + committed_offset=msg.committed_offset, + partition_offsets=OffsetsRange.from_proto(msg.partition_offsets) + ) + @dataclass - class StartPartitionSessionResponse: + class StartPartitionSessionResponse(IToProto): partition_session_id: int read_offset: int commit_offset: int + def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.StartPartitionSessionResponse: + res = ydb_topic_pb2.StreamReadMessage.StartPartitionSessionResponse() + res.partition_session_id = self.partition_session_id + res.read_offset = self.read_offset + res.commit_offset = self.commit_offset + return res + @dataclass class StopPartitionSessionRequest: partition_session_id: int @@ -159,7 +240,11 @@ class FromServer(IFromProto): @staticmethod def from_proto(msg: ydb_topic_pb2.StreamReadMessage.FromServer) -> "StreamReadMessage.FromServer": mess_type = msg.WhichOneof("server_message") - if mess_type == "init_response": + if mess_type == "read_response": + return StreamReadMessage.FromServer( + server_message=StreamReadMessage.ReadResponse.from_proto(msg.init_response) + ) + elif mess_type == "init_response": return StreamReadMessage.FromServer( server_message=StreamReadMessage.InitResponse.from_proto(msg.init_response), ) From 89ebf8df14f0a675395d56f7d363eece0b26dd3e Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Fri, 3 Feb 2023 17:10:58 +0300 Subject: [PATCH 026/147] fix type order --- ydb/_topic_wrapper/common.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/ydb/_topic_wrapper/common.py b/ydb/_topic_wrapper/common.py index 30f39cb2..9111bc2c 100644 --- a/ydb/_topic_wrapper/common.py +++ b/ydb/_topic_wrapper/common.py @@ -32,19 +32,6 @@ class Codec(IntEnum): CODEC_ZSTD = 4 -@dataclass -class OffsetsRange(IFromProto): - start: int - end: int - - @staticmethod - def from_proto(msg: ydb_topic_pb2.OffsetsRange) -> "OffsetsRange": - return OffsetsRange( - start=msg.start, - end=msg.end, - ) - - class IToProto(abc.ABC): @abc.abstractmethod def to_proto(self) -> Message: @@ -62,6 +49,19 @@ def from_proto(msg: Message) -> typing.Any: pass +@dataclass +class OffsetsRange(IFromProto): + start: int + end: int + + @staticmethod + def from_proto(msg: ydb_topic_pb2.OffsetsRange) -> "OffsetsRange": + return OffsetsRange( + start=msg.start, + end=msg.end, + ) + + class QueueToIteratorAsyncIO: __slots__ = ("_queue",) From dfe7620e61a8b6c3e48ed13c0963d54cbec8a40e Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Fri, 3 Feb 2023 18:40:52 +0300 Subject: [PATCH 027/147] read message from real servre --- tests/conftest.py | 28 ++++++++++++++++++++++- tests/topics/test_topic_reader.py | 16 +++++++++++++ ydb/_topic_reader/topic_reader_asyncio.py | 7 ++++++ ydb/_topic_wrapper/common.py | 9 +++++--- ydb/_topic_wrapper/reader.py | 12 ++++++++-- ydb/topic.py | 2 ++ 6 files changed, 68 insertions(+), 6 deletions(-) create mode 100644 tests/topics/test_topic_reader.py diff --git a/tests/conftest.py b/tests/conftest.py index 2681037b..8cf8ab0d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -100,11 +100,17 @@ async def driver(endpoint, database, event_loop): @pytest.fixture() -def topic_path(endpoint) -> str: +def topic_consumer(): + return "fixture-consumer" + + +@pytest.fixture() +def topic_path(endpoint, topic_consumer) -> str: subprocess.run( """docker-compose exec -T ydb /ydb -e grpc://%s -d /local topic drop /local/test-topic""" % endpoint, shell=True, + capture_output=True, ) res = subprocess.run( """docker-compose exec -T ydb /ydb -e grpc://%s -d /local topic create /local/test-topic""" @@ -114,4 +120,24 @@ def topic_path(endpoint) -> str: ) assert res.returncode == 0, res.stderr + res.stdout + res = subprocess.run( + """docker-compose exec -T ydb /ydb -e grpc://%s -d /local topic consumer add --consumer %s /local/test-topic""" + % (endpoint, topic_consumer), + shell=True, + capture_output=True, + ) + assert res.returncode == 0, res.stderr + res.stdout + return "/local/test-topic" + + +@pytest.fixture() +@pytest.mark.asyncio() +async def topic_with_messages(driver, topic_path): + pass + writer = driver.topic_client.topic_writer(topic_path, producer_and_message_group_id="fixture-producer-id") + await writer.write_with_ack( + ydb.TopicWriterMessage(data="123".encode()), + ydb.TopicWriterMessage(data="456".encode()), + ) + await writer.close() diff --git a/tests/topics/test_topic_reader.py b/tests/topics/test_topic_reader.py new file mode 100644 index 00000000..d775db32 --- /dev/null +++ b/tests/topics/test_topic_reader.py @@ -0,0 +1,16 @@ +import pytest + +from ydb._topic_reader.topic_reader_asyncio import PublicAsyncIOReader +from ydb import TopicReaderSettings + + +@pytest.mark.asyncio +class TestTopicWriterAsyncIO: + async def test_read_message(self, driver, topic_path, topic_with_messages, topic_consumer): + reader = PublicAsyncIOReader(driver, TopicReaderSettings( + consumer=topic_consumer, + topic=topic_path, + )) + await reader.wait_messages() + + assert reader.receive_batch() is not None diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index b96d5637..0ad3ff2f 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -32,6 +32,12 @@ def __init__(self, driver: Driver, settings: PublicReaderSettings): self._loop = asyncio.get_running_loop() self._reconnector = ReaderReconnector(driver, settings) + async def wait_messages(self): + await self._reconnector.wait_message() + + def receive_batch(self): + return self._reconnector.receive_batch_nowait() + class ReaderReconnector: _settings: PublicReaderSettings @@ -58,6 +64,7 @@ async def wait_message(self): while True: if self._stream_reader is not None: await self._stream_reader.wait_messages() + return await self._state_changed.wait() self._state_changed.clear() diff --git a/ydb/_topic_wrapper/common.py b/ydb/_topic_wrapper/common.py index 9111bc2c..d45b629c 100644 --- a/ydb/_topic_wrapper/common.py +++ b/ydb/_topic_wrapper/common.py @@ -179,11 +179,14 @@ async def _start_sync_driver(self, driver: ydb.Driver, stub, method): async def receive(self) -> typing.Any: # todo handle grpc exceptions and convert it to internal exceptions - grpc_item = await self.from_server_grpc.__anext__() - return self.convert_server_grpc_to_wrapper(grpc_item) + grpc_message = await self.from_server_grpc.__anext__() + # print("rekby, grpc, received", grpc_message) + return self.convert_server_grpc_to_wrapper(grpc_message) def write(self, wrap_message: IToProto): - self.from_client_grpc.put_nowait(wrap_message.to_proto()) + grpc_message=wrap_message.to_proto() + # print("rekby, grpc, send", grpc_message) + self.from_client_grpc.put_nowait(grpc_message) @dataclass(init=False) diff --git a/ydb/_topic_wrapper/reader.py b/ydb/_topic_wrapper/reader.py index ee4de232..8d152efe 100644 --- a/ydb/_topic_wrapper/reader.py +++ b/ydb/_topic_wrapper/reader.py @@ -227,8 +227,12 @@ def __init__(self, client_message: "ReaderMessagesFromClientToServer"): def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.FromClient: res = ydb_topic_pb2.StreamReadMessage.FromClient() - if isinstance(self.client_message, StreamReadMessage.InitRequest): + if isinstance(self.client_message, StreamReadMessage.ReadRequest): + res.read_request.CopyFrom(self.client_message.to_proto()) + elif isinstance(self.client_message, StreamReadMessage.InitRequest): res.init_request.CopyFrom(self.client_message.to_proto()) + elif isinstance(self.client_message, StreamReadMessage.StartPartitionSessionResponse): + res.start_partition_session_response.CopyFrom(self.client_message.to_proto()) else: raise NotImplementedError() return res @@ -242,12 +246,16 @@ def from_proto(msg: ydb_topic_pb2.StreamReadMessage.FromServer) -> "StreamReadMe mess_type = msg.WhichOneof("server_message") if mess_type == "read_response": return StreamReadMessage.FromServer( - server_message=StreamReadMessage.ReadResponse.from_proto(msg.init_response) + server_message=StreamReadMessage.ReadResponse.from_proto(msg.read_response) ) elif mess_type == "init_response": return StreamReadMessage.FromServer( server_message=StreamReadMessage.InitResponse.from_proto(msg.init_response), ) + elif mess_type == "start_partition_session_request": + return StreamReadMessage.FromServer( + server_message=StreamReadMessage.StartPartitionSessionRequest.from_proto(msg.start_partition_session_request) + ) # todo replace exception to log raise NotImplementedError() diff --git a/ydb/topic.py b/ydb/topic.py index 6de3a847..165763c4 100644 --- a/ydb/topic.py +++ b/ydb/topic.py @@ -2,6 +2,7 @@ from . import aio, Credentials from ._topic_reader.topic_reader import ( + PublicReaderSettings as TopicReaderSettings, Reader as TopicReader, ReaderAsyncIO as TopicReaderAsyncIO, Selector as TopicSelector, @@ -18,6 +19,7 @@ RetryPolicy as TopicWriterRetryPolicy, ) + from ydb._topic_writer.topic_writer_asyncio import WriterAsyncIO as TopicWriterAsyncIO From cae886a5c370e3f918ad22e49ce1005d0233e185 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Mon, 6 Feb 2023 11:27:11 +0300 Subject: [PATCH 028/147] sync --- ydb/_topic_reader/topic_reader_asyncio.py | 7 ++++ .../topic_reader_asyncio_test.py | 35 ++++++++++++++---- ydb/_topic_wrapper/common.py | 10 ++++++ ydb/_topic_wrapper/common_test.py | 36 +++++++++++++++++++ 4 files changed, 82 insertions(+), 6 deletions(-) create mode 100644 ydb/_topic_wrapper/common_test.py diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index 0ad3ff2f..fca01987 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -6,6 +6,8 @@ from collections import deque from typing import Optional, Set, Dict +import grpc + from .. import _apis from ..aio import Driver from ..issues import Error as YdbError @@ -145,6 +147,9 @@ async def wait_messages(self): raise TopicReaderStreamClosedError() while len(self._message_batches) == 0: + if self._first_error is not None: + raise self._first_error + await self._state_changed.wait() self._state_changed.clear() @@ -180,6 +185,8 @@ async def _read_messages_loop(self, stream: IGrpcWrapperAsyncIO): ) self._state_changed.set() + except grpc.RpcError as e: + except Exception as e: self._set_first_error(e) raise e diff --git a/ydb/_topic_reader/topic_reader_asyncio_test.py b/ydb/_topic_reader/topic_reader_asyncio_test.py index 58e332f7..9ffbbe25 100644 --- a/ydb/_topic_reader/topic_reader_asyncio_test.py +++ b/ydb/_topic_reader/topic_reader_asyncio_test.py @@ -2,15 +2,17 @@ import datetime from unittest import mock +import grpc import pytest from ydb import aio -from ydb._topic_reader.datatypes import PublicBatch, PublicMessage -from ydb._topic_reader.topic_reader import PublicReaderSettings -from ydb._topic_reader.topic_reader_asyncio import ReaderStream, PartitionSession -from ydb._topic_wrapper.common import OffsetsRange, Codec -from ydb._topic_wrapper.reader import StreamReadMessage -from ydb._topic_wrapper.test_helpers import StreamMock, wait_condition, wait_for_fast +from .datatypes import PublicBatch, PublicMessage +from .topic_reader import PublicReaderSettings +from .topic_reader_asyncio import ReaderStream, PartitionSession +from .._topic_wrapper.common import OffsetsRange, Codec +from .._topic_wrapper.reader import StreamReadMessage +from .._topic_wrapper.test_helpers import StreamMock, wait_condition, wait_for_fast +from ..issues import Unavailable @pytest.fixture() @@ -167,6 +169,21 @@ def batch_count(): ))) await wait_condition(lambda: batch_count() > initial_batches) + async def test_convert_errors_to_ydb(self, stream, stream_reader): + class TestError(grpc.RpcError): + _code: grpc.StatusCode + + def __init__(self, code: grpc.StatusCode): + self._code = code + + def code(self): + return self._code + + stream.from_server.put_nowait(TestError(grpc.StatusCode.UNAVAILABLE)) + + with pytest.raises(Unavailable): + await wait_for_fast(stream_reader.wait_messages()) + async def test_init_reader(self, stream, default_reader_settings): reader = ReaderStream(default_reader_settings) init_message = StreamReadMessage.InitRequest( @@ -563,3 +580,9 @@ async def test_receive_batch_nowait(self, stream, stream_reader, partition_sessi with pytest.raises(asyncio.QueueEmpty): stream.from_client.get_nowait() + +@pytest.mark.asyncio +class TestReaderReconnector: + async def test_start(self): + pass + diff --git a/ydb/_topic_wrapper/common.py b/ydb/_topic_wrapper/common.py index d45b629c..9834dd50 100644 --- a/ydb/_topic_wrapper/common.py +++ b/ydb/_topic_wrapper/common.py @@ -238,3 +238,13 @@ def from_proto(msg: ydb_topic_pb2.UpdateTokenResponse) -> typing.Any: TokenGetterFuncType = typing.Optional[typing.Callable[[], str]] + + +def callback_from_asyncio(callback: typing.Union[typing.Callable, typing.Coroutine]) -> [asyncio.Future, asyncio.Task]: + loop = asyncio.get_running_loop() + + if asyncio.iscoroutinefunction(callback): + return loop.create_task(callback()) + else: + return loop.run_in_executor(None, callback) + diff --git a/ydb/_topic_wrapper/common_test.py b/ydb/_topic_wrapper/common_test.py new file mode 100644 index 00000000..40543c76 --- /dev/null +++ b/ydb/_topic_wrapper/common_test.py @@ -0,0 +1,36 @@ +import asyncio + +import pytest + +from .common import callback_from_asyncio + + +@pytest.mark.asyncio +class Test: + async def test_callback_from_asyncio(self): + class TestError(Exception): + pass + + def sync_success(): + return 1 + + assert await callback_from_asyncio(sync_success) == 1 + + def sync_failed(): + raise TestError() + + with pytest.raises(TestError): + await callback_from_asyncio(sync_failed) + + async def async_success(): + await asyncio.sleep(0) + return 1 + + assert await callback_from_asyncio(async_success) == 1 + + async def async_failed(): + await asyncio.sleep(0) + raise TestError() + + with pytest.raises(TestError): + await callback_from_asyncio(async_failed) From 8a0bda11efa3b44ec8c482cc425274de5b224cc6 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Mon, 6 Feb 2023 16:17:39 +0300 Subject: [PATCH 029/147] handle grpc errors and errors in server status --- ydb/_topic_reader/topic_reader_asyncio.py | 18 ++--- .../topic_reader_asyncio_test.py | 78 ++++++++++++++----- ydb/_topic_wrapper/common.py | 44 ++++++++--- ydb/_topic_wrapper/common_test.py | 58 +++++++++++++- ydb/_topic_wrapper/reader.py | 6 +- 5 files changed, 161 insertions(+), 43 deletions(-) diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index fca01987..2abf70d0 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -8,7 +8,7 @@ import grpc -from .. import _apis +from .. import _apis, issues from ..aio import Driver from ..issues import Error as YdbError from .datatypes import PartitionSession, PublicMessage, PublicBatch @@ -143,19 +143,19 @@ async def _start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamReadMess self._background_tasks.add(read_messages_task) async def wait_messages(self): - if self._closed: - raise TopicReaderStreamClosedError() - - while len(self._message_batches) == 0: + while True: if self._first_error is not None: raise self._first_error + if len(self._message_batches) > 0: + return + await self._state_changed.wait() self._state_changed.clear() def receive_batch_nowait(self): - if self._closed: - raise TopicReaderStreamClosedError() + if self._first_error is not None: + raise self._first_error try: batch = self._message_batches.popleft() @@ -185,8 +185,6 @@ async def _read_messages_loop(self, stream: IGrpcWrapperAsyncIO): ) self._state_changed.set() - except grpc.RpcError as e: - except Exception as e: self._set_first_error(e) raise e @@ -296,7 +294,7 @@ async def close(self): raise TopicReaderError(message="Double closed ReaderStream") self._closed = True - self._set_first_error(TopicReaderError("Reader closed")) + self._set_first_error(TopicReaderStreamClosedError()) self._state_changed.set() for task in self._background_tasks: diff --git a/ydb/_topic_reader/topic_reader_asyncio_test.py b/ydb/_topic_reader/topic_reader_asyncio_test.py index 9ffbbe25..a8a98dfa 100644 --- a/ydb/_topic_reader/topic_reader_asyncio_test.py +++ b/ydb/_topic_reader/topic_reader_asyncio_test.py @@ -5,15 +5,23 @@ import grpc import pytest +import ydb from ydb import aio from .datatypes import PublicBatch, PublicMessage from .topic_reader import PublicReaderSettings from .topic_reader_asyncio import ReaderStream, PartitionSession -from .._topic_wrapper.common import OffsetsRange, Codec +from .._topic_wrapper.common import OffsetsRange, Codec, ServerStatus from .._topic_wrapper.reader import StreamReadMessage from .._topic_wrapper.test_helpers import StreamMock, wait_condition, wait_for_fast from ..issues import Unavailable +# Workaround for good autocomplete in IDE and universal import at runtime +# noinspection PyUnreachableCode +if False: + from .._grpc.v4.protos import ydb_status_codes_pb2 +else: + from .._grpc.common.protos import ydb_status_codes_pb2 + @pytest.fixture() def default_reader_settings(): @@ -58,7 +66,7 @@ def second_partition_session(self, default_reader_settings): ) @pytest.fixture() - async def stream_reader(self, stream, default_reader_settings, partition_session, + async def stream_reader_started(self, stream, default_reader_settings, partition_session, second_partition_session) -> ReaderStream: reader = ReaderStream(default_reader_settings) init_message = object() @@ -67,7 +75,8 @@ async def stream_reader(self, stream, default_reader_settings, partition_session start = asyncio.create_task(reader._start(stream, init_message)) stream.from_server.put_nowait(StreamReadMessage.FromServer( - StreamReadMessage.InitResponse(session_id="test-session") + server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), + server_message=StreamReadMessage.InitResponse(session_id="test-session"), )) init_request = await wait_for_fast(stream.from_client.get()) @@ -77,7 +86,9 @@ async def stream_reader(self, stream, default_reader_settings, partition_session assert isinstance(read_request.client_message, StreamReadMessage.ReadRequest) stream.from_server.put_nowait( - StreamReadMessage.FromServer(server_message=StreamReadMessage.StartPartitionSessionRequest( + StreamReadMessage.FromServer( + server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), + server_message=StreamReadMessage.StartPartitionSessionRequest( partition_session=StreamReadMessage.PartitionSession( partition_session_id=partition_session.id, path=partition_session.topic_path, @@ -96,7 +107,9 @@ async def stream_reader(self, stream, default_reader_settings, partition_session assert isinstance(start_partition_resp.client_message, StreamReadMessage.StartPartitionSessionResponse) stream.from_server.put_nowait( - StreamReadMessage.FromServer(server_message=StreamReadMessage.StartPartitionSessionRequest( + StreamReadMessage.FromServer( + server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), + server_message=StreamReadMessage.StartPartitionSessionRequest( partition_session=StreamReadMessage.PartitionSession( partition_session_id=second_partition_session.id, path=second_partition_session.topic_path, @@ -116,11 +129,22 @@ async def stream_reader(self, stream, default_reader_settings, partition_session with pytest.raises(asyncio.QueueEmpty): stream.from_client.get_nowait() - yield reader + return reader - assert reader._first_error is None + @pytest.fixture() + async def stream_reader(self, stream_reader_started: ReaderStream): + yield stream_reader_started + + assert stream_reader_started._first_error is None + await stream_reader_started.close() + + @pytest.fixture() + async def stream_reader_finish_with_error(self, stream_reader_started: ReaderStream): + yield stream_reader_started + + assert stream_reader_started._first_error is not None + await stream_reader_started.close() - await reader.close() @staticmethod def create_message(partition_session: PartitionSession, seqno: int): @@ -143,7 +167,9 @@ def batch_count(): initial_batches = batch_count() stream = stream_reader._stream # type: StreamMock - stream.from_server.put_nowait(StreamReadMessage.FromServer(server_message=StreamReadMessage.ReadResponse( + stream.from_server.put_nowait(StreamReadMessage.FromServer( + server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), + server_message=StreamReadMessage.ReadResponse( partition_data=[StreamReadMessage.ReadResponse.PartitionData( partition_session_id=message._partition_session.id, batches=[ @@ -169,20 +195,25 @@ def batch_count(): ))) await wait_condition(lambda: batch_count() > initial_batches) - async def test_convert_errors_to_ydb(self, stream, stream_reader): - class TestError(grpc.RpcError): - _code: grpc.StatusCode - - def __init__(self, code: grpc.StatusCode): - self._code = code + async def test_first_error(self, stream, stream_reader_finish_with_error): + class TestError(grpc.RpcError, grpc.Call): + def __init__(self): + pass def code(self): - return self._code + return grpc.StatusCode.UNAUTHENTICATED - stream.from_server.put_nowait(TestError(grpc.StatusCode.UNAVAILABLE)) + def details(self): + return "test error" - with pytest.raises(Unavailable): - await wait_for_fast(stream_reader.wait_messages()) + test_err = TestError() + stream.from_server.put_nowait(test_err) + + with pytest.raises(TestError): + await wait_for_fast(stream_reader_finish_with_error.wait_messages()) + + with pytest.raises(TestError): + stream_reader_finish_with_error.receive_batch_nowait() async def test_init_reader(self, stream, default_reader_settings): reader = ReaderStream(default_reader_settings) @@ -202,6 +233,7 @@ async def test_init_reader(self, stream, default_reader_settings): assert sent_message == expected_sent_init_message stream.from_server.put_nowait(StreamReadMessage.FromServer( + server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), server_message=StreamReadMessage.InitResponse(session_id="test")) ) @@ -231,6 +263,7 @@ def session_count(): test_topic_path = default_reader_settings.topic + "-asd" stream.from_server.put_nowait(StreamReadMessage.FromServer( + server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), server_message=StreamReadMessage.StartPartitionSessionRequest( partition_session=StreamReadMessage.PartitionSession( partition_session_id=test_partition_session_id, @@ -266,6 +299,7 @@ def session_count(): initial_session_count = session_count() stream.from_server.put_nowait(StreamReadMessage.FromServer( + server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), server_message=StreamReadMessage.StopPartitionSessionRequest( partition_session_id=partition_session.id, graceful=False, @@ -287,6 +321,7 @@ def session_count(): initial_session_count = session_count() stream.from_server.put_nowait(StreamReadMessage.FromServer( + server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), server_message=StreamReadMessage.StopPartitionSessionRequest( partition_session_id=partition_session.id, graceful=True, @@ -303,6 +338,7 @@ def session_count(): ) stream.from_server.put_nowait(StreamReadMessage.FromServer( + server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), server_message=StreamReadMessage.StopPartitionSessionRequest( partition_session_id=partition_session.id, graceful=False, @@ -330,7 +366,9 @@ def reader_batch_count(): session_meta = {"a": "b"} message_group_id = "test-message-group-id" - stream.from_server.put_nowait(StreamReadMessage.FromServer(server_message=StreamReadMessage.ReadResponse( + stream.from_server.put_nowait(StreamReadMessage.FromServer( + server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), + server_message=StreamReadMessage.ReadResponse( bytes_size=bytes_size, partition_data=[ StreamReadMessage.ReadResponse.PartitionData( diff --git a/ydb/_topic_wrapper/common.py b/ydb/_topic_wrapper/common.py index 9834dd50..ae148cb9 100644 --- a/ydb/_topic_wrapper/common.py +++ b/ydb/_topic_wrapper/common.py @@ -4,11 +4,15 @@ from dataclasses import dataclass from enum import IntEnum +import grpc from google.protobuf.message import Message import ydb.aio +from .. import issues, connection + # Workaround for good autocomplete in IDE and universal import at runtime +# noinspection PyUnreachableCode if False: from ydb._grpc.v4.protos import ( ydb_status_codes_pb2, @@ -147,16 +151,19 @@ class GrpcWrapperAsyncIO(IGrpcWrapperAsyncIO): from_client_grpc: asyncio.Queue from_server_grpc: typing.AsyncIterator convert_server_grpc_to_wrapper: typing.Callable[[typing.Any], typing.Any] + _connection_state: str def __init__(self, convert_server_grpc_to_wrapper): self.from_client_grpc = asyncio.Queue() self.convert_server_grpc_to_wrapper = convert_server_grpc_to_wrapper + self._connection_state = "new" async def start(self, driver: SupportedDriverType, stub, method): if asyncio.iscoroutinefunction(driver.__call__): await self._start_asyncio_driver(driver, stub, method) else: await self._start_sync_driver(driver, stub, method) + self._connection_state = "started" async def _start_asyncio_driver(self, driver: ydb.aio.Driver, stub, method): requests_iterator = QueueToIteratorAsyncIO(self.from_client_grpc) @@ -179,37 +186,49 @@ async def _start_sync_driver(self, driver: ydb.Driver, stub, method): async def receive(self) -> typing.Any: # todo handle grpc exceptions and convert it to internal exceptions - grpc_message = await self.from_server_grpc.__anext__() + try: + grpc_message = await self.from_server_grpc.__anext__() + except grpc.RpcError as e: + raise connection._rpc_error_handler(self._connection_state, e) + + issues._process_response(grpc_message) + + if self._connection_state != "has_received_messages": + self._connection_state = "has_received_messages" + # print("rekby, grpc, received", grpc_message) return self.convert_server_grpc_to_wrapper(grpc_message) def write(self, wrap_message: IToProto): - grpc_message=wrap_message.to_proto() + grpc_message = wrap_message.to_proto() # print("rekby, grpc, send", grpc_message) self.from_client_grpc.put_nowait(grpc_message) @dataclass(init=False) class ServerStatus(IFromProto): - __slots__ = ("status", "_issues") + __slots__ = ("_grpc_status_code", "_issues") def __init__( - self, - status: ydb_status_codes_pb2.StatusIds.StatusCode, - issues: typing.Iterable[ydb_issue_message_pb2.IssueMessage], + self, + status_code: ydb_status_codes_pb2.StatusIds.StatusCode, + grpc_issues: typing.Iterable[ydb_issue_message_pb2.IssueMessage], ): - self.status = status - self._issues = issues + self._grpc_status_code = status_code + self._issues = grpc_issues def __str__(self): return self.__repr__() @staticmethod - def from_proto(msg: Message) -> "ServerStatus": - return ServerStatus(msg.status) + def from_proto(msg: typing.Union[ + ydb_topic_pb2.StreamReadMessage.FromServer, + ydb_topic_pb2.StreamWriteMessage.FromServer, + ]) -> "ServerStatus": + return ServerStatus(msg.status, msg.issues) def is_success(self) -> bool: - return self.status == ydb_status_codes_pb2.StatusIds.SUCCESS + return self._grpc_status_code == ydb_status_codes_pb2.StatusIds.SUCCESS @classmethod def issue_to_str(cls, issue: ydb_issue_message_pb2.IssueMessage): @@ -248,3 +267,6 @@ def callback_from_asyncio(callback: typing.Union[typing.Callable, typing.Corouti else: return loop.run_in_executor(None, callback) + +def ensure_success_or_raise_error(server_status: ServerStatus): + error = issues._process_response(server_status._grpc_status_code, server_status._issues) diff --git a/ydb/_topic_wrapper/common_test.py b/ydb/_topic_wrapper/common_test.py index 40543c76..b176e314 100644 --- a/ydb/_topic_wrapper/common_test.py +++ b/ydb/_topic_wrapper/common_test.py @@ -1,9 +1,26 @@ import asyncio +import grpc import pytest -from .common import callback_from_asyncio +from .common import callback_from_asyncio, GrpcWrapperAsyncIO +from .. import issues +# Workaround for good autocomplete in IDE and universal import at runtime +# noinspection PyUnreachableCode +if False: + from ydb._grpc.v4.protos import ( + ydb_status_codes_pb2, + ydb_issue_message_pb2, + ydb_topic_pb2, + ) +else: + # noinspection PyUnresolvedReferences + from ydb._grpc.common.protos import ( + ydb_status_codes_pb2, + ydb_issue_message_pb2, + ydb_topic_pb2, + ) @pytest.mark.asyncio class Test: @@ -34,3 +51,42 @@ async def async_failed(): with pytest.raises(TestError): await callback_from_asyncio(async_failed) + + +@pytest.mark.asyncio +class TestGrpcWrapperAsyncIO: + async def test_convert_grpc_errors_to_ydb(self): + class TestError(grpc.RpcError, grpc.Call): + def __init__(self): + pass + + def code(self): + return grpc.StatusCode.UNAUTHENTICATED + + def details(self): + return "test error" + + class FromServerMock: + async def __anext__(self): + raise TestError() + + wrapper = GrpcWrapperAsyncIO(lambda: None) + wrapper.from_server_grpc = FromServerMock() + + with pytest.raises(issues.Unauthenticated): + await wrapper.receive() + + async def convert_status_code_to_ydb_error(self): + class FromServerMock: + async def __anext__(self): + return ydb_topic_pb2.StreamReadMessage.FromServer( + status=ydb_status_codes_pb2.StatusIds.OVERLOADED, + issues=[], + ) + + wrapper = GrpcWrapperAsyncIO(lambda: None) + wrapper.from_server_grpc = FromServerMock() + + with pytest.raises(issues.Overloaded): + await wrapper.receive() + diff --git a/ydb/_topic_wrapper/reader.py b/ydb/_topic_wrapper/reader.py index 8d152efe..85538b8b 100644 --- a/ydb/_topic_wrapper/reader.py +++ b/ydb/_topic_wrapper/reader.py @@ -5,10 +5,12 @@ from google.protobuf.message import Message -from ydb._topic_wrapper.common import OffsetsRange, IToProto, UpdateTokenRequest, UpdateTokenResponse, IFromProto +from ydb._topic_wrapper.common import OffsetsRange, IToProto, UpdateTokenRequest, UpdateTokenResponse, IFromProto, \ + ServerStatus from google.protobuf.duration_pb2 import Duration as ProtoDuration # Workaround for good autocomplete in IDE and universal import at runtime +# noinspection PyUnreachableCode if False: from ydb._grpc.v4.protos import ydb_topic_pb2 else: @@ -240,10 +242,12 @@ def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.FromClient: @dataclass class FromServer(IFromProto): server_message: "ReaderMessagesFromServerToClient" + server_status: ServerStatus @staticmethod def from_proto(msg: ydb_topic_pb2.StreamReadMessage.FromServer) -> "StreamReadMessage.FromServer": mess_type = msg.WhichOneof("server_message") + server_status = ServerStatus.from_proto(ms) if mess_type == "read_response": return StreamReadMessage.FromServer( server_message=StreamReadMessage.ReadResponse.from_proto(msg.read_response) From 0cff44650b15f5bcb57ad1a817972ff86ed1317c Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Mon, 6 Feb 2023 19:11:05 +0300 Subject: [PATCH 030/147] sync --- ydb/_topic_reader/topic_reader.py | 4 ++ ydb/_topic_reader/topic_reader_asyncio.py | 70 +++++++++++++++---- .../topic_reader_asyncio_test.py | 52 +++++++++----- ydb/_topic_wrapper/common.py | 10 +-- ydb/_topic_wrapper/common_test.py | 21 +++++- ydb/_topic_wrapper/reader.py | 7 +- ydb/issues.py | 2 + 7 files changed, 127 insertions(+), 39 deletions(-) diff --git a/ydb/_topic_reader/topic_reader.py b/ydb/_topic_reader/topic_reader.py index 8d818217..41a17add 100644 --- a/ydb/_topic_reader/topic_reader.py +++ b/ydb/_topic_reader/topic_reader.py @@ -16,6 +16,7 @@ Any, Dict, ) +from ydb import RetrySettings from ydb._topic_wrapper.common import OffsetsRange, TokenGetterFuncType from ydb._topic_wrapper.reader import StreamReadMessage @@ -281,6 +282,9 @@ def _init_message(self) -> StreamReadMessage.InitRequest: consumer=self.consumer, ) + def _retry_settings(self)->RetrySettings: + return RetrySettings(idempotent=True) + class Events: class OnCommit: diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index 2abf70d0..6421625e 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -8,13 +8,18 @@ import grpc -from .. import _apis, issues +import ydb +from .. import _apis, issues, RetrySettings from ..aio import Driver -from ..issues import Error as YdbError +from ..issues import ( + Error as YdbError, + _process_response +) from .datatypes import PartitionSession, PublicMessage, PublicBatch from .topic_reader import PublicReaderSettings from .._topic_wrapper.common import TokenGetterFuncType, IGrpcWrapperAsyncIO, SupportedDriverType, GrpcWrapperAsyncIO from .._topic_wrapper.reader import StreamReadMessage +from .._errors import check_retriable_error class TopicReaderError(YdbError): @@ -48,22 +53,42 @@ class ReaderReconnector: _state_changed: asyncio.Event _stream_reader: Optional["ReaderStream"] + _first_error: asyncio.Future[ydb.Error] def __init__(self, driver: Driver, settings: PublicReaderSettings): self._settings = settings self._driver = driver self._background_tasks = set() + self._retry_settins = RetrySettings(idempotent=True) # get from settings self._state_changed = asyncio.Event() self._stream_reader = None - self._background_tasks.add(asyncio.create_task(self.start())) + self._background_tasks.add(asyncio.create_task(self._connection_loop())) + self._first_error = asyncio.get_running_loop().create_future() - async def start(self): - self._stream_reader = await ReaderStream.create(self._driver, self._settings) - self._state_changed.set() + async def _connection_loop(self): + attempt = 0 + while True: + try: + self._stream_reader = await ReaderStream.create(self._driver, self._settings) + self._state_changed.set() + self._stream_reader._state_changed.wait() + except Exception as err: + # todo reset attempts when connection established + + retry_info = check_retriable_error(err, self._settings._retry_settings(), attempt) + if not retry_info.is_retriable: + self._set_first_error(err) + return + await asyncio.sleep(retry_info.sleep_timeout_seconds) + + attempt += 1 async def wait_message(self): while True: + if self._first_error.done(): + raise self._first_error.result() + if self._stream_reader is not None: await self._stream_reader.wait_messages() return @@ -81,6 +106,13 @@ async def close(self): await asyncio.wait(self._background_tasks) + def _set_first_error(self, err: issues.Error): + try: + self._first_error.set_result(err) + self._state_changed.set() + except asyncio.InvalidStateError: + # skip if already has result + pass class ReaderStream: _token_getter: Optional[TokenGetterFuncType] @@ -93,8 +125,8 @@ class ReaderStream: _state_changed: asyncio.Event _closed: bool - _first_error: Optional[YdbError] _message_batches: typing.Deque[PublicBatch] + first_error: asyncio.Future[YdbError] def __init__(self, settings: PublicReaderSettings): self._token_getter = settings._token_getter @@ -107,7 +139,7 @@ def __init__(self, settings: PublicReaderSettings): self._state_changed = asyncio.Event() self._closed = False - self._first_error = None + self.first_error = asyncio.get_running_loop().create_future() self._message_batches = deque() @staticmethod @@ -144,8 +176,8 @@ async def _start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamReadMess async def wait_messages(self): while True: - if self._first_error is not None: - raise self._first_error + if self._get_first_error() is not None: + raise self._get_first_error() if len(self._message_batches) > 0: return @@ -154,8 +186,8 @@ async def wait_messages(self): self._state_changed.clear() def receive_batch_nowait(self): - if self._first_error is not None: - raise self._first_error + if self._get_first_error() is not None: + raise self._get_first_error() try: batch = self._message_batches.popleft() @@ -173,6 +205,7 @@ async def _read_messages_loop(self, stream: IGrpcWrapperAsyncIO): )) while True: message = await stream.receive() # type: StreamReadMessage.FromServer + _process_response(message.server_status) if isinstance(message.server_message, StreamReadMessage.ReadResponse): self._on_read_response(message.server_message) elif isinstance(message.server_message, StreamReadMessage.StartPartitionSessionRequest): @@ -285,9 +318,18 @@ def _read_response_to_batches(self, message: StreamReadMessage.ReadResponse) -> return batches def _set_first_error(self, err): - if self._first_error is None: - self._first_error = err + try: + self.first_error.set_result(err) self._state_changed.set() + except asyncio.InvalidStateError: + # skip later set errors + pass + + def _get_first_error(self): + if self.first_error.done(): + return self.first_error.result() + else: + return None async def close(self): if self._closed: diff --git a/ydb/_topic_reader/topic_reader_asyncio_test.py b/ydb/_topic_reader/topic_reader_asyncio_test.py index a8a98dfa..ba73f322 100644 --- a/ydb/_topic_reader/topic_reader_asyncio_test.py +++ b/ydb/_topic_reader/topic_reader_asyncio_test.py @@ -6,11 +6,11 @@ import pytest import ydb -from ydb import aio +from ydb import aio, issues from .datatypes import PublicBatch, PublicMessage from .topic_reader import PublicReaderSettings -from .topic_reader_asyncio import ReaderStream, PartitionSession -from .._topic_wrapper.common import OffsetsRange, Codec, ServerStatus +from .topic_reader_asyncio import ReaderStream, PartitionSession, ReaderReconnector +from .._topic_wrapper.common import OffsetsRange, Codec, ServerStatus, UpdateTokenResponse from .._topic_wrapper.reader import StreamReadMessage from .._topic_wrapper.test_helpers import StreamMock, wait_condition, wait_for_fast from ..issues import Unavailable @@ -135,14 +135,14 @@ async def stream_reader_started(self, stream, default_reader_settings, partition async def stream_reader(self, stream_reader_started: ReaderStream): yield stream_reader_started - assert stream_reader_started._first_error is None + assert stream_reader_started._get_first_error() is None await stream_reader_started.close() @pytest.fixture() async def stream_reader_finish_with_error(self, stream_reader_started: ReaderStream): yield stream_reader_started - assert stream_reader_started._first_error is not None + assert stream_reader_started._get_first_error() is not None await stream_reader_started.close() @@ -195,16 +195,9 @@ def batch_count(): ))) await wait_condition(lambda: batch_count() > initial_batches) - async def test_first_error(self, stream, stream_reader_finish_with_error): - class TestError(grpc.RpcError, grpc.Call): - def __init__(self): - pass - - def code(self): - return grpc.StatusCode.UNAUTHENTICATED - - def details(self): - return "test error" + async def test_unknown_error(self, stream, stream_reader_finish_with_error): + class TestError(Exception): + pass test_err = TestError() stream.from_server.put_nowait(test_err) @@ -215,6 +208,24 @@ def details(self): with pytest.raises(TestError): stream_reader_finish_with_error.receive_batch_nowait() + async def test_error_from_status_code(self, stream, stream_reader_finish_with_error): + # noinspection PyTypeChecker + stream.from_server.put_nowait( + StreamReadMessage.FromServer( + server_status=ServerStatus( + status=issues.StatusCode.OVERLOADED, + issues=[], + ), + server_message=None, + ) + ) + + with pytest.raises(issues.Overloaded): + await wait_for_fast(stream_reader_finish_with_error.wait_messages()) + + with pytest.raises(issues.Overloaded): + stream_reader_finish_with_error.receive_batch_nowait() + async def test_init_reader(self, stream, default_reader_settings): reader = ReaderStream(default_reader_settings) init_message = StreamReadMessage.InitRequest( @@ -619,8 +630,15 @@ async def test_receive_batch_nowait(self, stream, stream_reader, partition_sessi with pytest.raises(asyncio.QueueEmpty): stream.from_client.get_nowait() + @pytest.mark.asyncio class TestReaderReconnector: - async def test_start(self): - pass + async def test_reconnect_on_repeatable_error(self, monkeypatch): + def stream_create(): + pass + + with mock.patch.object(ReaderStream, "create", stream_create): + reconnector = ReaderReconnector(None, PublicReaderSettings("", "")) + await reconnector.wait_message() + raise NotImplementedError() diff --git a/ydb/_topic_wrapper/common.py b/ydb/_topic_wrapper/common.py index ae148cb9..2955461a 100644 --- a/ydb/_topic_wrapper/common.py +++ b/ydb/_topic_wrapper/common.py @@ -211,11 +211,11 @@ class ServerStatus(IFromProto): def __init__( self, - status_code: ydb_status_codes_pb2.StatusIds.StatusCode, - grpc_issues: typing.Iterable[ydb_issue_message_pb2.IssueMessage], + status: issues.StatusCode, + issues: typing.Iterable[typing.Any], ): - self._grpc_status_code = status_code - self._issues = grpc_issues + self.status = status + self.issues = issues def __str__(self): return self.__repr__() @@ -228,7 +228,7 @@ def from_proto(msg: typing.Union[ return ServerStatus(msg.status, msg.issues) def is_success(self) -> bool: - return self._grpc_status_code == ydb_status_codes_pb2.StatusIds.SUCCESS + return self.status == issues.StatusCode.SUCCESS @classmethod def issue_to_str(cls, issue: ydb_issue_message_pb2.IssueMessage): diff --git a/ydb/_topic_wrapper/common_test.py b/ydb/_topic_wrapper/common_test.py index b176e314..f7c7493e 100644 --- a/ydb/_topic_wrapper/common_test.py +++ b/ydb/_topic_wrapper/common_test.py @@ -3,7 +3,7 @@ import grpc import pytest -from .common import callback_from_asyncio, GrpcWrapperAsyncIO +from .common import callback_from_asyncio, GrpcWrapperAsyncIO, ServerStatus from .. import issues # Workaround for good autocomplete in IDE and universal import at runtime @@ -90,3 +90,22 @@ async def __anext__(self): with pytest.raises(issues.Overloaded): await wrapper.receive() + +class TestServerStatus: + def test_success(self): + status = ServerStatus( + status=ydb_status_codes_pb2.StatusIds.SUCCESS, + issues=[], + ) + assert status.is_success() + assert issues._process_response(status) is None + + def test_failed(self): + status = ServerStatus( + status=ydb_status_codes_pb2.StatusIds.OVERLOADED, + issues=[], + ) + assert not status.is_success() + with pytest.raises(issues.Overloaded): + issues._process_response(status) + diff --git a/ydb/_topic_wrapper/reader.py b/ydb/_topic_wrapper/reader.py index 85538b8b..baa023c9 100644 --- a/ydb/_topic_wrapper/reader.py +++ b/ydb/_topic_wrapper/reader.py @@ -247,17 +247,20 @@ class FromServer(IFromProto): @staticmethod def from_proto(msg: ydb_topic_pb2.StreamReadMessage.FromServer) -> "StreamReadMessage.FromServer": mess_type = msg.WhichOneof("server_message") - server_status = ServerStatus.from_proto(ms) + server_status = ServerStatus.from_proto(msg) if mess_type == "read_response": return StreamReadMessage.FromServer( - server_message=StreamReadMessage.ReadResponse.from_proto(msg.read_response) + server_status=server_status, + server_message=StreamReadMessage.ReadResponse.from_proto(msg.read_response), ) elif mess_type == "init_response": return StreamReadMessage.FromServer( + server_status=server_status, server_message=StreamReadMessage.InitResponse.from_proto(msg.init_response), ) elif mess_type == "start_partition_session_request": return StreamReadMessage.FromServer( + server_status=server_status, server_message=StreamReadMessage.StartPartitionSessionRequest.from_proto(msg.start_partition_session_request) ) diff --git a/ydb/issues.py b/ydb/issues.py index 727aff1b..6df634ea 100644 --- a/ydb/issues.py +++ b/ydb/issues.py @@ -1,4 +1,6 @@ # -*- coding: utf-8 -*- +import abc + from google.protobuf import text_format import enum from six.moves import queue From dde12dbc3e84abdfa004157320dbce8e4e57e5dc Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Mon, 6 Feb 2023 19:49:48 +0300 Subject: [PATCH 031/147] fix beta release --- .github/workflows/python-publish.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index f3762395..0d0ee5c7 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -103,9 +103,9 @@ jobs: Full Changelog: [$LAST_TAG...$TAG](https://github.com/ydb-platform/ydb-go-sdk/compare/$LAST_TAG...$TAG)" if [ "$WITH_BETA" = true ] then - gh release create -d $TAG -t "$TAG" --notes "$CHANGELOG" + gh release create --prerelease $TAG --title "$TAG" --notes "$CHANGELOG" else - gh release create $TAG -t "$TAG" --notes "$CHANGELOG" + gh release create $TAG --title "$TAG" --notes "$CHANGELOG" fi; - name: Publish package From 01dd5bf305d2f98c3bba21ca7830f541997c95c4 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 7 Feb 2023 15:21:15 +0300 Subject: [PATCH 032/147] sync --- ydb/_topic_reader/topic_reader_asyncio.py | 36 ++++++++------- .../topic_reader_asyncio_test.py | 44 ++++++++++++++++--- 2 files changed, 59 insertions(+), 21 deletions(-) diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index 6421625e..5a9f3dc5 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -6,9 +6,7 @@ from collections import deque from typing import Optional, Set, Dict -import grpc -import ydb from .. import _apis, issues, RetrySettings from ..aio import Driver from ..issues import ( @@ -53,7 +51,7 @@ class ReaderReconnector: _state_changed: asyncio.Event _stream_reader: Optional["ReaderStream"] - _first_error: asyncio.Future[ydb.Error] + _first_error: asyncio.Future[YdbError] def __init__(self, driver: Driver, settings: PublicReaderSettings): self._settings = settings @@ -71,11 +69,10 @@ async def _connection_loop(self): while True: try: self._stream_reader = await ReaderStream.create(self._driver, self._settings) + attempt = 0 self._state_changed.set() - self._stream_reader._state_changed.wait() - except Exception as err: - # todo reset attempts when connection established - + await self._stream_reader.wait_error() + except issues.Error as err: retry_info = check_retriable_error(err, self._settings._retry_settings(), attempt) if not retry_info.is_retriable: self._set_first_error(err) @@ -90,8 +87,11 @@ async def wait_message(self): raise self._first_error.result() if self._stream_reader is not None: - await self._stream_reader.wait_messages() - return + try: + await self._stream_reader.wait_messages() + return + except YdbError: + pass # handle errors in reconnection loop await self._state_changed.wait() self._state_changed.clear() @@ -114,6 +114,7 @@ def _set_first_error(self, err: issues.Error): # skip if already has result pass + class ReaderStream: _token_getter: Optional[TokenGetterFuncType] _session_id: str @@ -126,7 +127,7 @@ class ReaderStream: _state_changed: asyncio.Event _closed: bool _message_batches: typing.Deque[PublicBatch] - first_error: asyncio.Future[YdbError] + _first_error: asyncio.Future[YdbError] def __init__(self, settings: PublicReaderSettings): self._token_getter = settings._token_getter @@ -139,7 +140,7 @@ def __init__(self, settings: PublicReaderSettings): self._state_changed = asyncio.Event() self._closed = False - self.first_error = asyncio.get_running_loop().create_future() + self._first_error = asyncio.get_running_loop().create_future() self._message_batches = deque() @staticmethod @@ -174,6 +175,9 @@ async def _start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamReadMess read_messages_task = asyncio.create_task(self._read_messages_loop(stream)) self._background_tasks.add(read_messages_task) + async def wait_error(self): + raise await self._first_error + async def wait_messages(self): while True: if self._get_first_error() is not None: @@ -317,17 +321,17 @@ def _read_response_to_batches(self, message: StreamReadMessage.ReadResponse) -> batches[-1]._bytes_size += additional_bytes_to_last_batch return batches - def _set_first_error(self, err): + def _set_first_error(self, err: ydb.Error): try: - self.first_error.set_result(err) + self._first_error.set_result(err) self._state_changed.set() except asyncio.InvalidStateError: # skip later set errors pass - def _get_first_error(self): - if self.first_error.done(): - return self.first_error.result() + def _get_first_error(self) -> Optional[ydb.Error]: + if self._first_error.done(): + return self._first_error.result() else: return None diff --git a/ydb/_topic_reader/topic_reader_asyncio_test.py b/ydb/_topic_reader/topic_reader_asyncio_test.py index ba73f322..4bc78b7a 100644 --- a/ydb/_topic_reader/topic_reader_asyncio_test.py +++ b/ydb/_topic_reader/topic_reader_asyncio_test.py @@ -10,7 +10,7 @@ from .datatypes import PublicBatch, PublicMessage from .topic_reader import PublicReaderSettings from .topic_reader_asyncio import ReaderStream, PartitionSession, ReaderReconnector -from .._topic_wrapper.common import OffsetsRange, Codec, ServerStatus, UpdateTokenResponse +from .._topic_wrapper.common import OffsetsRange, Codec, ServerStatus, UpdateTokenResponse, SupportedDriverType from .._topic_wrapper.reader import StreamReadMessage from .._topic_wrapper.test_helpers import StreamMock, wait_condition, wait_for_fast from ..issues import Unavailable @@ -23,6 +23,13 @@ from .._grpc.common.protos import ydb_status_codes_pb2 +@pytest.fixture(autouse=True) +def handle_exceptions(event_loop): + def handler(loop, context): + print(context) + event_loop.set_exception_handler(handler) + + @pytest.fixture() def default_reader_settings(): return PublicReaderSettings( @@ -634,11 +641,38 @@ async def test_receive_batch_nowait(self, stream, stream_reader, partition_sessi @pytest.mark.asyncio class TestReaderReconnector: async def test_reconnect_on_repeatable_error(self, monkeypatch): - def stream_create(): - pass + test_error = issues.Overloaded("test error") + + async def wait_error(): + raise test_error + + reader_stream_mock_with_error = mock.Mock(ReaderStream) + reader_stream_mock_with_error.wait_error = mock.AsyncMock(side_effect=wait_error) + + async def wait_messages(): + raise test_error + + reader_stream_mock_with_error.wait_messages = mock.AsyncMock(side_effect=wait_messages) + + reader_stream_with_messages = mock.Mock(ReaderStream) + reader_stream_with_messages.wait_error.return_value = asyncio.Future() + reader_stream_with_messages.wait_messages.return_value = None + + stream_index = 0 + + async def stream_create(driver: SupportedDriverType, settings: PublicReaderSettings,): + nonlocal stream_index + stream_index += 1 + if stream_index == 1: + return reader_stream_mock_with_error + elif stream_index == 2: + return reader_stream_with_messages + else: + raise Exception("unexpected create stream") with mock.patch.object(ReaderStream, "create", stream_create): - reconnector = ReaderReconnector(None, PublicReaderSettings("", "")) + reconnector = ReaderReconnector(mock.Mock(), PublicReaderSettings("", "")) await reconnector.wait_message() - raise NotImplementedError() + reader_stream_mock_with_error.wait_error.assert_any_await() + reader_stream_mock_with_error.wait_messages.assert_any_await() From 538d6c7d789cd096118fde92aa40b012b4054fdf Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 7 Feb 2023 15:32:37 +0300 Subject: [PATCH 033/147] sync with declared public api --- tests/topics/test_topic_reader.py | 3 +- ydb/_topic_reader/topic_reader_asyncio.py | 97 ++++++++++++++++++++++- 2 files changed, 95 insertions(+), 5 deletions(-) diff --git a/tests/topics/test_topic_reader.py b/tests/topics/test_topic_reader.py index d775db32..d2c94c71 100644 --- a/tests/topics/test_topic_reader.py +++ b/tests/topics/test_topic_reader.py @@ -11,6 +11,5 @@ async def test_read_message(self, driver, topic_path, topic_with_messages, topic consumer=topic_consumer, topic=topic_path, )) - await reader.wait_messages() - assert reader.receive_batch() is not None + assert await reader.receive_batch() is not None diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index 5a9f3dc5..8250df22 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -37,12 +37,103 @@ def __init__(self, driver: Driver, settings: PublicReaderSettings): self._loop = asyncio.get_running_loop() self._reconnector = ReaderReconnector(driver, settings) - async def wait_messages(self): + async def __aenter__(self): + raise NotImplementedError() + + async def __aexit__(self, exc_type, exc_val, exc_tb): + raise NotImplementedError() + + async def sessions_stat(self) -> typing.List["SessionStat"]: + """ + Receive stat from the server + + use asyncio.wait_for for wait with timeout. + """ + raise NotImplementedError() + + def messages( + self, *, timeout: typing.Union[float, None] = None + ) -> typing.AsyncIterable["PublicMessage"]: + """ + Block until receive new message + + if no new messages in timeout seconds: stop iteration by raise StopAsyncIteration + """ + raise NotImplementedError() + + async def receive_message(self) -> typing.Union["PublicMessage", None]: + """ + Block until receive new message + + use asyncio.wait_for for wait with timeout. + """ + raise NotImplementedError() + + def batches( + self, + *, + max_messages: typing.Union[int, None] = None, + max_bytes: typing.Union[int, None] = None, + timeout: typing.Union[float, None] = None, + ) -> typing.AsyncIterable["PublicBatch"]: + """ + Block until receive new batch. + All messages in a batch from same partition. + + if no new message in timeout seconds (default - infinite): stop iterations by raise StopIteration + """ + raise NotImplementedError() + + async def receive_batch( + self, *, max_messages: typing.Union[int, None] = None, max_bytes: typing.Union[int, None] = None + ) -> typing.Union["PublicBatch", None]: + """ + Get one messages batch from reader. + All messages in a batch from same partition. + + use asyncio.wait_for for wait with timeout. + """ await self._reconnector.wait_message() - - def receive_batch(self): return self._reconnector.receive_batch_nowait() + async def commit_on_exit(self, mess: "ICommittable") -> typing.AsyncContextManager: + """ + commit the mess match/message if exit from context manager without exceptions + + reader will close if exit from context manager with exception + """ + raise NotImplementedError() + + def commit(self, mess: "ICommittable"): + """ + Write commit message to a buffer. + + For the method no way check the commit result + (for example if lost connection - commits will not re-send and committed messages will receive again) + """ + raise NotImplementedError() + + async def commit_with_ack( + self, mess: "ICommittable" + ) -> typing.Union["CommitResult", typing.List["CommitResult"]]: + """ + write commit message to a buffer and wait ack from the server. + + use asyncio.wait_for for wait with timeout. + """ + raise NotImplementedError() + + async def flush(self): + """ + force send all commit messages from internal buffers to server and wait acks for all of them. + + use asyncio.wait_for for wait with timeout. + """ + raise NotImplementedError() + + async def close(self): + raise NotImplementedError() + class ReaderReconnector: _settings: PublicReaderSettings From 62b8033da707ec243dd4725c301c18b0ea9242bd Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 7 Feb 2023 15:50:45 +0300 Subject: [PATCH 034/147] create topic reader from topic client --- tests/topics/test_topic_reader.py | 5 +- ydb/_topic_reader/topic_reader.py | 98 ----------------------- ydb/_topic_reader/topic_reader_asyncio.py | 19 ++++- ydb/topic.py | 42 +++++----- 4 files changed, 41 insertions(+), 123 deletions(-) diff --git a/tests/topics/test_topic_reader.py b/tests/topics/test_topic_reader.py index d2c94c71..e2f9f725 100644 --- a/tests/topics/test_topic_reader.py +++ b/tests/topics/test_topic_reader.py @@ -7,9 +7,6 @@ @pytest.mark.asyncio class TestTopicWriterAsyncIO: async def test_read_message(self, driver, topic_path, topic_with_messages, topic_consumer): - reader = PublicAsyncIOReader(driver, TopicReaderSettings( - consumer=topic_consumer, - topic=topic_path, - )) + reader = driver.topic_client.topic_reader(topic_consumer, topic_path) assert await reader.receive_batch() is not None diff --git a/ydb/_topic_reader/topic_reader.py b/ydb/_topic_reader/topic_reader.py index 41a17add..7995acc3 100644 --- a/ydb/_topic_reader/topic_reader.py +++ b/ydb/_topic_reader/topic_reader.py @@ -32,104 +32,6 @@ def __init__(self, path, *, partitions: Union[None, int, List[int]] = None): self.partitions = partitions -class ReaderAsyncIO(object): - async def __aenter__(self): - raise NotImplementedError() - - async def __aexit__(self, exc_type, exc_val, exc_tb): - raise NotImplementedError() - - async def sessions_stat(self) -> List["SessionStat"]: - """ - Receive stat from the server - - use asyncio.wait_for for wait with timeout. - """ - raise NotImplementedError() - - def messages( - self, *, timeout: Union[float, None] = None - ) -> AsyncIterable["PublicMessage"]: - """ - Block until receive new message - - if no new messages in timeout seconds: stop iteration by raise StopAsyncIteration - """ - raise NotImplementedError() - - async def receive_message(self) -> Union["PublicMessage", None]: - """ - Block until receive new message - - use asyncio.wait_for for wait with timeout. - """ - raise NotImplementedError() - - def batches( - self, - *, - max_messages: Union[int, None] = None, - max_bytes: Union[int, None] = None, - timeout: Union[float, None] = None, - ) -> AsyncIterable["PublicBatch"]: - """ - Block until receive new batch. - All messages in a batch from same partition. - - if no new message in timeout seconds (default - infinite): stop iterations by raise StopIteration - """ - raise NotImplementedError() - - async def receive_batch( - self, *, max_messages: Union[int, None] = None, max_bytes: Union[int, None] - ) -> Union["PublicBatch", None]: - """ - Get one messages batch from reader. - All messages in a batch from same partition. - - use asyncio.wait_for for wait with timeout. - """ - raise NotImplementedError() - - async def commit_on_exit(self, mess: "ICommittable") -> AsyncContextManager: - """ - commit the mess match/message if exit from context manager without exceptions - - reader will close if exit from context manager with exception - """ - raise NotImplementedError() - - def commit(self, mess: "ICommittable"): - """ - Write commit message to a buffer. - - For the method no way check the commit result - (for example if lost connection - commits will not re-send and committed messages will receive again) - """ - raise NotImplementedError() - - async def commit_with_ack( - self, mess: "ICommittable" - ) -> Union["CommitResult", List["CommitResult"]]: - """ - write commit message to a buffer and wait ack from the server. - - use asyncio.wait_for for wait with timeout. - """ - raise NotImplementedError() - - async def flush(self): - """ - force send all commit messages from internal buffers to server and wait acks for all of them. - - use asyncio.wait_for for wait with timeout. - """ - raise NotImplementedError() - - async def close(self): - raise NotImplementedError() - - class Reader(object): def async_sessions_stat(self) -> concurrent.futures.Future: """ diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index 8250df22..b9e53960 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -26,15 +26,22 @@ class TopicReaderError(YdbError): class TopicReaderStreamClosedError(TopicReaderError): def __init__(self): - super().__init__("Topic reader is closed") + super().__init__("Topic reader stream is closed") + + +class TopicReaderClosedError(TopicReaderError): + def __init__(self): + super().__init__("Topic reader is closed already") class PublicAsyncIOReader: _loop: asyncio.AbstractEventLoop + _closed: bool _reconnector: ReaderReconnector def __init__(self, driver: Driver, settings: PublicReaderSettings): self._loop = asyncio.get_running_loop() + self._closed = False self._reconnector = ReaderReconnector(driver, settings) async def __aenter__(self): @@ -43,6 +50,10 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_val, exc_tb): raise NotImplementedError() + def __del__(self): + if not self._closed: + self._loop.create_task(self.close(), name="close reader") + async def sessions_stat(self) -> typing.List["SessionStat"]: """ Receive stat from the server @@ -132,7 +143,11 @@ async def flush(self): raise NotImplementedError() async def close(self): - raise NotImplementedError() + if self._closed: + raise TopicReaderClosedError() + + self._closed = True + await self._reconnector.close() class ReaderReconnector: diff --git a/ydb/topic.py b/ydb/topic.py index 165763c4..880424ea 100644 --- a/ydb/topic.py +++ b/ydb/topic.py @@ -1,16 +1,19 @@ -from typing import List, Callable, Union, Mapping, Any +from typing import List, Callable, Union, Mapping, Any, Optional from . import aio, Credentials from ._topic_reader.topic_reader import ( PublicReaderSettings as TopicReaderSettings, Reader as TopicReader, - ReaderAsyncIO as TopicReaderAsyncIO, Selector as TopicSelector, Events as TopicReaderEvents, RetryPolicy as TopicReaderRetryPolicy, StubEvent as TopicReaderStubEvent, ) +from ._topic_reader.topic_reader_asyncio import ( + PublicAsyncIOReader as TopicReaderAsyncIO +) +from ._topic_wrapper.common import TokenGetterFuncType from ._topic_writer.topic_writer import ( # noqa: F401 Writer as TopicWriter, @@ -32,26 +35,27 @@ def __init__(self, driver: aio.Driver, settings: "TopicClientSettings" = None): def topic_reader( self, - topic: Union[str, TopicSelector, List[Union[str, TopicSelector]]], consumer: str, - commit_batch_time: Union[float, None] = 0.1, - commit_batch_count: Union[int, None] = 1000, + topic: str, buffer_size_bytes: int = 50 * 1024 * 1024, - sync_commit: bool = False, # reader.commit(...) will wait commit ack from server - on_commit: Callable[["TopicReaderEvents.OnCommit"], None] = None, - on_get_partition_start_offset: Callable[ - ["TopicReaderEvents.OnPartitionGetStartOffsetRequest"], - "TopicReaderEvents.OnPartitionGetStartOffsetResponse", - ] = None, - on_init_partition: Callable[["TopicReaderStubEvent"], None] = None, - on_shutdown_partition: Callable[["TopicReaderStubEvent"], None] = None, - decoder: Union[Mapping[int, Callable[[bytes], bytes]], None] = None, - deserializer: Union[Callable[[bytes], Any], None] = None, - one_attempt_connection_timeout: Union[float, None] = 1, - connection_timeout: Union[float, None] = None, - retry_policy: Union["TopicReaderRetryPolicy", None] = None, + # on_commit: Callable[["Events.OnCommit"], None] = None + # on_get_partition_start_offset: Callable[ + # ["Events.OnPartitionGetStartOffsetRequest"], + # "Events.OnPartitionGetStartOffsetResponse", + # ] = None + # on_partition_session_start: Callable[["StubEvent"], None] = None + # on_partition_session_stop: Callable[["StubEvent"], None] = None + # on_partition_session_close: Callable[["StubEvent"], None] = None # todo? + # decoder: Union[Mapping[int, Callable[[bytes], bytes]], None] = None + # deserializer: Union[Callable[[bytes], Any], None] = None + # one_attempt_connection_timeout: Union[float, None] = 1 + # connection_timeout: Union[float, None] = None + # retry_policy: Union["RetryPolicy", None] = None ) -> TopicReaderAsyncIO: - raise NotImplementedError() + args = locals() + del args["self"] + settings = TopicReaderSettings(**args) + return TopicReaderAsyncIO(self._driver, settings) def topic_writer( self, From abac878f40e68642d04c15f62c895b22c5457283 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 7 Feb 2023 18:23:42 +0300 Subject: [PATCH 035/147] fix linter --- CHANGELOG.md | 2 + tests/conftest.py | 4 +- tests/topics/test_topic_reader.py | 7 +- ydb/_topic_reader/datatypes.py | 5 +- ydb/_topic_reader/topic_reader.py | 32 +- ydb/_topic_reader/topic_reader_asyncio.py | 138 ++++-- .../topic_reader_asyncio_test.py | 406 ++++++++++-------- ydb/_topic_wrapper/common.py | 26 +- ydb/_topic_wrapper/common_test.py | 4 +- ydb/_topic_wrapper/reader.py | 101 +++-- ydb/_topic_wrapper/test_helpers.py | 2 - ydb/_topic_writer/topic_writer_asyncio.py | 7 +- .../topic_writer_asyncio_test.py | 2 +- ydb/issues.py | 2 - ydb/topic.py | 5 +- 15 files changed, 443 insertions(+), 300 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ed909b29..04d9d60b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,5 @@ +* Initial implementation of topic reader + ## 3.0.1b3 ## * Fix error of check retriable error for idempotent operations (error exist since 2.12.1) diff --git a/tests/conftest.py b/tests/conftest.py index 8cf8ab0d..17d2801a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -135,7 +135,9 @@ def topic_path(endpoint, topic_consumer) -> str: @pytest.mark.asyncio() async def topic_with_messages(driver, topic_path): pass - writer = driver.topic_client.topic_writer(topic_path, producer_and_message_group_id="fixture-producer-id") + writer = driver.topic_client.topic_writer( + topic_path, producer_and_message_group_id="fixture-producer-id" + ) await writer.write_with_ack( ydb.TopicWriterMessage(data="123".encode()), ydb.TopicWriterMessage(data="456".encode()), diff --git a/tests/topics/test_topic_reader.py b/tests/topics/test_topic_reader.py index e2f9f725..6d87fc0b 100644 --- a/tests/topics/test_topic_reader.py +++ b/tests/topics/test_topic_reader.py @@ -1,12 +1,11 @@ import pytest -from ydb._topic_reader.topic_reader_asyncio import PublicAsyncIOReader -from ydb import TopicReaderSettings - @pytest.mark.asyncio class TestTopicWriterAsyncIO: - async def test_read_message(self, driver, topic_path, topic_with_messages, topic_consumer): + async def test_read_message( + self, driver, topic_path, topic_with_messages, topic_consumer + ): reader = driver.topic_client.topic_reader(topic_consumer, topic_path) assert await reader.receive_batch() is not None diff --git a/ydb/_topic_reader/datatypes.py b/ydb/_topic_reader/datatypes.py index 1c26b272..9b2ab31a 100644 --- a/ydb/_topic_reader/datatypes.py +++ b/ydb/_topic_reader/datatypes.py @@ -87,4 +87,7 @@ def end_offset(self) -> int: @property def is_alive(self) -> bool: state = self._partition_session.state - return state == PartitionSession.State.Active or state == PartitionSession.State.GracefulShutdown + return ( + state == PartitionSession.State.Active + or state == PartitionSession.State.GracefulShutdown + ) diff --git a/ydb/_topic_reader/topic_reader.py b/ydb/_topic_reader/topic_reader.py index 7995acc3..322df7e8 100644 --- a/ydb/_topic_reader/topic_reader.py +++ b/ydb/_topic_reader/topic_reader.py @@ -1,24 +1,18 @@ -import abc import concurrent.futures import enum -import io import datetime from dataclasses import dataclass from typing import ( Union, Optional, List, - Mapping, - Callable, Iterable, - AsyncIterable, - AsyncContextManager, - Any, Dict, ) -from ydb import RetrySettings -from ydb._topic_wrapper.common import OffsetsRange, TokenGetterFuncType -from ydb._topic_wrapper.reader import StreamReadMessage +from ..table import RetrySettings +from .datatypes import ICommittable, PublicBatch, PublicMessage +from .._topic_wrapper.common import OffsetsRange, TokenGetterFuncType +from .._topic_wrapper.reader import StreamReadMessage class Selector: @@ -47,7 +41,9 @@ async def sessions_stat(self) -> List["SessionStat"]: """ raise NotImplementedError() - def messages(self, *, timeout: Union[float, None] = None) -> Iterable["PublicMessage"]: + def messages( + self, *, timeout: Union[float, None] = None + ) -> Iterable[PublicMessage]: """ todo? @@ -59,7 +55,7 @@ def messages(self, *, timeout: Union[float, None] = None) -> Iterable["PublicMes """ raise NotImplementedError() - def receive_message(self, *, timeout: Union[float, None] = None) -> "PublicMessage": + def receive_message(self, *, timeout: Union[float, None] = None) -> PublicMessage: """ Block until receive new message It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available. @@ -85,7 +81,7 @@ def batches( max_messages: Union[int, None] = None, max_bytes: Union[int, None] = None, timeout: Union[float, None] = None, - ) -> Iterable["PublicBatch"]: + ) -> Iterable[PublicBatch]: """ Block until receive new batch. It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available. @@ -101,7 +97,7 @@ def receive_batch( max_messages: Union[int, None] = None, max_bytes: Union[int, None], timeout: Union[float, None] = None, - ) -> Union["PublicBatch", None]: + ) -> Union[PublicBatch, None]: """ Get one messages batch from reader It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available. @@ -111,7 +107,7 @@ def receive_batch( """ raise NotImplementedError() - def commit(self, mess: "ICommittable"): + def commit(self, mess: ICommittable): """ Put commit message to internal buffer. @@ -121,7 +117,7 @@ def commit(self, mess: "ICommittable"): raise NotImplementedError() def commit_with_ack( - self, mess: "ICommittable" + self, mess: ICommittable ) -> Union["CommitResult", List["CommitResult"]]: """ write commit message to a buffer and wait ack from the server. @@ -131,7 +127,7 @@ def commit_with_ack( raise NotImplementedError() def async_commit_with_ack( - self, mess: "ICommittable" + self, mess: ICommittable ) -> Union["CommitResult", List["CommitResult"]]: """ write commit message to a buffer and return Future for wait result. @@ -184,7 +180,7 @@ def _init_message(self) -> StreamReadMessage.InitRequest: consumer=self.consumer, ) - def _retry_settings(self)->RetrySettings: + def _retry_settings(self) -> RetrySettings: return RetrySettings(idempotent=True) diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index b9e53960..50e1a331 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -9,13 +9,15 @@ from .. import _apis, issues, RetrySettings from ..aio import Driver -from ..issues import ( - Error as YdbError, - _process_response +from ..issues import Error as YdbError, _process_response +from .datatypes import PartitionSession, PublicMessage, PublicBatch, ICommittable +from .topic_reader import PublicReaderSettings, CommitResult, SessionStat +from .._topic_wrapper.common import ( + TokenGetterFuncType, + IGrpcWrapperAsyncIO, + SupportedDriverType, + GrpcWrapperAsyncIO, ) -from .datatypes import PartitionSession, PublicMessage, PublicBatch -from .topic_reader import PublicReaderSettings -from .._topic_wrapper.common import TokenGetterFuncType, IGrpcWrapperAsyncIO, SupportedDriverType, GrpcWrapperAsyncIO from .._topic_wrapper.reader import StreamReadMessage from .._errors import check_retriable_error @@ -96,7 +98,10 @@ def batches( raise NotImplementedError() async def receive_batch( - self, *, max_messages: typing.Union[int, None] = None, max_bytes: typing.Union[int, None] = None + self, + *, + max_messages: typing.Union[int, None] = None, + max_bytes: typing.Union[int, None] = None, ) -> typing.Union["PublicBatch", None]: """ Get one messages batch from reader. @@ -107,7 +112,7 @@ async def receive_batch( await self._reconnector.wait_message() return self._reconnector.receive_batch_nowait() - async def commit_on_exit(self, mess: "ICommittable") -> typing.AsyncContextManager: + async def commit_on_exit(self, mess: ICommittable) -> typing.AsyncContextManager: """ commit the mess match/message if exit from context manager without exceptions @@ -115,7 +120,7 @@ async def commit_on_exit(self, mess: "ICommittable") -> typing.AsyncContextManag """ raise NotImplementedError() - def commit(self, mess: "ICommittable"): + def commit(self, mess: ICommittable): """ Write commit message to a buffer. @@ -125,8 +130,8 @@ def commit(self, mess: "ICommittable"): raise NotImplementedError() async def commit_with_ack( - self, mess: "ICommittable" - ) -> typing.Union["CommitResult", typing.List["CommitResult"]]: + self, mess: ICommittable + ) -> typing.Union[CommitResult, typing.List[CommitResult]]: """ write commit message to a buffer and wait ack from the server. @@ -174,12 +179,16 @@ async def _connection_loop(self): attempt = 0 while True: try: - self._stream_reader = await ReaderStream.create(self._driver, self._settings) + self._stream_reader = await ReaderStream.create( + self._driver, self._settings + ) attempt = 0 self._state_changed.set() await self._stream_reader.wait_error() except issues.Error as err: - retry_info = check_retriable_error(err, self._settings._retry_settings(), attempt) + retry_info = check_retriable_error( + err, self._settings._retry_settings(), attempt + ) if not retry_info.is_retriable: self._set_first_error(err) return @@ -264,7 +273,9 @@ async def create( await reader._start(stream, settings._init_message()) return reader - async def _start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamReadMessage.InitRequest): + async def _start( + self, stream: IGrpcWrapperAsyncIO, init_message: StreamReadMessage.InitRequest + ): if self._started: raise TopicReaderError("Double start ReaderStream") @@ -276,7 +287,9 @@ async def _start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamReadMess if isinstance(init_response.server_message, StreamReadMessage.InitResponse): self._session_id = init_response.server_message.session_id else: - raise TopicReaderError("Unexpected message after InitRequest: %s", init_response) + raise TopicReaderError( + "Unexpected message after InitRequest: %s", init_response + ) read_messages_task = asyncio.create_task(self._read_messages_loop(stream)) self._background_tasks.add(read_messages_task) @@ -308,23 +321,32 @@ def receive_batch_nowait(self): async def _read_messages_loop(self, stream: IGrpcWrapperAsyncIO): try: - self._stream.write(StreamReadMessage.FromClient( - client_message=StreamReadMessage.ReadRequest( - bytes_size=self._buffer_size_bytes, - ), - )) + self._stream.write( + StreamReadMessage.FromClient( + client_message=StreamReadMessage.ReadRequest( + bytes_size=self._buffer_size_bytes, + ), + ) + ) while True: message = await stream.receive() # type: StreamReadMessage.FromServer _process_response(message.server_status) if isinstance(message.server_message, StreamReadMessage.ReadResponse): self._on_read_response(message.server_message) - elif isinstance(message.server_message, StreamReadMessage.StartPartitionSessionRequest): + elif isinstance( + message.server_message, + StreamReadMessage.StartPartitionSessionRequest, + ): self._on_start_partition_session_start(message.server_message) - elif isinstance(message.server_message, StreamReadMessage.StopPartitionSessionRequest): + elif isinstance( + message.server_message, + StreamReadMessage.StopPartitionSessionRequest, + ): self._on_partition_session_stop(message.server_message) else: raise NotImplementedError( - "Unexpected type of StreamReadMessage.FromServer message: %s" % message.server_message + "Unexpected type of StreamReadMessage.FromServer message: %s" + % message.server_message ) self._state_changed.set() @@ -332,30 +354,42 @@ async def _read_messages_loop(self, stream: IGrpcWrapperAsyncIO): self._set_first_error(e) raise e - def _on_start_partition_session_start(self, message: StreamReadMessage.StartPartitionSessionRequest): + def _on_start_partition_session_start( + self, message: StreamReadMessage.StartPartitionSessionRequest + ): try: - if message.partition_session.partition_session_id in self._partition_sessions: + if ( + message.partition_session.partition_session_id + in self._partition_sessions + ): raise TopicReaderError( - "Double start partition session: %s" % message.partition_session.partition_session_id + "Double start partition session: %s" + % message.partition_session.partition_session_id ) - self._partition_sessions[message.partition_session.partition_session_id] = PartitionSession( + self._partition_sessions[ + message.partition_session.partition_session_id + ] = PartitionSession( id=message.partition_session.partition_session_id, state=PartitionSession.State.Active, topic_path=message.partition_session.path, partition_id=message.partition_session.partition_id, ) - self._stream.write(StreamReadMessage.FromClient( - client_message=StreamReadMessage.StartPartitionSessionResponse( - partition_session_id=message.partition_session.partition_session_id, - read_offset=0, - commit_offset=0, - )), + self._stream.write( + StreamReadMessage.FromClient( + client_message=StreamReadMessage.StartPartitionSessionResponse( + partition_session_id=message.partition_session.partition_session_id, + read_offset=0, + commit_offset=0, + ) + ), ) except YdbError as err: self._set_first_error(err) - def _on_partition_session_stop(self, message: StreamReadMessage.StopPartitionSessionRequest): + def _on_partition_session_stop( + self, message: StreamReadMessage.StopPartitionSessionRequest + ): partition = self._partition_sessions.get(message.partition_session_id) if partition is None: # may if receive stop partition with graceful=false after response on stop partition @@ -366,10 +400,12 @@ def _on_partition_session_stop(self, message: StreamReadMessage.StopPartitionSes partition.stop() if message.graceful: - self._stream.write(StreamReadMessage.FromClient( - client_message=StreamReadMessage.StopPartitionSessionResponse( - partition_session_id=message.partition_session_id, - )) + self._stream.write( + StreamReadMessage.FromClient( + client_message=StreamReadMessage.StopPartitionSessionResponse( + partition_session_id=message.partition_session_id, + ) + ) ) def _on_read_response(self, message: StreamReadMessage.ReadResponse): @@ -382,11 +418,17 @@ def _buffer_consume_bytes(self, bytes_size): def _buffer_release_bytes(self, bytes_size): self._buffer_size_bytes += bytes_size - self._stream.write(StreamReadMessage.FromClient(client_message=StreamReadMessage.ReadRequest( - bytes_size=bytes_size, - ))) + self._stream.write( + StreamReadMessage.FromClient( + client_message=StreamReadMessage.ReadRequest( + bytes_size=bytes_size, + ) + ) + ) - def _read_response_to_batches(self, message: StreamReadMessage.ReadResponse) -> typing.List[PublicBatch]: + def _read_response_to_batches( + self, message: StreamReadMessage.ReadResponse + ) -> typing.List[PublicBatch]: batches = [] batch_count = 0 @@ -397,10 +439,14 @@ def _read_response_to_batches(self, message: StreamReadMessage.ReadResponse) -> return [] bytes_per_batch = message.bytes_size // batch_count - additional_bytes_to_last_batch = message.bytes_size - bytes_per_batch * batch_count + additional_bytes_to_last_batch = ( + message.bytes_size - bytes_per_batch * batch_count + ) for partition_data in message.partition_data: - partition_session = self._partition_sessions[partition_data.partition_session_id] + partition_session = self._partition_sessions[ + partition_data.partition_session_id + ] for server_batch in partition_data.batches: messages = [] for message_data in server_batch.message_data: @@ -427,7 +473,7 @@ def _read_response_to_batches(self, message: StreamReadMessage.ReadResponse) -> batches[-1]._bytes_size += additional_bytes_to_last_batch return batches - def _set_first_error(self, err: ydb.Error): + def _set_first_error(self, err: YdbError): try: self._first_error.set_result(err) self._state_changed.set() @@ -435,7 +481,7 @@ def _set_first_error(self, err: ydb.Error): # skip later set errors pass - def _get_first_error(self) -> Optional[ydb.Error]: + def _get_first_error(self) -> Optional[YdbError]: if self._first_error.done(): return self._first_error.result() else: diff --git a/ydb/_topic_reader/topic_reader_asyncio_test.py b/ydb/_topic_reader/topic_reader_asyncio_test.py index 4bc78b7a..7a002a0b 100644 --- a/ydb/_topic_reader/topic_reader_asyncio_test.py +++ b/ydb/_topic_reader/topic_reader_asyncio_test.py @@ -2,18 +2,20 @@ import datetime from unittest import mock -import grpc import pytest -import ydb -from ydb import aio, issues +from ydb import issues from .datatypes import PublicBatch, PublicMessage from .topic_reader import PublicReaderSettings from .topic_reader_asyncio import ReaderStream, PartitionSession, ReaderReconnector -from .._topic_wrapper.common import OffsetsRange, Codec, ServerStatus, UpdateTokenResponse, SupportedDriverType +from .._topic_wrapper.common import ( + OffsetsRange, + Codec, + ServerStatus, + SupportedDriverType, +) from .._topic_wrapper.reader import StreamReadMessage from .._topic_wrapper.test_helpers import StreamMock, wait_condition, wait_for_fast -from ..issues import Unavailable # Workaround for good autocomplete in IDE and universal import at runtime # noinspection PyUnreachableCode @@ -27,6 +29,7 @@ def handle_exceptions(event_loop): def handler(loop, context): print(context) + event_loop.set_exception_handler(handler) @@ -73,18 +76,27 @@ def second_partition_session(self, default_reader_settings): ) @pytest.fixture() - async def stream_reader_started(self, stream, default_reader_settings, partition_session, - second_partition_session) -> ReaderStream: + async def stream_reader_started( + self, + stream, + default_reader_settings, + partition_session, + second_partition_session, + ) -> ReaderStream: reader = ReaderStream(default_reader_settings) init_message = object() # noinspection PyTypeChecker start = asyncio.create_task(reader._start(stream, init_message)) - stream.from_server.put_nowait(StreamReadMessage.FromServer( - server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), - server_message=StreamReadMessage.InitResponse(session_id="test-session"), - )) + stream.from_server.put_nowait( + StreamReadMessage.FromServer( + server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), + server_message=StreamReadMessage.InitResponse( + session_id="test-session" + ), + ) + ) init_request = await wait_for_fast(stream.from_client.get()) assert init_request.client_message == init_message @@ -96,41 +108,49 @@ async def stream_reader_started(self, stream, default_reader_settings, partition StreamReadMessage.FromServer( server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), server_message=StreamReadMessage.StartPartitionSessionRequest( - partition_session=StreamReadMessage.PartitionSession( - partition_session_id=partition_session.id, - path=partition_session.topic_path, - partition_id=partition_session.partition_id, + partition_session=StreamReadMessage.PartitionSession( + partition_session_id=partition_session.id, + path=partition_session.topic_path, + partition_id=partition_session.partition_id, + ), + committed_offset=0, + partition_offsets=OffsetsRange( + start=0, + end=0, + ), ), - committed_offset=0, - partition_offsets=OffsetsRange( - start=0, - end=0, - ) - )) + ) ) await start start_partition_resp = await wait_for_fast(stream.from_client.get()) - assert isinstance(start_partition_resp.client_message, StreamReadMessage.StartPartitionSessionResponse) + assert isinstance( + start_partition_resp.client_message, + StreamReadMessage.StartPartitionSessionResponse, + ) stream.from_server.put_nowait( StreamReadMessage.FromServer( server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), server_message=StreamReadMessage.StartPartitionSessionRequest( - partition_session=StreamReadMessage.PartitionSession( - partition_session_id=second_partition_session.id, - path=second_partition_session.topic_path, - partition_id=second_partition_session.partition_id, + partition_session=StreamReadMessage.PartitionSession( + partition_session_id=second_partition_session.id, + path=second_partition_session.topic_path, + partition_id=second_partition_session.partition_id, + ), + committed_offset=0, + partition_offsets=OffsetsRange( + start=0, + end=0, + ), ), - committed_offset=0, - partition_offsets=OffsetsRange( - start=0, - end=0, - ) - )) + ) ) start_partition_resp = await wait_for_fast(stream.from_client.get()) - assert isinstance(start_partition_resp.client_message, StreamReadMessage.StartPartitionSessionResponse) + assert isinstance( + start_partition_resp.client_message, + StreamReadMessage.StartPartitionSessionResponse, + ) await asyncio.sleep(0) with pytest.raises(asyncio.QueueEmpty): @@ -146,13 +166,14 @@ async def stream_reader(self, stream_reader_started: ReaderStream): await stream_reader_started.close() @pytest.fixture() - async def stream_reader_finish_with_error(self, stream_reader_started: ReaderStream): + async def stream_reader_finish_with_error( + self, stream_reader_started: ReaderStream + ): yield stream_reader_started assert stream_reader_started._get_first_error() is not None await stream_reader_started.close() - @staticmethod def create_message(partition_session: PartitionSession, seqno: int): return PublicMessage( @@ -160,11 +181,11 @@ def create_message(partition_session: PartitionSession, seqno: int): created_at=datetime.datetime(2023, 2, 3, 14, 15), message_group_id="test-message-group", session_metadata={}, - offset=seqno+1, + offset=seqno + 1, written_at=datetime.datetime(2023, 2, 3, 14, 16), producer_id="test-producer-id", data=bytes(), - _partition_session=partition_session + _partition_session=partition_session, ) async def send_message(self, stream_reader, message: PublicMessage): @@ -174,32 +195,37 @@ def batch_count(): initial_batches = batch_count() stream = stream_reader._stream # type: StreamMock - stream.from_server.put_nowait(StreamReadMessage.FromServer( - server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), - server_message=StreamReadMessage.ReadResponse( - partition_data=[StreamReadMessage.ReadResponse.PartitionData( - partition_session_id=message._partition_session.id, - batches=[ - StreamReadMessage.ReadResponse.Batch( - message_data=[ - StreamReadMessage.ReadResponse.MessageData( - offset=message.offset, - seq_no=message.seqno, - created_at=message.created_at, - data=message.data, - uncompresed_size=len(message.data), - message_group_id=message.message_group_id, - ) - ], - producer_id=message.producer_id, - write_session_meta=message.session_metadata, - codec=Codec.CODEC_RAW, - written_at=message.written_at, - ) - ] - )], - bytes_size=self.default_batch_size, - ))) + stream.from_server.put_nowait( + StreamReadMessage.FromServer( + server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), + server_message=StreamReadMessage.ReadResponse( + partition_data=[ + StreamReadMessage.ReadResponse.PartitionData( + partition_session_id=message._partition_session.id, + batches=[ + StreamReadMessage.ReadResponse.Batch( + message_data=[ + StreamReadMessage.ReadResponse.MessageData( + offset=message.offset, + seq_no=message.seqno, + created_at=message.created_at, + data=message.data, + uncompresed_size=len(message.data), + message_group_id=message.message_group_id, + ) + ], + producer_id=message.producer_id, + write_session_meta=message.session_metadata, + codec=Codec.CODEC_RAW, + written_at=message.written_at, + ) + ], + ) + ], + bytes_size=self.default_batch_size, + ), + ) + ) await wait_condition(lambda: batch_count() > initial_batches) async def test_unknown_error(self, stream, stream_reader_finish_with_error): @@ -215,7 +241,9 @@ class TestError(Exception): with pytest.raises(TestError): stream_reader_finish_with_error.receive_batch_nowait() - async def test_error_from_status_code(self, stream, stream_reader_finish_with_error): + async def test_error_from_status_code( + self, stream, stream_reader_finish_with_error + ): # noinspection PyTypeChecker stream.from_server.put_nowait( StreamReadMessage.FromServer( @@ -237,22 +265,28 @@ async def test_init_reader(self, stream, default_reader_settings): reader = ReaderStream(default_reader_settings) init_message = StreamReadMessage.InitRequest( consumer="test-consumer", - topics_read_settings=[StreamReadMessage.InitRequest.TopicReadSettings( - path="/local/test-topic", - partition_ids=[], - max_lag_seconds=None, - read_from=None, - )] + topics_read_settings=[ + StreamReadMessage.InitRequest.TopicReadSettings( + path="/local/test-topic", + partition_ids=[], + max_lag_seconds=None, + read_from=None, + ) + ], ) start_task = asyncio.create_task(reader._start(stream, init_message)) sent_message = await wait_for_fast(stream.from_client.get()) - expected_sent_init_message = StreamReadMessage.FromClient(client_message=init_message) + expected_sent_init_message = StreamReadMessage.FromClient( + client_message=init_message + ) assert sent_message == expected_sent_init_message - stream.from_server.put_nowait(StreamReadMessage.FromServer( - server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), - server_message=StreamReadMessage.InitResponse(session_id="test")) + stream.from_server.put_nowait( + StreamReadMessage.FromServer( + server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), + server_message=StreamReadMessage.InitResponse(session_id="test"), + ) ) await start_task @@ -265,12 +299,13 @@ async def test_init_reader(self, stream, default_reader_settings): assert reader._session_id == "test" await reader.close() - async def test_start_partition(self, - stream_reader: ReaderStream, - stream, - default_reader_settings, - partition_session, - ): + async def test_start_partition( + self, + stream_reader: ReaderStream, + stream, + default_reader_settings, + partition_session, + ): def session_count(): return len(stream_reader._partition_sessions) @@ -280,30 +315,36 @@ def session_count(): test_partition_session_id = partition_session.id + 1 test_topic_path = default_reader_settings.topic + "-asd" - stream.from_server.put_nowait(StreamReadMessage.FromServer( - server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), - server_message=StreamReadMessage.StartPartitionSessionRequest( - partition_session=StreamReadMessage.PartitionSession( - partition_session_id=test_partition_session_id, - path=test_topic_path, - partition_id=test_partition_id, - ), - committed_offset=0, - partition_offsets=OffsetsRange( - start=0, - end=0, + stream.from_server.put_nowait( + StreamReadMessage.FromServer( + server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), + server_message=StreamReadMessage.StartPartitionSessionRequest( + partition_session=StreamReadMessage.PartitionSession( + partition_session_id=test_partition_session_id, + path=test_topic_path, + partition_id=test_partition_id, + ), + committed_offset=0, + partition_offsets=OffsetsRange( + start=0, + end=0, + ), ), - )), + ), ) response = await wait_for_fast(stream.from_client.get()) - assert response == StreamReadMessage.FromClient(client_message=StreamReadMessage.StartPartitionSessionResponse( - partition_session_id=test_partition_session_id, - read_offset=0, - commit_offset=0, - )) + assert response == StreamReadMessage.FromClient( + client_message=StreamReadMessage.StartPartitionSessionResponse( + partition_session_id=test_partition_session_id, + read_offset=0, + commit_offset=0, + ) + ) assert len(stream_reader._partition_sessions) == initial_session_count + 1 - assert stream_reader._partition_sessions[test_partition_session_id] == PartitionSession( + assert stream_reader._partition_sessions[ + test_partition_session_id + ] == PartitionSession( id=test_partition_session_id, state=PartitionSession.State.Active, topic_path=test_topic_path, @@ -316,14 +357,16 @@ def session_count(): initial_session_count = session_count() - stream.from_server.put_nowait(StreamReadMessage.FromServer( - server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), - server_message=StreamReadMessage.StopPartitionSessionRequest( - partition_session_id=partition_session.id, - graceful=False, - committed_offset=0, + stream.from_server.put_nowait( + StreamReadMessage.FromServer( + server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), + server_message=StreamReadMessage.StopPartitionSessionRequest( + partition_session_id=partition_session.id, + graceful=False, + committed_offset=0, + ), ) - )) + ) await asyncio.sleep(0) # wait next loop with pytest.raises(asyncio.QueueEmpty): @@ -332,44 +375,52 @@ def session_count(): assert session_count() == initial_session_count - 1 assert partition_session.id not in stream_reader._partition_sessions - async def test_partition_stop_graceful(self, stream, stream_reader, partition_session): + async def test_partition_stop_graceful( + self, stream, stream_reader, partition_session + ): def session_count(): return len(stream_reader._partition_sessions) initial_session_count = session_count() - stream.from_server.put_nowait(StreamReadMessage.FromServer( - server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), - server_message=StreamReadMessage.StopPartitionSessionRequest( - partition_session_id=partition_session.id, - graceful=True, - committed_offset=0, + stream.from_server.put_nowait( + StreamReadMessage.FromServer( + server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), + server_message=StreamReadMessage.StopPartitionSessionRequest( + partition_session_id=partition_session.id, + graceful=True, + committed_offset=0, + ), ) - )) + ) - resp = await wait_for_fast(stream.from_client.get()) # type: StreamReadMessage.FromClient + resp = await wait_for_fast( + stream.from_client.get() + ) # type: StreamReadMessage.FromClient assert session_count() == initial_session_count - 1 assert partition_session.id not in stream_reader._partition_sessions assert resp.client_message == StreamReadMessage.StopPartitionSessionResponse( partition_session_id=partition_session.id - ) - stream.from_server.put_nowait(StreamReadMessage.FromServer( - server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), - server_message=StreamReadMessage.StopPartitionSessionRequest( - partition_session_id=partition_session.id, - graceful=False, - committed_offset=0, + stream.from_server.put_nowait( + StreamReadMessage.FromServer( + server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), + server_message=StreamReadMessage.StopPartitionSessionRequest( + partition_session_id=partition_session.id, + graceful=False, + committed_offset=0, + ), ) - )) + ) await asyncio.sleep(0) # wait next loop with pytest.raises(asyncio.QueueEmpty): stream.from_client.get_nowait() - async def test_receive_message_from_server(self, stream_reader, stream, partition_session, - second_partition_session): + async def test_receive_message_from_server( + self, stream_reader, stream, partition_session, second_partition_session + ): def reader_batch_count(): return len(stream_reader._message_batches) @@ -384,34 +435,37 @@ def reader_batch_count(): session_meta = {"a": "b"} message_group_id = "test-message-group-id" - stream.from_server.put_nowait(StreamReadMessage.FromServer( - server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), - server_message=StreamReadMessage.ReadResponse( - bytes_size=bytes_size, - partition_data=[ - StreamReadMessage.ReadResponse.PartitionData( - partition_session_id=partition_session.id, - batches=[ - StreamReadMessage.ReadResponse.Batch( - message_data=[ - StreamReadMessage.ReadResponse.MessageData( - offset=1, - seq_no=2, - created_at=created_at, - data=data, - uncompresed_size=len(data), - message_group_id=message_group_id, + stream.from_server.put_nowait( + StreamReadMessage.FromServer( + server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), + server_message=StreamReadMessage.ReadResponse( + bytes_size=bytes_size, + partition_data=[ + StreamReadMessage.ReadResponse.PartitionData( + partition_session_id=partition_session.id, + batches=[ + StreamReadMessage.ReadResponse.Batch( + message_data=[ + StreamReadMessage.ReadResponse.MessageData( + offset=1, + seq_no=2, + created_at=created_at, + data=data, + uncompresed_size=len(data), + message_group_id=message_group_id, + ) + ], + producer_id=producer_id, + write_session_meta=session_meta, + codec=Codec.CODEC_RAW, + written_at=written_at, ) ], - producer_id=producer_id, - write_session_meta=session_meta, - codec=Codec.CODEC_RAW, - written_at=written_at, ) - ] - ) - ] - ))), + ], + ), + ) + ), await wait_condition(lambda: reader_batch_count() == initial_batch_count + 1) @@ -437,7 +491,9 @@ def reader_batch_count(): _bytes_size=bytes_size, ) - async def test_read_batches(self, stream_reader, partition_session, second_partition_session): + async def test_read_batches( + self, stream_reader, partition_session, second_partition_session + ): created_at = datetime.datetime(2020, 2, 1, 18, 12) created_at2 = datetime.datetime(2020, 2, 2, 18, 12) created_at3 = datetime.datetime(2020, 2, 3, 18, 12) @@ -477,7 +533,7 @@ async def test_read_batches(self, stream_reader, partition_session, second_parti codec=Codec.CODEC_RAW, written_at=written_at, ) - ] + ], ), StreamReadMessage.ReadResponse.PartitionData( partition_session_id=second_partition_session.id, @@ -515,16 +571,16 @@ async def test_read_batches(self, stream_reader, partition_session, second_parti data=data, uncompresed_size=len(data), message_group_id=message_group_id2, - ) + ), ], producer_id=producer_id2, write_session_meta=session_meta2, codec=Codec.CODEC_RAW, written_at=written_at2, - ) - ] + ), + ], ), - ] + ], ) ) @@ -592,7 +648,7 @@ async def test_read_batches(self, stream_reader, partition_session, second_parti producer_id=producer_id, data=data, _partition_session=second_partition_session, - ) + ), ], _partition_session=second_partition_session, _bytes_size=1, @@ -612,9 +668,7 @@ async def test_receive_batch_nowait(self, stream, stream_reader, partition_sessi received = stream_reader.receive_batch_nowait() assert received == PublicBatch( mess1.session_metadata, - messages=[ - mess1 - ], + messages=[mess1], _partition_session=mess1._partition_session, _bytes_size=self.default_batch_size, ) @@ -622,17 +676,24 @@ async def test_receive_batch_nowait(self, stream, stream_reader, partition_sessi received = stream_reader.receive_batch_nowait() assert received == PublicBatch( mess2.session_metadata, - messages=[ - mess2 - ], + messages=[mess2], _partition_session=mess2._partition_session, _bytes_size=self.default_batch_size, ) - assert stream_reader._buffer_size_bytes == initial_buffer_size + 2 * self.default_batch_size + assert ( + stream_reader._buffer_size_bytes + == initial_buffer_size + 2 * self.default_batch_size + ) - assert StreamReadMessage.ReadRequest(self.default_batch_size) == stream.from_client.get_nowait().client_message - assert StreamReadMessage.ReadRequest(self.default_batch_size) == stream.from_client.get_nowait().client_message + assert ( + StreamReadMessage.ReadRequest(self.default_batch_size) + == stream.from_client.get_nowait().client_message + ) + assert ( + StreamReadMessage.ReadRequest(self.default_batch_size) + == stream.from_client.get_nowait().client_message + ) with pytest.raises(asyncio.QueueEmpty): stream.from_client.get_nowait() @@ -647,12 +708,16 @@ async def wait_error(): raise test_error reader_stream_mock_with_error = mock.Mock(ReaderStream) - reader_stream_mock_with_error.wait_error = mock.AsyncMock(side_effect=wait_error) + reader_stream_mock_with_error.wait_error = mock.AsyncMock( + side_effect=wait_error + ) async def wait_messages(): raise test_error - reader_stream_mock_with_error.wait_messages = mock.AsyncMock(side_effect=wait_messages) + reader_stream_mock_with_error.wait_messages = mock.AsyncMock( + side_effect=wait_messages + ) reader_stream_with_messages = mock.Mock(ReaderStream) reader_stream_with_messages.wait_error.return_value = asyncio.Future() @@ -660,7 +725,10 @@ async def wait_messages(): stream_index = 0 - async def stream_create(driver: SupportedDriverType, settings: PublicReaderSettings,): + async def stream_create( + driver: SupportedDriverType, + settings: PublicReaderSettings, + ): nonlocal stream_index stream_index += 1 if stream_index == 1: diff --git a/ydb/_topic_wrapper/common.py b/ydb/_topic_wrapper/common.py index 2955461a..e666dc2f 100644 --- a/ydb/_topic_wrapper/common.py +++ b/ydb/_topic_wrapper/common.py @@ -15,14 +15,12 @@ # noinspection PyUnreachableCode if False: from ydb._grpc.v4.protos import ( - ydb_status_codes_pb2, ydb_issue_message_pb2, ydb_topic_pb2, ) else: # noinspection PyUnresolvedReferences from ydb._grpc.common.protos import ( - ydb_status_codes_pb2, ydb_issue_message_pb2, ydb_topic_pb2, ) @@ -210,9 +208,9 @@ class ServerStatus(IFromProto): __slots__ = ("_grpc_status_code", "_issues") def __init__( - self, - status: issues.StatusCode, - issues: typing.Iterable[typing.Any], + self, + status: issues.StatusCode, + issues: typing.Iterable[typing.Any], ): self.status = status self.issues = issues @@ -221,10 +219,12 @@ def __str__(self): return self.__repr__() @staticmethod - def from_proto(msg: typing.Union[ - ydb_topic_pb2.StreamReadMessage.FromServer, - ydb_topic_pb2.StreamWriteMessage.FromServer, - ]) -> "ServerStatus": + def from_proto( + msg: typing.Union[ + ydb_topic_pb2.StreamReadMessage.FromServer, + ydb_topic_pb2.StreamWriteMessage.FromServer, + ] + ) -> "ServerStatus": return ServerStatus(msg.status, msg.issues) def is_success(self) -> bool: @@ -259,14 +259,12 @@ def from_proto(msg: ydb_topic_pb2.UpdateTokenResponse) -> typing.Any: TokenGetterFuncType = typing.Optional[typing.Callable[[], str]] -def callback_from_asyncio(callback: typing.Union[typing.Callable, typing.Coroutine]) -> [asyncio.Future, asyncio.Task]: +def callback_from_asyncio( + callback: typing.Union[typing.Callable, typing.Coroutine] +) -> [asyncio.Future, asyncio.Task]: loop = asyncio.get_running_loop() if asyncio.iscoroutinefunction(callback): return loop.create_task(callback()) else: return loop.run_in_executor(None, callback) - - -def ensure_success_or_raise_error(server_status: ServerStatus): - error = issues._process_response(server_status._grpc_status_code, server_status._issues) diff --git a/ydb/_topic_wrapper/common_test.py b/ydb/_topic_wrapper/common_test.py index f7c7493e..d490c5ec 100644 --- a/ydb/_topic_wrapper/common_test.py +++ b/ydb/_topic_wrapper/common_test.py @@ -11,17 +11,16 @@ if False: from ydb._grpc.v4.protos import ( ydb_status_codes_pb2, - ydb_issue_message_pb2, ydb_topic_pb2, ) else: # noinspection PyUnresolvedReferences from ydb._grpc.common.protos import ( ydb_status_codes_pb2, - ydb_issue_message_pb2, ydb_topic_pb2, ) + @pytest.mark.asyncio class Test: async def test_callback_from_asyncio(self): @@ -108,4 +107,3 @@ def test_failed(self): assert not status.is_success() with pytest.raises(issues.Overloaded): issues._process_response(status) - diff --git a/ydb/_topic_wrapper/reader.py b/ydb/_topic_wrapper/reader.py index baa023c9..88a12778 100644 --- a/ydb/_topic_wrapper/reader.py +++ b/ydb/_topic_wrapper/reader.py @@ -1,12 +1,15 @@ import datetime -import typing from dataclasses import dataclass, field from typing import List, Union, Dict -from google.protobuf.message import Message - -from ydb._topic_wrapper.common import OffsetsRange, IToProto, UpdateTokenRequest, UpdateTokenResponse, IFromProto, \ - ServerStatus +from ydb._topic_wrapper.common import ( + OffsetsRange, + IToProto, + UpdateTokenRequest, + UpdateTokenResponse, + IFromProto, + ServerStatus, +) from google.protobuf.duration_pb2 import Duration as ProtoDuration # Workaround for good autocomplete in IDE and universal import at runtime @@ -25,7 +28,9 @@ class PartitionSession(IFromProto): partition_id: int @staticmethod - def from_proto(msg: ydb_topic_pb2.StreamReadMessage.PartitionSession) -> "StreamReadMessage.PartitionSession": + def from_proto( + msg: ydb_topic_pb2.StreamReadMessage.PartitionSession, + ) -> "StreamReadMessage.PartitionSession": return StreamReadMessage.PartitionSession( partition_session_id=msg.partition_session_id, path=msg.path, @@ -51,7 +56,9 @@ class TopicReadSettings(IToProto): max_lag_seconds: Union[datetime.timedelta, None] = None read_from: Union[int, float, datetime.datetime, None] = None - def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.InitRequest.TopicReadSettings: + def to_proto( + self, + ) -> ydb_topic_pb2.StreamReadMessage.InitRequest.TopicReadSettings: res = ydb_topic_pb2.StreamReadMessage.InitRequest.TopicReadSettings() res.path = self.path res.partition_ids.extend(self.partition_ids) @@ -65,7 +72,9 @@ class InitResponse(IFromProto): session_id: str @staticmethod - def from_proto(msg: ydb_topic_pb2.StreamReadMessage.InitResponse) -> "StreamReadMessage.InitResponse": + def from_proto( + msg: ydb_topic_pb2.StreamReadMessage.InitResponse, + ) -> "StreamReadMessage.InitResponse": return StreamReadMessage.InitResponse(session_id=msg.session_id) @dataclass @@ -83,10 +92,16 @@ class ReadResponse(IFromProto): bytes_size: int @staticmethod - def from_proto(msg: ydb_topic_pb2.StreamReadMessage.ReadResponse) -> "StreamReadMessage.ReadResponse": + def from_proto( + msg: ydb_topic_pb2.StreamReadMessage.ReadResponse, + ) -> "StreamReadMessage.ReadResponse": partition_data = [] for proto_partition_data in msg.partition_data: - partition_data.append(StreamReadMessage.ReadResponse.PartitionData.from_proto(proto_partition_data)) + partition_data.append( + StreamReadMessage.ReadResponse.PartitionData.from_proto( + proto_partition_data + ) + ) return StreamReadMessage.ReadResponse( partition_data=partition_data, bytes_size=msg.bytes_size, @@ -102,15 +117,16 @@ class MessageData(IFromProto): message_group_id: str @staticmethod - def from_proto(msg: ydb_topic_pb2.StreamReadMessage.ReadResponse.MessageData) ->\ - "StreamReadMessage.ReadResponse.MessageData": + def from_proto( + msg: ydb_topic_pb2.StreamReadMessage.ReadResponse.MessageData, + ) -> "StreamReadMessage.ReadResponse.MessageData": return StreamReadMessage.ReadResponse.MessageData( offset=msg.offset, seq_no=msg.seq_no, created_at=msg.created_at.ToDatetime(), data=msg.data, uncompresed_size=msg.uncompressed_size, - message_group_id=msg.message_group_id + message_group_id=msg.message_group_id, ) @dataclass @@ -122,11 +138,14 @@ class Batch(IFromProto): written_at: datetime.datetime @staticmethod - def from_proto(msg: ydb_topic_pb2.StreamReadMessage.ReadResponse.Batch) -> \ - "StreamReadMessage.ReadResponse.Batch": + def from_proto( + msg: ydb_topic_pb2.StreamReadMessage.ReadResponse.Batch, + ) -> "StreamReadMessage.ReadResponse.Batch": message_data = [] for message in msg.message_data: - message_data.append(StreamReadMessage.ReadResponse.MessageData.from_proto(message)) + message_data.append( + StreamReadMessage.ReadResponse.MessageData.from_proto(message) + ) return StreamReadMessage.ReadResponse.Batch( message_data=message_data, producer_id=msg.producer_id, @@ -135,24 +154,25 @@ def from_proto(msg: ydb_topic_pb2.StreamReadMessage.ReadResponse.Batch) -> \ written_at=msg.written_at.ToDatetime(), ) - @dataclass class PartitionData(IFromProto): partition_session_id: int batches: List["StreamReadMessage.ReadResponse.Batch"] @staticmethod - def from_proto(msg: ydb_topic_pb2.StreamReadMessage.ReadResponse.PartitionData) ->\ - "StreamReadMessage.ReadResponse.PartitionData": + def from_proto( + msg: ydb_topic_pb2.StreamReadMessage.ReadResponse.PartitionData, + ) -> "StreamReadMessage.ReadResponse.PartitionData": batches = [] for proto_batch in msg.batches: - batches.append(StreamReadMessage.ReadResponse.Batch.from_proto(proto_batch)) + batches.append( + StreamReadMessage.ReadResponse.Batch.from_proto(proto_batch) + ) return StreamReadMessage.ReadResponse.PartitionData( partition_session_id=msg.partition_session_id, batches=batches, ) - @dataclass class CommitOffsetRequest: commit_offsets: List["PartitionCommitOffset"] @@ -189,12 +209,15 @@ class StartPartitionSessionRequest(IFromProto): partition_offsets: OffsetsRange @staticmethod - def from_proto(msg: ydb_topic_pb2.StreamReadMessage.StartPartitionSessionRequest) -> \ - "StreamReadMessage.StartPartitionSessionRequest": + def from_proto( + msg: ydb_topic_pb2.StreamReadMessage.StartPartitionSessionRequest, + ) -> "StreamReadMessage.StartPartitionSessionRequest": return StreamReadMessage.StartPartitionSessionRequest( - partition_session=StreamReadMessage.PartitionSession.from_proto(msg.partition_session), + partition_session=StreamReadMessage.PartitionSession.from_proto( + msg.partition_session + ), committed_offset=msg.committed_offset, - partition_offsets=OffsetsRange.from_proto(msg.partition_offsets) + partition_offsets=OffsetsRange.from_proto(msg.partition_offsets), ) @dataclass @@ -203,7 +226,9 @@ class StartPartitionSessionResponse(IToProto): read_offset: int commit_offset: int - def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.StartPartitionSessionResponse: + def to_proto( + self, + ) -> ydb_topic_pb2.StreamReadMessage.StartPartitionSessionResponse: res = ydb_topic_pb2.StreamReadMessage.StartPartitionSessionResponse() res.partition_session_id = self.partition_session_id res.read_offset = self.read_offset @@ -233,8 +258,12 @@ def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.FromClient: res.read_request.CopyFrom(self.client_message.to_proto()) elif isinstance(self.client_message, StreamReadMessage.InitRequest): res.init_request.CopyFrom(self.client_message.to_proto()) - elif isinstance(self.client_message, StreamReadMessage.StartPartitionSessionResponse): - res.start_partition_session_response.CopyFrom(self.client_message.to_proto()) + elif isinstance( + self.client_message, StreamReadMessage.StartPartitionSessionResponse + ): + res.start_partition_session_response.CopyFrom( + self.client_message.to_proto() + ) else: raise NotImplementedError() return res @@ -245,23 +274,31 @@ class FromServer(IFromProto): server_status: ServerStatus @staticmethod - def from_proto(msg: ydb_topic_pb2.StreamReadMessage.FromServer) -> "StreamReadMessage.FromServer": + def from_proto( + msg: ydb_topic_pb2.StreamReadMessage.FromServer, + ) -> "StreamReadMessage.FromServer": mess_type = msg.WhichOneof("server_message") server_status = ServerStatus.from_proto(msg) if mess_type == "read_response": return StreamReadMessage.FromServer( server_status=server_status, - server_message=StreamReadMessage.ReadResponse.from_proto(msg.read_response), + server_message=StreamReadMessage.ReadResponse.from_proto( + msg.read_response + ), ) elif mess_type == "init_response": return StreamReadMessage.FromServer( server_status=server_status, - server_message=StreamReadMessage.InitResponse.from_proto(msg.init_response), + server_message=StreamReadMessage.InitResponse.from_proto( + msg.init_response + ), ) elif mess_type == "start_partition_session_request": return StreamReadMessage.FromServer( server_status=server_status, - server_message=StreamReadMessage.StartPartitionSessionRequest.from_proto(msg.start_partition_session_request) + server_message=StreamReadMessage.StartPartitionSessionRequest.from_proto( + msg.start_partition_session_request + ), ) # todo replace exception to log diff --git a/ydb/_topic_wrapper/test_helpers.py b/ydb/_topic_wrapper/test_helpers.py index b0c75a03..a278046b 100644 --- a/ydb/_topic_wrapper/test_helpers.py +++ b/ydb/_topic_wrapper/test_helpers.py @@ -2,8 +2,6 @@ import time import typing -import pytest - from .common import IGrpcWrapperAsyncIO, IToProto diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 01be0a8c..b3999659 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -1,7 +1,7 @@ import asyncio import datetime from collections import deque -from typing import Deque, AsyncIterator, Union, List, Optional, Callable +from typing import Deque, AsyncIterator, Union, List, Optional import ydb from .topic_writer import ( @@ -26,7 +26,8 @@ UpdateTokenResponse, GrpcWrapperAsyncIO, IGrpcWrapperAsyncIO, - SupportedDriverType, TokenGetterFuncType, + SupportedDriverType, + TokenGetterFuncType, ) from .._topic_wrapper.writer import StreamWriteMessage, WriterMessagesFromServerToClient @@ -470,5 +471,3 @@ def _ensure_ok(message: WriterMessagesFromServerToClient): def write(self, messages: List[InternalMessage]): for request in messages_to_proto_requests(messages): self._stream.write(request) - - diff --git a/ydb/_topic_writer/topic_writer_asyncio_test.py b/ydb/_topic_writer/topic_writer_asyncio_test.py index 38a0a2dd..1d1ca225 100644 --- a/ydb/_topic_writer/topic_writer_asyncio_test.py +++ b/ydb/_topic_writer/topic_writer_asyncio_test.py @@ -14,7 +14,7 @@ from .. import aio from .. import StatusCode, issues -from .._topic_wrapper.common import ServerStatus, IGrpcWrapperAsyncIO, IToProto, Codec +from .._topic_wrapper.common import ServerStatus, Codec from .topic_writer import ( InternalMessage, PublicMessage, diff --git a/ydb/issues.py b/ydb/issues.py index 6df634ea..727aff1b 100644 --- a/ydb/issues.py +++ b/ydb/issues.py @@ -1,6 +1,4 @@ # -*- coding: utf-8 -*- -import abc - from google.protobuf import text_format import enum from six.moves import queue diff --git a/ydb/topic.py b/ydb/topic.py index 880424ea..1b2722ca 100644 --- a/ydb/topic.py +++ b/ydb/topic.py @@ -1,4 +1,4 @@ -from typing import List, Callable, Union, Mapping, Any, Optional +from typing import List, Callable, Union, Mapping, Any from . import aio, Credentials from ._topic_reader.topic_reader import ( @@ -11,9 +11,8 @@ ) from ._topic_reader.topic_reader_asyncio import ( - PublicAsyncIOReader as TopicReaderAsyncIO + PublicAsyncIOReader as TopicReaderAsyncIO, ) -from ._topic_wrapper.common import TokenGetterFuncType from ._topic_writer.topic_writer import ( # noqa: F401 Writer as TopicWriter, From 779fe85eba0a0946277b8b1c4feb9b96f15deafa Mon Sep 17 00:00:00 2001 From: robot Date: Tue, 7 Feb 2023 15:31:40 +0000 Subject: [PATCH 036/147] Release: 3.0.1b4 --- CHANGELOG.md | 1 + setup.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 04d9d60b..c0b52c2b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,4 @@ +## 3.0.1b4 ## * Initial implementation of topic reader ## 3.0.1b3 ## diff --git a/setup.py b/setup.py index 6716c84a..389711ba 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setuptools.setup( name="ydb", - version="3.0.1b3", # AUTOVERSION + version="3.0.1b4", # AUTOVERSION description="YDB Python SDK", author="Yandex LLC", author_email="ydb@yandex-team.ru", From 996051160ab15401ca386896a7b8e7a3aa94394a Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Wed, 8 Feb 2023 14:15:54 +0300 Subject: [PATCH 037/147] move topic grpc wrapper code to one file --- tox.ini | 4 +- ydb/_grpc/common/__init__.py | 10 +- .../grpcwrapper}/__init__.py | 0 .../grpcwrapper/common_utils.py} | 84 +-- ydb/_grpc/grpcwrapper/ydb_topic.py | 639 ++++++++++++++++++ ydb/_topic_common/__init__.py | 0 ydb/_topic_common/common.py | 3 + .../common_test.py | 6 +- .../control_plane.py | 0 .../test_helpers.py | 2 +- ydb/_topic_reader/topic_reader.py | 4 +- ydb/_topic_reader/topic_reader_asyncio.py | 6 +- .../topic_reader_asyncio_test.py | 11 +- ydb/_topic_wrapper/reader.py | 326 --------- ydb/_topic_wrapper/writer.py | 289 -------- ydb/_topic_writer/topic_writer.py | 4 +- ydb/_topic_writer/topic_writer_asyncio.py | 13 +- .../topic_writer_asyncio_test.py | 6 +- ydb/_topic_writer/topic_writer_sync.py | 2 +- 19 files changed, 691 insertions(+), 718 deletions(-) rename ydb/{_topic_wrapper => _grpc/grpcwrapper}/__init__.py (100%) rename ydb/{_topic_wrapper/common.py => _grpc/grpcwrapper/common_utils.py} (77%) create mode 100644 ydb/_grpc/grpcwrapper/ydb_topic.py create mode 100644 ydb/_topic_common/__init__.py create mode 100644 ydb/_topic_common/common.py rename ydb/{_topic_wrapper => _topic_common}/common_test.py (96%) rename ydb/{_topic_wrapper => _topic_common}/control_plane.py (100%) rename ydb/{_topic_wrapper => _topic_common}/test_helpers.py (92%) delete mode 100644 ydb/_topic_wrapper/reader.py delete mode 100644 ydb/_topic_wrapper/writer.py diff --git a/tox.ini b/tox.ini index 026ec732..28181d20 100644 --- a/tox.ini +++ b/tox.ini @@ -46,12 +46,12 @@ deps = [testenv:black-format] skip_install = true commands = - black ydb examples tests --extend-exclude ydb/_grpc + black ydb examples tests --extend-exclude "ydb/_grpc/v3|ydb/_grpc/v4" [testenv:black] skip_install = true commands = - black --diff --check ydb examples tests --extend-exclude ydb/_grpc + black --diff --check ydb examples tests --extend-exclude "ydb/_grpc/v3|ydb/_grpc/v4" [testenv:pylint] deps = pylint diff --git a/ydb/_grpc/common/__init__.py b/ydb/_grpc/common/__init__.py index 10138358..4a5ef87b 100644 --- a/ydb/_grpc/common/__init__.py +++ b/ydb/_grpc/common/__init__.py @@ -9,12 +9,14 @@ protobuf_version = Version(google.protobuf.__version__) if protobuf_version < Version("4.0"): - from ydb._grpc.v3 import * # noqa - from ydb._grpc.v3 import protos # noqa + from ydb._grpc.v3 import * # noqa + from ydb._grpc.v3 import protos # noqa + sys.modules["ydb._grpc.common"] = sys.modules["ydb._grpc.v3"] sys.modules["ydb._grpc.common.protos"] = sys.modules["ydb._grpc.v3.protos"] else: - from ydb._grpc.v4 import * # noqa - from ydb._grpc.v4 import protos # noqa + from ydb._grpc.v4 import * # noqa + from ydb._grpc.v4 import protos # noqa + sys.modules["ydb._grpc.common"] = sys.modules["ydb._grpc.v4"] sys.modules["ydb._grpc.common.protos"] = sys.modules["ydb._grpc.v4.protos"] diff --git a/ydb/_topic_wrapper/__init__.py b/ydb/_grpc/grpcwrapper/__init__.py similarity index 100% rename from ydb/_topic_wrapper/__init__.py rename to ydb/_grpc/grpcwrapper/__init__.py diff --git a/ydb/_topic_wrapper/common.py b/ydb/_grpc/grpcwrapper/common_utils.py similarity index 77% rename from ydb/_topic_wrapper/common.py rename to ydb/_grpc/grpcwrapper/common_utils.py index e666dc2f..f4c0f6e2 100644 --- a/ydb/_topic_wrapper/common.py +++ b/ydb/_grpc/grpcwrapper/common_utils.py @@ -2,46 +2,20 @@ import asyncio import typing from dataclasses import dataclass -from enum import IntEnum import grpc from google.protobuf.message import Message import ydb.aio -from .. import issues, connection - -# Workaround for good autocomplete in IDE and universal import at runtime +# Workaround for good IDE and universal for runtime # noinspection PyUnreachableCode if False: - from ydb._grpc.v4.protos import ( - ydb_issue_message_pb2, - ydb_topic_pb2, - ) + from ..v4.protos import ydb_topic_pb2, ydb_issue_message_pb2 else: - # noinspection PyUnresolvedReferences - from ydb._grpc.common.protos import ( - ydb_issue_message_pb2, - ydb_topic_pb2, - ) - - -class Codec(IntEnum): - CODEC_UNSPECIFIED = 0 - CODEC_RAW = 1 - CODEC_GZIP = 2 - CODEC_LZOP = 3 - CODEC_ZSTD = 4 + from ..common.protos import ydb_topic_pb2, ydb_issue_message_pb2 - -class IToProto(abc.ABC): - @abc.abstractmethod - def to_proto(self) -> Message: - pass - - -class UnknownGrpcMessageError(ydb.Error): - pass +from ... import issues, connection class IFromProto(abc.ABC): @@ -51,17 +25,14 @@ def from_proto(msg: Message) -> typing.Any: pass -@dataclass -class OffsetsRange(IFromProto): - start: int - end: int +class IToProto(abc.ABC): + @abc.abstractmethod + def to_proto(self) -> Message: + pass - @staticmethod - def from_proto(msg: ydb_topic_pb2.OffsetsRange) -> "OffsetsRange": - return OffsetsRange( - start=msg.start, - end=msg.end, - ) + +class UnknownGrpcMessageError(issues.Error): + pass class QueueToIteratorAsyncIO: @@ -119,19 +90,6 @@ async def __anext__(self): raise StopIteration() -class IteratorToQueueAsyncIO: - __slots__ = ("_iterator",) - - def __init__(self, iterator: typing.AsyncIterator[typing.Any]): - self._iterator = iterator - - async def get(self) -> typing.Any: - try: - return self._iterator.__anext__() - except StopAsyncIteration: - raise asyncio.QueueEmpty() - - class IGrpcWrapperAsyncIO(abc.ABC): @abc.abstractmethod async def receive(self) -> typing.Any: @@ -239,26 +197,6 @@ def issue_to_str(cls, issue: ydb_issue_message_pb2.IssueMessage): return res -@dataclass -class UpdateTokenRequest(IToProto): - token: str - - def to_proto(self) -> Message: - res = ydb_topic_pb2.UpdateTokenRequest() - res.token = self.token - return res - - -@dataclass -class UpdateTokenResponse(IFromProto): - @staticmethod - def from_proto(msg: ydb_topic_pb2.UpdateTokenResponse) -> typing.Any: - return UpdateTokenResponse() - - -TokenGetterFuncType = typing.Optional[typing.Callable[[], str]] - - def callback_from_asyncio( callback: typing.Union[typing.Callable, typing.Coroutine] ) -> [asyncio.Future, asyncio.Task]: diff --git a/ydb/_grpc/grpcwrapper/ydb_topic.py b/ydb/_grpc/grpcwrapper/ydb_topic.py new file mode 100644 index 00000000..df43b803 --- /dev/null +++ b/ydb/_grpc/grpcwrapper/ydb_topic.py @@ -0,0 +1,639 @@ +import datetime +import enum +import typing +from dataclasses import dataclass, field +from enum import IntEnum +from typing import List, Union, Dict + +from google.protobuf.duration_pb2 import Duration as ProtoDuration +from google.protobuf.message import Message + +# Workaround for good IDE and universal for runtime +# noinspection PyUnreachableCode +if False: + from ..v4.protos import ydb_topic_pb2 +else: + from ..common.protos import ydb_topic_pb2 + +from .common_utils import IFromProto, IToProto, ServerStatus, UnknownGrpcMessageError + + +class Codec(IntEnum): + CODEC_UNSPECIFIED = 0 + CODEC_RAW = 1 + CODEC_GZIP = 2 + CODEC_LZOP = 3 + CODEC_ZSTD = 4 + + +@dataclass +class OffsetsRange(IFromProto): + start: int + end: int + + @staticmethod + def from_proto(msg: ydb_topic_pb2.OffsetsRange) -> "OffsetsRange": + return OffsetsRange( + start=msg.start, + end=msg.end, + ) + + +@dataclass +class UpdateTokenRequest(IToProto): + token: str + + def to_proto(self) -> Message: + res = ydb_topic_pb2.UpdateTokenRequest() + res.token = self.token + return res + + +@dataclass +class UpdateTokenResponse(IFromProto): + @staticmethod + def from_proto(msg: ydb_topic_pb2.UpdateTokenResponse) -> typing.Any: + return UpdateTokenResponse() + + +######################################################################################################################## +# StreamWrite +######################################################################################################################## + + +class StreamWriteMessage: + @dataclass() + class InitRequest(IToProto): + path: str + producer_id: str + write_session_meta: typing.Dict[str, str] + partitioning: "StreamWriteMessage.PartitioningType" + get_last_seq_no: bool + + def to_proto(self) -> ydb_topic_pb2.StreamWriteMessage.InitRequest: + proto = ydb_topic_pb2.StreamWriteMessage.InitRequest() + proto.path = self.path + proto.producer_id = self.producer_id + + if self.partitioning is None: + pass + elif isinstance( + self.partitioning, StreamWriteMessage.PartitioningMessageGroupID + ): + proto.message_group_id = self.partitioning.message_group_id + elif isinstance( + self.partitioning, StreamWriteMessage.PartitioningPartitionID + ): + proto.partition_id = self.partitioning.partition_id + else: + raise Exception( + "Bad partitioning type at StreamWriteMessage.InitRequest" + ) + + if self.write_session_meta: + for key in self.write_session_meta: + proto.write_session_meta[key] = self.write_session_meta[key] + + proto.get_last_seq_no = self.get_last_seq_no + return proto + + @dataclass + class InitResponse(IFromProto): + last_seq_no: Union[int, None] + session_id: str + partition_id: int + supported_codecs: typing.List[int] + status: ServerStatus = None + + @staticmethod + def from_proto( + msg: ydb_topic_pb2.StreamWriteMessage.InitResponse, + ) -> "StreamWriteMessage.InitResponse": + codecs = [] # type: typing.List[int] + if msg.supported_codecs: + for codec in msg.supported_codecs.codecs: + codecs.append(codec) + + return StreamWriteMessage.InitResponse( + last_seq_no=msg.last_seq_no, + session_id=msg.session_id, + partition_id=msg.partition_id, + supported_codecs=codecs, + ) + + @dataclass + class WriteRequest(IToProto): + messages: typing.List["StreamWriteMessage.WriteRequest.MessageData"] + codec: int + + @dataclass + class MessageData(IToProto): + seq_no: int + created_at: datetime.datetime + data: bytes + uncompressed_size: int + partitioning: "StreamWriteMessage.PartitioningType" + + def to_proto( + self, + ) -> ydb_topic_pb2.StreamWriteMessage.WriteRequest.MessageData: + proto = ydb_topic_pb2.StreamWriteMessage.WriteRequest.MessageData() + proto.seq_no = self.seq_no + proto.created_at.FromDatetime(self.created_at) + proto.data = self.data + proto.uncompressed_size = self.uncompressed_size + + if self.partitioning is None: + pass + elif isinstance( + self.partitioning, StreamWriteMessage.PartitioningPartitionID + ): + proto.partition_id = self.partitioning.partition_id + elif isinstance( + self.partitioning, StreamWriteMessage.PartitioningMessageGroupID + ): + proto.message_group_id = self.partitioning.message_group_id + else: + raise Exception( + "Bad partition at StreamWriteMessage.WriteRequest.MessageData" + ) + + return proto + + def to_proto(self) -> ydb_topic_pb2.StreamWriteMessage.WriteRequest: + proto = ydb_topic_pb2.StreamWriteMessage.WriteRequest() + proto.codec = self.codec + + for message in self.messages: + proto_mess = proto.messages.add() + proto_mess.CopyFrom(message.to_proto()) + + return proto + + @dataclass + class WriteResponse(IFromProto): + partition_id: int + acks: typing.List["StreamWriteMessage.WriteResponse.WriteAck"] + write_statistics: "StreamWriteMessage.WriteResponse.WriteStatistics" + status: ServerStatus = field(default=None) + + @staticmethod + def from_proto( + msg: ydb_topic_pb2.StreamWriteMessage.WriteResponse, + ) -> "StreamWriteMessage.WriteResponse": + acks = [] + for proto_ack in msg.acks: + ack = StreamWriteMessage.WriteResponse.WriteAck.from_proto(proto_ack) + acks.append(ack) + write_statistics = StreamWriteMessage.WriteResponse.WriteStatistics( + persisting_time=msg.write_statistics.persisting_time.ToTimedelta(), + min_queue_wait_time=msg.write_statistics.min_queue_wait_time.ToTimedelta(), + max_queue_wait_time=msg.write_statistics.max_queue_wait_time.ToTimedelta(), + partition_quota_wait_time=msg.write_statistics.partition_quota_wait_time.ToTimedelta(), + topic_quota_wait_time=msg.write_statistics.topic_quota_wait_time.ToTimedelta(), + ) + return StreamWriteMessage.WriteResponse( + partition_id=msg.partition_id, + acks=acks, + write_statistics=write_statistics, + status=None, + ) + + @dataclass + class WriteAck(IFromProto): + seq_no: int + message_write_status: Union[ + "StreamWriteMessage.WriteResponse.WriteAck.StatusWritten", + "StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped", + int, + ] + + @classmethod + def from_proto( + cls, proto_ack: ydb_topic_pb2.StreamWriteMessage.WriteResponse.WriteAck + ): + if proto_ack.HasField("written"): + message_write_status = ( + StreamWriteMessage.WriteResponse.WriteAck.StatusWritten( + proto_ack.written.offset + ) + ) + elif proto_ack.HasField("skipped"): + reason = proto_ack.skipped.reason + try: + message_write_status = StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped( + reason=StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped.Reason.from_protobuf_code( + reason + ) + ) + except ValueError: + message_write_status = reason + else: + raise NotImplementedError("unexpected ack status") + + return StreamWriteMessage.WriteResponse.WriteAck( + seq_no=proto_ack.seq_no, + message_write_status=message_write_status, + ) + + @dataclass + class StatusWritten: + offset: int + + @dataclass + class StatusSkipped: + reason: "StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped.Reason" + + class Reason(enum.Enum): + UNSPECIFIED = 0 + ALREADY_WRITTEN = 1 + + @classmethod + def from_protobuf_code( + cls, code: int + ) -> Union[ + "StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped.Reason", + int, + ]: + try: + return StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped.Reason( + code + ) + except ValueError: + return code + + @dataclass + class WriteStatistics: + persisting_time: datetime.timedelta + min_queue_wait_time: datetime.timedelta + max_queue_wait_time: datetime.timedelta + partition_quota_wait_time: datetime.timedelta + topic_quota_wait_time: datetime.timedelta + + @dataclass + class PartitioningMessageGroupID: + message_group_id: str + + @dataclass + class PartitioningPartitionID: + partition_id: int + + PartitioningType = Union[PartitioningMessageGroupID, PartitioningPartitionID, None] + + @dataclass + class FromClient(IToProto): + value: "WriterMessagesFromClientToServer" + + def __init__(self, value: "WriterMessagesFromClientToServer"): + self.value = value + + def to_proto(self) -> Message: + res = ydb_topic_pb2.StreamWriteMessage.FromClient() + value = self.value + if isinstance(value, StreamWriteMessage.WriteRequest): + res.write_request.CopyFrom(value.to_proto()) + elif isinstance(value, StreamWriteMessage.InitRequest): + res.init_request.CopyFrom(value.to_proto()) + elif isinstance(value, UpdateTokenRequest): + res.update_token_request.CopyFrom(value.to_proto()) + else: + raise Exception("Unknown outcoming grpc message: %s" % value) + return res + + class FromServer(IFromProto): + @staticmethod + def from_proto(msg: ydb_topic_pb2.StreamWriteMessage.FromServer) -> typing.Any: + message_type = msg.WhichOneof("server_message") + if message_type == "write_response": + res = StreamWriteMessage.WriteResponse.from_proto(msg.write_response) + elif message_type == "init_response": + res = StreamWriteMessage.InitResponse.from_proto(msg.init_response) + elif message_type == "update_token_response": + res = UpdateTokenResponse.from_proto(msg.update_token_response) + else: + # todo log instead of exception - for allow add messages in the future + raise UnknownGrpcMessageError("Unexpected proto message: %s" % msg) + + res.status = ServerStatus(msg.status, msg.issues) + return res + + +WriterMessagesFromClientToServer = Union[ + StreamWriteMessage.InitRequest, StreamWriteMessage.WriteRequest, UpdateTokenRequest +] +WriterMessagesFromServerToClient = Union[ + StreamWriteMessage.InitResponse, + StreamWriteMessage.WriteResponse, + UpdateTokenResponse, +] + + +######################################################################################################################## +# StreamRead +######################################################################################################################## + + +class StreamReadMessage: + @dataclass + class PartitionSession(IFromProto): + partition_session_id: int + path: str + partition_id: int + + @staticmethod + def from_proto( + msg: ydb_topic_pb2.StreamReadMessage.PartitionSession, + ) -> "StreamReadMessage.PartitionSession": + return StreamReadMessage.PartitionSession( + partition_session_id=msg.partition_session_id, + path=msg.path, + partition_id=msg.partition_id, + ) + + @dataclass + class InitRequest(IToProto): + topics_read_settings: List["StreamReadMessage.InitRequest.TopicReadSettings"] + consumer: str + + def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.InitRequest: + res = ydb_topic_pb2.StreamReadMessage.InitRequest() + res.consumer = self.consumer + for settings in self.topics_read_settings: + res.topics_read_settings.append(settings.to_proto()) + return res + + @dataclass + class TopicReadSettings(IToProto): + path: str + partition_ids: List[int] = field(default_factory=list) + max_lag_seconds: Union[datetime.timedelta, None] = None + read_from: Union[int, float, datetime.datetime, None] = None + + def to_proto( + self, + ) -> ydb_topic_pb2.StreamReadMessage.InitRequest.TopicReadSettings: + res = ydb_topic_pb2.StreamReadMessage.InitRequest.TopicReadSettings() + res.path = self.path + res.partition_ids.extend(self.partition_ids) + if self.max_lag_seconds is not None: + res.max_lag = ProtoDuration() + res.max_lag.FromTimedelta(self.max_lag_seconds) + return res + + @dataclass + class InitResponse(IFromProto): + session_id: str + + @staticmethod + def from_proto( + msg: ydb_topic_pb2.StreamReadMessage.InitResponse, + ) -> "StreamReadMessage.InitResponse": + return StreamReadMessage.InitResponse(session_id=msg.session_id) + + @dataclass + class ReadRequest(IToProto): + bytes_size: int + + def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.ReadRequest: + res = ydb_topic_pb2.StreamReadMessage.ReadRequest() + res.bytes_size = self.bytes_size + return res + + @dataclass + class ReadResponse(IFromProto): + partition_data: List["StreamReadMessage.ReadResponse.PartitionData"] + bytes_size: int + + @staticmethod + def from_proto( + msg: ydb_topic_pb2.StreamReadMessage.ReadResponse, + ) -> "StreamReadMessage.ReadResponse": + partition_data = [] + for proto_partition_data in msg.partition_data: + partition_data.append( + StreamReadMessage.ReadResponse.PartitionData.from_proto( + proto_partition_data + ) + ) + return StreamReadMessage.ReadResponse( + partition_data=partition_data, + bytes_size=msg.bytes_size, + ) + + @dataclass + class MessageData(IFromProto): + offset: int + seq_no: int + created_at: datetime.datetime + data: bytes + uncompresed_size: int + message_group_id: str + + @staticmethod + def from_proto( + msg: ydb_topic_pb2.StreamReadMessage.ReadResponse.MessageData, + ) -> "StreamReadMessage.ReadResponse.MessageData": + return StreamReadMessage.ReadResponse.MessageData( + offset=msg.offset, + seq_no=msg.seq_no, + created_at=msg.created_at.ToDatetime(), + data=msg.data, + uncompresed_size=msg.uncompressed_size, + message_group_id=msg.message_group_id, + ) + + @dataclass + class Batch(IFromProto): + message_data: List["StreamReadMessage.ReadResponse.MessageData"] + producer_id: str + write_session_meta: Dict[str, str] + codec: int + written_at: datetime.datetime + + @staticmethod + def from_proto( + msg: ydb_topic_pb2.StreamReadMessage.ReadResponse.Batch, + ) -> "StreamReadMessage.ReadResponse.Batch": + message_data = [] + for message in msg.message_data: + message_data.append( + StreamReadMessage.ReadResponse.MessageData.from_proto(message) + ) + return StreamReadMessage.ReadResponse.Batch( + message_data=message_data, + producer_id=msg.producer_id, + write_session_meta=dict(msg.write_session_meta), + codec=msg.codec, + written_at=msg.written_at.ToDatetime(), + ) + + @dataclass + class PartitionData(IFromProto): + partition_session_id: int + batches: List["StreamReadMessage.ReadResponse.Batch"] + + @staticmethod + def from_proto( + msg: ydb_topic_pb2.StreamReadMessage.ReadResponse.PartitionData, + ) -> "StreamReadMessage.ReadResponse.PartitionData": + batches = [] + for proto_batch in msg.batches: + batches.append( + StreamReadMessage.ReadResponse.Batch.from_proto(proto_batch) + ) + return StreamReadMessage.ReadResponse.PartitionData( + partition_session_id=msg.partition_session_id, + batches=batches, + ) + + @dataclass + class CommitOffsetRequest: + commit_offsets: List["PartitionCommitOffset"] + + @dataclass + class PartitionCommitOffset: + partition_session_id: int + offsets: List["OffsetsRange"] + + @dataclass + class CommitOffsetResponse: + partitions_committed_offsets: List["PartitionCommittedOffset"] + + @dataclass + class PartitionCommittedOffset: + partition_session_id: int + committed_offset: int + + @dataclass + class PartitionSessionStatusRequest: + partition_session_id: int + + @dataclass + class PartitionSessionStatusResponse: + partition_session_id: int + partition_offsets: "OffsetsRange" + committed_offset: int + write_time_high_watermark: float + + @dataclass + class StartPartitionSessionRequest(IFromProto): + partition_session: "StreamReadMessage.PartitionSession" + committed_offset: int + partition_offsets: "OffsetsRange" + + @staticmethod + def from_proto( + msg: ydb_topic_pb2.StreamReadMessage.StartPartitionSessionRequest, + ) -> "StreamReadMessage.StartPartitionSessionRequest": + return StreamReadMessage.StartPartitionSessionRequest( + partition_session=StreamReadMessage.PartitionSession.from_proto( + msg.partition_session + ), + committed_offset=msg.committed_offset, + partition_offsets=OffsetsRange.from_proto(msg.partition_offsets), + ) + + @dataclass + class StartPartitionSessionResponse(IToProto): + partition_session_id: int + read_offset: int + commit_offset: int + + def to_proto( + self, + ) -> ydb_topic_pb2.StreamReadMessage.StartPartitionSessionResponse: + res = ydb_topic_pb2.StreamReadMessage.StartPartitionSessionResponse() + res.partition_session_id = self.partition_session_id + res.read_offset = self.read_offset + res.commit_offset = self.commit_offset + return res + + @dataclass + class StopPartitionSessionRequest: + partition_session_id: int + graceful: bool + committed_offset: int + + @dataclass + class StopPartitionSessionResponse: + partition_session_id: int + + @dataclass + class FromClient(IToProto): + client_message: "ReaderMessagesFromClientToServer" + + def __init__(self, client_message: "ReaderMessagesFromClientToServer"): + self.client_message = client_message + + def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.FromClient: + res = ydb_topic_pb2.StreamReadMessage.FromClient() + if isinstance(self.client_message, StreamReadMessage.ReadRequest): + res.read_request.CopyFrom(self.client_message.to_proto()) + elif isinstance(self.client_message, StreamReadMessage.InitRequest): + res.init_request.CopyFrom(self.client_message.to_proto()) + elif isinstance( + self.client_message, StreamReadMessage.StartPartitionSessionResponse + ): + res.start_partition_session_response.CopyFrom( + self.client_message.to_proto() + ) + else: + raise NotImplementedError() + return res + + @dataclass + class FromServer(IFromProto): + server_message: "ReaderMessagesFromServerToClient" + server_status: ServerStatus + + @staticmethod + def from_proto( + msg: ydb_topic_pb2.StreamReadMessage.FromServer, + ) -> "StreamReadMessage.FromServer": + mess_type = msg.WhichOneof("server_message") + server_status = ServerStatus.from_proto(msg) + if mess_type == "read_response": + return StreamReadMessage.FromServer( + server_status=server_status, + server_message=StreamReadMessage.ReadResponse.from_proto( + msg.read_response + ), + ) + elif mess_type == "init_response": + return StreamReadMessage.FromServer( + server_status=server_status, + server_message=StreamReadMessage.InitResponse.from_proto( + msg.init_response + ), + ) + elif mess_type == "start_partition_session_request": + return StreamReadMessage.FromServer( + server_status=server_status, + server_message=StreamReadMessage.StartPartitionSessionRequest.from_proto( + msg.start_partition_session_request + ), + ) + + # todo replace exception to log + raise NotImplementedError() + + +ReaderMessagesFromClientToServer = Union[ + StreamReadMessage.InitRequest, + StreamReadMessage.ReadRequest, + StreamReadMessage.CommitOffsetRequest, + StreamReadMessage.PartitionSessionStatusRequest, + UpdateTokenRequest, + StreamReadMessage.StartPartitionSessionResponse, + StreamReadMessage.StopPartitionSessionResponse, +] + +ReaderMessagesFromServerToClient = Union[ + StreamReadMessage.InitResponse, + StreamReadMessage.ReadResponse, + StreamReadMessage.CommitOffsetResponse, + StreamReadMessage.PartitionSessionStatusResponse, + UpdateTokenResponse, + StreamReadMessage.StartPartitionSessionRequest, + StreamReadMessage.StopPartitionSessionRequest, +] diff --git a/ydb/_topic_common/__init__.py b/ydb/_topic_common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ydb/_topic_common/common.py b/ydb/_topic_common/common.py new file mode 100644 index 00000000..c92d9737 --- /dev/null +++ b/ydb/_topic_common/common.py @@ -0,0 +1,3 @@ +import typing + +TokenGetterFuncType = typing.Optional[typing.Callable[[], str]] diff --git a/ydb/_topic_wrapper/common_test.py b/ydb/_topic_common/common_test.py similarity index 96% rename from ydb/_topic_wrapper/common_test.py rename to ydb/_topic_common/common_test.py index d490c5ec..ce19f4a0 100644 --- a/ydb/_topic_wrapper/common_test.py +++ b/ydb/_topic_common/common_test.py @@ -3,7 +3,11 @@ import grpc import pytest -from .common import callback_from_asyncio, GrpcWrapperAsyncIO, ServerStatus +from .._grpc.grpcwrapper.common_utils import ( + GrpcWrapperAsyncIO, + ServerStatus, + callback_from_asyncio, +) from .. import issues # Workaround for good autocomplete in IDE and universal import at runtime diff --git a/ydb/_topic_wrapper/control_plane.py b/ydb/_topic_common/control_plane.py similarity index 100% rename from ydb/_topic_wrapper/control_plane.py rename to ydb/_topic_common/control_plane.py diff --git a/ydb/_topic_wrapper/test_helpers.py b/ydb/_topic_common/test_helpers.py similarity index 92% rename from ydb/_topic_wrapper/test_helpers.py rename to ydb/_topic_common/test_helpers.py index a278046b..bea6fea5 100644 --- a/ydb/_topic_wrapper/test_helpers.py +++ b/ydb/_topic_common/test_helpers.py @@ -2,7 +2,7 @@ import time import typing -from .common import IGrpcWrapperAsyncIO, IToProto +from .._grpc.grpcwrapper.common_utils import IToProto, IGrpcWrapperAsyncIO class StreamMock(IGrpcWrapperAsyncIO): diff --git a/ydb/_topic_reader/topic_reader.py b/ydb/_topic_reader/topic_reader.py index 322df7e8..7bb6d934 100644 --- a/ydb/_topic_reader/topic_reader.py +++ b/ydb/_topic_reader/topic_reader.py @@ -11,8 +11,8 @@ from ..table import RetrySettings from .datatypes import ICommittable, PublicBatch, PublicMessage -from .._topic_wrapper.common import OffsetsRange, TokenGetterFuncType -from .._topic_wrapper.reader import StreamReadMessage +from .._topic_common.common import TokenGetterFuncType +from .._grpc.grpcwrapper.ydb_topic import StreamReadMessage, OffsetsRange class Selector: diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index 50e1a331..95bd1008 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -12,13 +12,15 @@ from ..issues import Error as YdbError, _process_response from .datatypes import PartitionSession, PublicMessage, PublicBatch, ICommittable from .topic_reader import PublicReaderSettings, CommitResult, SessionStat -from .._topic_wrapper.common import ( +from .._topic_common.common import ( TokenGetterFuncType, +) +from .._grpc.grpcwrapper.common_utils import ( IGrpcWrapperAsyncIO, SupportedDriverType, GrpcWrapperAsyncIO, ) -from .._topic_wrapper.reader import StreamReadMessage +from .._grpc.grpcwrapper.ydb_topic import StreamReadMessage from .._errors import check_retriable_error diff --git a/ydb/_topic_reader/topic_reader_asyncio_test.py b/ydb/_topic_reader/topic_reader_asyncio_test.py index 7a002a0b..0fae1bec 100644 --- a/ydb/_topic_reader/topic_reader_asyncio_test.py +++ b/ydb/_topic_reader/topic_reader_asyncio_test.py @@ -8,14 +8,9 @@ from .datatypes import PublicBatch, PublicMessage from .topic_reader import PublicReaderSettings from .topic_reader_asyncio import ReaderStream, PartitionSession, ReaderReconnector -from .._topic_wrapper.common import ( - OffsetsRange, - Codec, - ServerStatus, - SupportedDriverType, -) -from .._topic_wrapper.reader import StreamReadMessage -from .._topic_wrapper.test_helpers import StreamMock, wait_condition, wait_for_fast +from .._grpc.grpcwrapper.common_utils import SupportedDriverType, ServerStatus +from .._grpc.grpcwrapper.ydb_topic import StreamReadMessage, Codec, OffsetsRange +from .._topic_common.test_helpers import StreamMock, wait_condition, wait_for_fast # Workaround for good autocomplete in IDE and universal import at runtime # noinspection PyUnreachableCode diff --git a/ydb/_topic_wrapper/reader.py b/ydb/_topic_wrapper/reader.py deleted file mode 100644 index 88a12778..00000000 --- a/ydb/_topic_wrapper/reader.py +++ /dev/null @@ -1,326 +0,0 @@ -import datetime -from dataclasses import dataclass, field -from typing import List, Union, Dict - -from ydb._topic_wrapper.common import ( - OffsetsRange, - IToProto, - UpdateTokenRequest, - UpdateTokenResponse, - IFromProto, - ServerStatus, -) -from google.protobuf.duration_pb2 import Duration as ProtoDuration - -# Workaround for good autocomplete in IDE and universal import at runtime -# noinspection PyUnreachableCode -if False: - from ydb._grpc.v4.protos import ydb_topic_pb2 -else: - from ydb._grpc.common.protos import ydb_topic_pb2 - - -class StreamReadMessage: - @dataclass - class PartitionSession(IFromProto): - partition_session_id: int - path: str - partition_id: int - - @staticmethod - def from_proto( - msg: ydb_topic_pb2.StreamReadMessage.PartitionSession, - ) -> "StreamReadMessage.PartitionSession": - return StreamReadMessage.PartitionSession( - partition_session_id=msg.partition_session_id, - path=msg.path, - partition_id=msg.partition_id, - ) - - @dataclass - class InitRequest(IToProto): - topics_read_settings: List["StreamReadMessage.InitRequest.TopicReadSettings"] - consumer: str - - def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.InitRequest: - res = ydb_topic_pb2.StreamReadMessage.InitRequest() - res.consumer = self.consumer - for settings in self.topics_read_settings: - res.topics_read_settings.append(settings.to_proto()) - return res - - @dataclass - class TopicReadSettings(IToProto): - path: str - partition_ids: List[int] = field(default_factory=list) - max_lag_seconds: Union[datetime.timedelta, None] = None - read_from: Union[int, float, datetime.datetime, None] = None - - def to_proto( - self, - ) -> ydb_topic_pb2.StreamReadMessage.InitRequest.TopicReadSettings: - res = ydb_topic_pb2.StreamReadMessage.InitRequest.TopicReadSettings() - res.path = self.path - res.partition_ids.extend(self.partition_ids) - if self.max_lag_seconds is not None: - res.max_lag = ProtoDuration() - res.max_lag.FromTimedelta(self.max_lag_seconds) - return res - - @dataclass - class InitResponse(IFromProto): - session_id: str - - @staticmethod - def from_proto( - msg: ydb_topic_pb2.StreamReadMessage.InitResponse, - ) -> "StreamReadMessage.InitResponse": - return StreamReadMessage.InitResponse(session_id=msg.session_id) - - @dataclass - class ReadRequest(IToProto): - bytes_size: int - - def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.ReadRequest: - res = ydb_topic_pb2.StreamReadMessage.ReadRequest() - res.bytes_size = self.bytes_size - return res - - @dataclass - class ReadResponse(IFromProto): - partition_data: List["StreamReadMessage.ReadResponse.PartitionData"] - bytes_size: int - - @staticmethod - def from_proto( - msg: ydb_topic_pb2.StreamReadMessage.ReadResponse, - ) -> "StreamReadMessage.ReadResponse": - partition_data = [] - for proto_partition_data in msg.partition_data: - partition_data.append( - StreamReadMessage.ReadResponse.PartitionData.from_proto( - proto_partition_data - ) - ) - return StreamReadMessage.ReadResponse( - partition_data=partition_data, - bytes_size=msg.bytes_size, - ) - - @dataclass - class MessageData(IFromProto): - offset: int - seq_no: int - created_at: datetime.datetime - data: bytes - uncompresed_size: int - message_group_id: str - - @staticmethod - def from_proto( - msg: ydb_topic_pb2.StreamReadMessage.ReadResponse.MessageData, - ) -> "StreamReadMessage.ReadResponse.MessageData": - return StreamReadMessage.ReadResponse.MessageData( - offset=msg.offset, - seq_no=msg.seq_no, - created_at=msg.created_at.ToDatetime(), - data=msg.data, - uncompresed_size=msg.uncompressed_size, - message_group_id=msg.message_group_id, - ) - - @dataclass - class Batch(IFromProto): - message_data: List["StreamReadMessage.ReadResponse.MessageData"] - producer_id: str - write_session_meta: Dict[str, str] - codec: int - written_at: datetime.datetime - - @staticmethod - def from_proto( - msg: ydb_topic_pb2.StreamReadMessage.ReadResponse.Batch, - ) -> "StreamReadMessage.ReadResponse.Batch": - message_data = [] - for message in msg.message_data: - message_data.append( - StreamReadMessage.ReadResponse.MessageData.from_proto(message) - ) - return StreamReadMessage.ReadResponse.Batch( - message_data=message_data, - producer_id=msg.producer_id, - write_session_meta=dict(msg.write_session_meta), - codec=msg.codec, - written_at=msg.written_at.ToDatetime(), - ) - - @dataclass - class PartitionData(IFromProto): - partition_session_id: int - batches: List["StreamReadMessage.ReadResponse.Batch"] - - @staticmethod - def from_proto( - msg: ydb_topic_pb2.StreamReadMessage.ReadResponse.PartitionData, - ) -> "StreamReadMessage.ReadResponse.PartitionData": - batches = [] - for proto_batch in msg.batches: - batches.append( - StreamReadMessage.ReadResponse.Batch.from_proto(proto_batch) - ) - return StreamReadMessage.ReadResponse.PartitionData( - partition_session_id=msg.partition_session_id, - batches=batches, - ) - - @dataclass - class CommitOffsetRequest: - commit_offsets: List["PartitionCommitOffset"] - - @dataclass - class PartitionCommitOffset: - partition_session_id: int - offsets: List[OffsetsRange] - - @dataclass - class CommitOffsetResponse: - partitions_committed_offsets: List["PartitionCommittedOffset"] - - @dataclass - class PartitionCommittedOffset: - partition_session_id: int - committed_offset: int - - @dataclass - class PartitionSessionStatusRequest: - partition_session_id: int - - @dataclass - class PartitionSessionStatusResponse: - partition_session_id: int - partition_offsets: OffsetsRange - committed_offset: int - write_time_high_watermark: float - - @dataclass - class StartPartitionSessionRequest(IFromProto): - partition_session: "StreamReadMessage.PartitionSession" - committed_offset: int - partition_offsets: OffsetsRange - - @staticmethod - def from_proto( - msg: ydb_topic_pb2.StreamReadMessage.StartPartitionSessionRequest, - ) -> "StreamReadMessage.StartPartitionSessionRequest": - return StreamReadMessage.StartPartitionSessionRequest( - partition_session=StreamReadMessage.PartitionSession.from_proto( - msg.partition_session - ), - committed_offset=msg.committed_offset, - partition_offsets=OffsetsRange.from_proto(msg.partition_offsets), - ) - - @dataclass - class StartPartitionSessionResponse(IToProto): - partition_session_id: int - read_offset: int - commit_offset: int - - def to_proto( - self, - ) -> ydb_topic_pb2.StreamReadMessage.StartPartitionSessionResponse: - res = ydb_topic_pb2.StreamReadMessage.StartPartitionSessionResponse() - res.partition_session_id = self.partition_session_id - res.read_offset = self.read_offset - res.commit_offset = self.commit_offset - return res - - @dataclass - class StopPartitionSessionRequest: - partition_session_id: int - graceful: bool - committed_offset: int - - @dataclass - class StopPartitionSessionResponse: - partition_session_id: int - - @dataclass - class FromClient(IToProto): - client_message: "ReaderMessagesFromClientToServer" - - def __init__(self, client_message: "ReaderMessagesFromClientToServer"): - self.client_message = client_message - - def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.FromClient: - res = ydb_topic_pb2.StreamReadMessage.FromClient() - if isinstance(self.client_message, StreamReadMessage.ReadRequest): - res.read_request.CopyFrom(self.client_message.to_proto()) - elif isinstance(self.client_message, StreamReadMessage.InitRequest): - res.init_request.CopyFrom(self.client_message.to_proto()) - elif isinstance( - self.client_message, StreamReadMessage.StartPartitionSessionResponse - ): - res.start_partition_session_response.CopyFrom( - self.client_message.to_proto() - ) - else: - raise NotImplementedError() - return res - - @dataclass - class FromServer(IFromProto): - server_message: "ReaderMessagesFromServerToClient" - server_status: ServerStatus - - @staticmethod - def from_proto( - msg: ydb_topic_pb2.StreamReadMessage.FromServer, - ) -> "StreamReadMessage.FromServer": - mess_type = msg.WhichOneof("server_message") - server_status = ServerStatus.from_proto(msg) - if mess_type == "read_response": - return StreamReadMessage.FromServer( - server_status=server_status, - server_message=StreamReadMessage.ReadResponse.from_proto( - msg.read_response - ), - ) - elif mess_type == "init_response": - return StreamReadMessage.FromServer( - server_status=server_status, - server_message=StreamReadMessage.InitResponse.from_proto( - msg.init_response - ), - ) - elif mess_type == "start_partition_session_request": - return StreamReadMessage.FromServer( - server_status=server_status, - server_message=StreamReadMessage.StartPartitionSessionRequest.from_proto( - msg.start_partition_session_request - ), - ) - - # todo replace exception to log - raise NotImplementedError() - - -ReaderMessagesFromClientToServer = Union[ - StreamReadMessage.InitRequest, - StreamReadMessage.ReadRequest, - StreamReadMessage.CommitOffsetRequest, - StreamReadMessage.PartitionSessionStatusRequest, - UpdateTokenRequest, - StreamReadMessage.StartPartitionSessionResponse, - StreamReadMessage.StopPartitionSessionResponse, -] - -ReaderMessagesFromServerToClient = Union[ - StreamReadMessage.InitResponse, - StreamReadMessage.ReadResponse, - StreamReadMessage.CommitOffsetResponse, - StreamReadMessage.PartitionSessionStatusResponse, - UpdateTokenResponse, - StreamReadMessage.StartPartitionSessionRequest, - StreamReadMessage.StopPartitionSessionRequest, -] diff --git a/ydb/_topic_wrapper/writer.py b/ydb/_topic_wrapper/writer.py deleted file mode 100644 index 6710f544..00000000 --- a/ydb/_topic_wrapper/writer.py +++ /dev/null @@ -1,289 +0,0 @@ -import datetime -import enum -import typing -from dataclasses import dataclass, field -from typing import Union - -from google.protobuf.message import Message - -from ydb._topic_wrapper.common import ( - IToProto, - IFromProto, - ServerStatus, - UpdateTokenRequest, - UpdateTokenResponse, - UnknownGrpcMessageError, -) - -# Workaround for good autocomplete in IDE and universal import at runtime -if False: - from ydb._grpc.v4.protos import ydb_topic_pb2 -else: - from ydb._grpc.common.protos import ydb_topic_pb2 - - -class StreamWriteMessage: - @dataclass() - class InitRequest(IToProto): - path: str - producer_id: str - write_session_meta: typing.Dict[str, str] - partitioning: "StreamWriteMessage.PartitioningType" - get_last_seq_no: bool - - def to_proto(self) -> ydb_topic_pb2.StreamWriteMessage.InitRequest: - proto = ydb_topic_pb2.StreamWriteMessage.InitRequest() - proto.path = self.path - proto.producer_id = self.producer_id - - if self.partitioning is None: - pass - elif isinstance( - self.partitioning, StreamWriteMessage.PartitioningMessageGroupID - ): - proto.message_group_id = self.partitioning.message_group_id - elif isinstance( - self.partitioning, StreamWriteMessage.PartitioningPartitionID - ): - proto.partition_id = self.partitioning.partition_id - else: - raise Exception( - "Bad partitioning type at StreamWriteMessage.InitRequest" - ) - - if self.write_session_meta: - for key in self.write_session_meta: - proto.write_session_meta[key] = self.write_session_meta[key] - - proto.get_last_seq_no = self.get_last_seq_no - return proto - - @dataclass - class InitResponse(IFromProto): - last_seq_no: Union[int, None] - session_id: str - partition_id: int - supported_codecs: typing.List[int] - status: ServerStatus = None - - @staticmethod - def from_proto( - msg: ydb_topic_pb2.StreamWriteMessage.InitResponse, - ) -> "StreamWriteMessage.InitResponse": - codecs = [] # type: typing.List[int] - if msg.supported_codecs: - for codec in msg.supported_codecs.codecs: - codecs.append(codec) - - return StreamWriteMessage.InitResponse( - last_seq_no=msg.last_seq_no, - session_id=msg.session_id, - partition_id=msg.partition_id, - supported_codecs=codecs, - ) - - @dataclass - class WriteRequest(IToProto): - messages: typing.List["StreamWriteMessage.WriteRequest.MessageData"] - codec: int - - @dataclass - class MessageData(IToProto): - seq_no: int - created_at: datetime.datetime - data: bytes - uncompressed_size: int - partitioning: "StreamWriteMessage.PartitioningType" - - def to_proto( - self, - ) -> ydb_topic_pb2.StreamWriteMessage.WriteRequest.MessageData: - proto = ydb_topic_pb2.StreamWriteMessage.WriteRequest.MessageData() - proto.seq_no = self.seq_no - proto.created_at.FromDatetime(self.created_at) - proto.data = self.data - proto.uncompressed_size = self.uncompressed_size - - if self.partitioning is None: - pass - elif isinstance( - self.partitioning, StreamWriteMessage.PartitioningPartitionID - ): - proto.partition_id = self.partitioning.partition_id - elif isinstance( - self.partitioning, StreamWriteMessage.PartitioningMessageGroupID - ): - proto.message_group_id = self.partitioning.message_group_id - else: - raise Exception( - "Bad partition at StreamWriteMessage.WriteRequest.MessageData" - ) - - return proto - - def to_proto(self) -> ydb_topic_pb2.StreamWriteMessage.WriteRequest: - proto = ydb_topic_pb2.StreamWriteMessage.WriteRequest() - proto.codec = self.codec - - for message in self.messages: - proto_mess = proto.messages.add() - proto_mess.CopyFrom(message.to_proto()) - - return proto - - @dataclass - class WriteResponse(IFromProto): - partition_id: int - acks: typing.List["StreamWriteMessage.WriteResponse.WriteAck"] - write_statistics: "StreamWriteMessage.WriteResponse.WriteStatistics" - status: ServerStatus = field(default=None) - - @staticmethod - def from_proto( - msg: ydb_topic_pb2.StreamWriteMessage.WriteResponse, - ) -> "StreamWriteMessage.WriteResponse": - acks = [] - for proto_ack in msg.acks: - ack = StreamWriteMessage.WriteResponse.WriteAck.from_proto(proto_ack) - acks.append(ack) - write_statistics = StreamWriteMessage.WriteResponse.WriteStatistics( - persisting_time=msg.write_statistics.persisting_time.ToTimedelta(), - min_queue_wait_time=msg.write_statistics.min_queue_wait_time.ToTimedelta(), - max_queue_wait_time=msg.write_statistics.max_queue_wait_time.ToTimedelta(), - partition_quota_wait_time=msg.write_statistics.partition_quota_wait_time.ToTimedelta(), - topic_quota_wait_time=msg.write_statistics.topic_quota_wait_time.ToTimedelta(), - ) - return StreamWriteMessage.WriteResponse( - partition_id=msg.partition_id, - acks=acks, - write_statistics=write_statistics, - status=None, - ) - - @dataclass - class WriteAck(IFromProto): - seq_no: int - message_write_status: Union[ - "StreamWriteMessage.WriteResponse.WriteAck.StatusWritten", - "StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped", - int, - ] - - @classmethod - def from_proto( - cls, proto_ack: ydb_topic_pb2.StreamWriteMessage.WriteResponse.WriteAck - ): - if proto_ack.HasField("written"): - message_write_status = ( - StreamWriteMessage.WriteResponse.WriteAck.StatusWritten( - proto_ack.written.offset - ) - ) - elif proto_ack.HasField("skipped"): - reason = proto_ack.skipped.reason - try: - message_write_status = StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped( - reason=StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped.Reason.from_protobuf_code( - reason - ) - ) - except ValueError: - message_write_status = reason - else: - raise NotImplementedError("unexpected ack status") - - return StreamWriteMessage.WriteResponse.WriteAck( - seq_no=proto_ack.seq_no, - message_write_status=message_write_status, - ) - - @dataclass - class StatusWritten: - offset: int - - @dataclass - class StatusSkipped: - reason: "StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped.Reason" - - class Reason(enum.Enum): - UNSPECIFIED = 0 - ALREADY_WRITTEN = 1 - - @classmethod - def from_protobuf_code( - cls, code: int - ) -> Union[ - "StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped.Reason", - int, - ]: - try: - return StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped.Reason( - code - ) - except ValueError: - return code - - @dataclass - class WriteStatistics: - persisting_time: datetime.timedelta - min_queue_wait_time: datetime.timedelta - max_queue_wait_time: datetime.timedelta - partition_quota_wait_time: datetime.timedelta - topic_quota_wait_time: datetime.timedelta - - @dataclass - class PartitioningMessageGroupID: - message_group_id: str - - @dataclass - class PartitioningPartitionID: - partition_id: int - - PartitioningType = Union[PartitioningMessageGroupID, PartitioningPartitionID, None] - - @dataclass - class FromClient(IToProto): - value: "WriterMessagesFromClientToServer" - - def __init__(self, value: "WriterMessagesFromClientToServer"): - self.value = value - - def to_proto(self) -> Message: - res = ydb_topic_pb2.StreamWriteMessage.FromClient() - value = self.value - if isinstance(value, StreamWriteMessage.WriteRequest): - res.write_request.CopyFrom(value.to_proto()) - elif isinstance(value, StreamWriteMessage.InitRequest): - res.init_request.CopyFrom(value.to_proto()) - elif isinstance(value, UpdateTokenRequest): - res.update_token_request.CopyFrom(value.to_proto()) - else: - raise Exception("Unknown outcoming grpc message: %s" % value) - return res - - class FromServer(IFromProto): - @staticmethod - def from_proto(msg: ydb_topic_pb2.StreamWriteMessage.FromServer) -> typing.Any: - message_type = msg.WhichOneof("server_message") - if message_type == "write_response": - res = StreamWriteMessage.WriteResponse.from_proto(msg.write_response) - elif message_type == "init_response": - res = StreamWriteMessage.InitResponse.from_proto(msg.init_response) - elif message_type == "update_token_response": - res = UpdateTokenResponse.from_proto(msg.update_token_response) - else: - # todo log instead of exception - for allow add messages in the future - raise UnknownGrpcMessageError("Unexpected proto message: %s" % msg) - - res.status = ServerStatus(msg.status, msg.issues) - return res - - -WriterMessagesFromClientToServer = Union[ - StreamWriteMessage.InitRequest, StreamWriteMessage.WriteRequest, UpdateTokenRequest -] -WriterMessagesFromServerToClient = Union[ - StreamWriteMessage.InitResponse, - StreamWriteMessage.WriteResponse, - UpdateTokenResponse, -] diff --git a/ydb/_topic_writer/topic_writer.py b/ydb/_topic_writer/topic_writer.py index ecc20e10..f3b0b3ab 100644 --- a/ydb/_topic_writer/topic_writer.py +++ b/ydb/_topic_writer/topic_writer.py @@ -8,8 +8,8 @@ import typing import ydb.aio -from .._topic_wrapper.common import IToProto, Codec -from .._topic_wrapper.writer import StreamWriteMessage +from .._grpc.grpcwrapper.ydb_topic import Codec, StreamWriteMessage +from .._grpc.grpcwrapper.common_utils import IToProto class Writer: diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index b3999659..a231a6b5 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -22,14 +22,19 @@ check_retriable_error, RetrySettings, ) -from .._topic_wrapper.common import ( +from .._topic_common.common import ( + TokenGetterFuncType, +) +from .._grpc.grpcwrapper.ydb_topic import ( UpdateTokenResponse, - GrpcWrapperAsyncIO, + StreamWriteMessage, + WriterMessagesFromServerToClient, +) +from .._grpc.grpcwrapper.common_utils import ( IGrpcWrapperAsyncIO, SupportedDriverType, - TokenGetterFuncType, + GrpcWrapperAsyncIO, ) -from .._topic_wrapper.writer import StreamWriteMessage, WriterMessagesFromServerToClient class WriterAsyncIO: diff --git a/ydb/_topic_writer/topic_writer_asyncio_test.py b/ydb/_topic_writer/topic_writer_asyncio_test.py index 1d1ca225..32fe9c02 100644 --- a/ydb/_topic_writer/topic_writer_asyncio_test.py +++ b/ydb/_topic_writer/topic_writer_asyncio_test.py @@ -14,7 +14,8 @@ from .. import aio from .. import StatusCode, issues -from .._topic_wrapper.common import ServerStatus, Codec +from .._grpc.grpcwrapper.ydb_topic import Codec, StreamWriteMessage +from .._grpc.grpcwrapper.common_utils import ServerStatus from .topic_writer import ( InternalMessage, PublicMessage, @@ -24,9 +25,8 @@ PublicWriteResult, TopicWriterError, ) -from .._topic_wrapper.test_helpers import StreamMock +from .._topic_common.test_helpers import StreamMock -from .._topic_wrapper.writer import StreamWriteMessage from .topic_writer_asyncio import ( WriterAsyncIOStream, WriterAsyncIOReconnector, diff --git a/ydb/_topic_writer/topic_writer_sync.py b/ydb/_topic_writer/topic_writer_sync.py index 9c39e5e6..d8c66213 100644 --- a/ydb/_topic_writer/topic_writer_sync.py +++ b/ydb/_topic_writer/topic_writer_sync.py @@ -5,7 +5,7 @@ import threading from typing import Union, List, Optional, Coroutine -from .._topic_wrapper.common import SupportedDriverType +from .._grpc.grpcwrapper.common_utils import SupportedDriverType from .topic_writer import ( PublicWriterSettings, TopicWriterError, From dbc22a16daeed986e920b5feb6584ae02bd1b147 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Wed, 8 Feb 2023 13:45:20 +0300 Subject: [PATCH 038/147] add topic control plane async --- CHANGELOG.md | 2 + tests/conftest.py | 35 +- tests/topics/test_control_plane.py | 40 ++ ydb/_apis.py | 3 + ydb/_grpc/grpcwrapper/common_utils.py | 86 +++- ydb/_grpc/grpcwrapper/ydb_scheme.py | 36 ++ ydb/_grpc/grpcwrapper/ydb_topic.py | 428 +++++++++++++++++- .../grpcwrapper/ydb_topic_public_types.py | 154 +++++++ ydb/_topic_common/common.py | 29 ++ ydb/_topic_common/control_plane.py | 13 - ydb/_utilities.py | 1 + ydb/topic.py | 87 +++- 12 files changed, 853 insertions(+), 61 deletions(-) create mode 100644 tests/topics/test_control_plane.py create mode 100644 ydb/_grpc/grpcwrapper/ydb_scheme.py create mode 100644 ydb/_grpc/grpcwrapper/ydb_topic_public_types.py delete mode 100644 ydb/_topic_common/control_plane.py diff --git a/CHANGELOG.md b/CHANGELOG.md index c0b52c2b..61f06737 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,5 @@ +* Add control plane operations for topic api: create, drop + ## 3.0.1b4 ## * Initial implementation of topic reader diff --git a/tests/conftest.py b/tests/conftest.py index 17d2801a..c4a62526 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,7 +4,7 @@ import pytest import ydb import time -import subprocess +from ydb import issues @pytest.fixture(autouse=True, scope="session") @@ -105,30 +105,21 @@ def topic_consumer(): @pytest.fixture() -def topic_path(endpoint, topic_consumer) -> str: - subprocess.run( - """docker-compose exec -T ydb /ydb -e grpc://%s -d /local topic drop /local/test-topic""" - % endpoint, - shell=True, - capture_output=True, - ) - res = subprocess.run( - """docker-compose exec -T ydb /ydb -e grpc://%s -d /local topic create /local/test-topic""" - % endpoint, - shell=True, - capture_output=True, - ) - assert res.returncode == 0, res.stderr + res.stdout +@pytest.mark.asyncio() +async def topic_path(driver, topic_consumer, database) -> str: + topic_path = database + "/test-topic" + + try: + await driver.topic_client.drop_topic(topic_path) + except issues.SchemeError: + pass - res = subprocess.run( - """docker-compose exec -T ydb /ydb -e grpc://%s -d /local topic consumer add --consumer %s /local/test-topic""" - % (endpoint, topic_consumer), - shell=True, - capture_output=True, + await driver.topic_client.create_topic( + path=topic_path, + consumers=[topic_consumer], ) - assert res.returncode == 0, res.stderr + res.stdout - return "/local/test-topic" + return topic_path @pytest.fixture() diff --git a/tests/topics/test_control_plane.py b/tests/topics/test_control_plane.py new file mode 100644 index 00000000..615bd271 --- /dev/null +++ b/tests/topics/test_control_plane.py @@ -0,0 +1,40 @@ +import os.path + +import pytest + +from ydb import issues + + +@pytest.mark.asyncio +class TestTopicClientControlPlaneAsyncIO: + async def test_create_topic(self, driver, database): + client = driver.topic_client + + topic_path = database + "/my-test-topic" + + await client.create_topic(topic_path) + + with pytest.raises(issues.SchemeError): + # double create is ok - try create topic with bad path + await client.create_topic(database) + + async def test_drop_topic(self, driver, topic_path): + client = driver.topic_client + + await client.drop_topic(topic_path) + + with pytest.raises(issues.SchemeError): + await client.drop_topic(topic_path) + + async def test_describe_topic(self, driver, topic_path: str, topic_consumer): + res = await driver.topic_client.describe(topic_path) + + assert res.self.name == os.path.basename(topic_path) + + has_consumer = False + for consumer in res.consumers: + if consumer.name == topic_consumer: + has_consumer = True + break + + assert has_consumer diff --git a/ydb/_apis.py b/ydb/_apis.py index 342c485f..557871f8 100644 --- a/ydb/_apis.py +++ b/ydb/_apis.py @@ -103,5 +103,8 @@ class TableService(object): class TopicService(object): Stub = ydb_topic_v1_pb2_grpc.TopicServiceStub + CreateTopic = "CreateTopic" + DescribeTopic = "DescribeTopic" + DropTopic = "DropTopic" StreamRead = "StreamRead" StreamWrite = "StreamWrite" diff --git a/ydb/_grpc/grpcwrapper/common_utils.py b/ydb/_grpc/grpcwrapper/common_utils.py index f4c0f6e2..3af5f363 100644 --- a/ydb/_grpc/grpcwrapper/common_utils.py +++ b/ydb/_grpc/grpcwrapper/common_utils.py @@ -1,10 +1,25 @@ +from __future__ import annotations + import abc import asyncio +import datetime import typing +from typing import ( + Optional, + Any, + Iterator, + AsyncIterator, + Callable, + Iterable, + Union, + Coroutine, +) from dataclasses import dataclass import grpc from google.protobuf.message import Message +from google.protobuf.duration_pb2 import Duration as ProtoDuration +from google.protobuf.timestamp_pb2 import Timestamp as ProtoTimeStamp import ydb.aio @@ -21,14 +36,35 @@ class IFromProto(abc.ABC): @staticmethod @abc.abstractmethod - def from_proto(msg: Message) -> typing.Any: - pass + def from_proto(msg: Message) -> Any: + ... + + +class IFromProtoWithProtoType(IFromProto): + @staticmethod + @abc.abstractmethod + def empty_proto_message() -> Message: + ... class IToProto(abc.ABC): @abc.abstractmethod def to_proto(self) -> Message: - pass + ... + + +class IFromPublic(abc.ABC): + + @staticmethod + @abc.abstractmethod + def from_public(o: typing.Any) -> typing.Any: + ... + + +class IToPublic(abc.ABC): + @abc.abstractmethod + def to_public(self) -> typing.Any: + ... class UnknownGrpcMessageError(issues.Error): @@ -76,7 +112,7 @@ def __next__(self): class SyncIteratorToAsyncIterator: - def __init__(self, sync_iterator: typing.Iterator): + def __init__(self, sync_iterator: Iterator): self._sync_iterator = sync_iterator def __aiter__(self): @@ -92,7 +128,7 @@ async def __anext__(self): class IGrpcWrapperAsyncIO(abc.ABC): @abc.abstractmethod - async def receive(self) -> typing.Any: + async def receive(self) -> Any: ... @abc.abstractmethod @@ -100,13 +136,13 @@ def write(self, wrap_message: IToProto): ... -SupportedDriverType = typing.Union[ydb.Driver, ydb.aio.Driver] +SupportedDriverType = Union[ydb.Driver, ydb.aio.Driver] class GrpcWrapperAsyncIO(IGrpcWrapperAsyncIO): from_client_grpc: asyncio.Queue - from_server_grpc: typing.AsyncIterator - convert_server_grpc_to_wrapper: typing.Callable[[typing.Any], typing.Any] + from_server_grpc: AsyncIterator + convert_server_grpc_to_wrapper: Callable[[Any], Any] _connection_state: str def __init__(self, convert_server_grpc_to_wrapper): @@ -140,7 +176,7 @@ async def _start_sync_driver(self, driver: ydb.Driver, stub, method): ) self.from_server_grpc = SyncIteratorToAsyncIterator(stream_call.__iter__()) - async def receive(self) -> typing.Any: + async def receive(self) -> Any: # todo handle grpc exceptions and convert it to internal exceptions try: grpc_message = await self.from_server_grpc.__anext__() @@ -168,7 +204,7 @@ class ServerStatus(IFromProto): def __init__( self, status: issues.StatusCode, - issues: typing.Iterable[typing.Any], + issues: Iterable[Any], ): self.status = status self.issues = issues @@ -178,7 +214,7 @@ def __str__(self): @staticmethod def from_proto( - msg: typing.Union[ + msg: Union[ ydb_topic_pb2.StreamReadMessage.FromServer, ydb_topic_pb2.StreamWriteMessage.FromServer, ] @@ -198,7 +234,7 @@ def issue_to_str(cls, issue: ydb_issue_message_pb2.IssueMessage): def callback_from_asyncio( - callback: typing.Union[typing.Callable, typing.Coroutine] + callback: Union[Callable, Coroutine] ) -> [asyncio.Future, asyncio.Task]: loop = asyncio.get_running_loop() @@ -206,3 +242,29 @@ def callback_from_asyncio( return loop.create_task(callback()) else: return loop.run_in_executor(None, callback) + + +def proto_duration_from_timedelta(t: Optional[datetime.timedelta]) -> ProtoDuration: + if t is None: + return None + res = ProtoDuration() + res.FromTimedelta(t) + + +def proto_timestamp_from_datetime(t: Optional[datetime.datetime]) -> ProtoTimeStamp: + if t is None: + return None + + res = ProtoTimeStamp() + res.FromDatetime(t) + + +def datetime_from_proto_timestamp(ts: Optional[ProtoTimeStamp]) -> Optional[datetime.datetime]: + if ts is None: + return None + return ts.ToDatetime() + +def timedelta_from_proto_duration(d: Optional[ProtoDuration]) -> Optional[datetime.timedelta]: + if d is None: + return None + return d.ToTimedelta() diff --git a/ydb/_grpc/grpcwrapper/ydb_scheme.py b/ydb/_grpc/grpcwrapper/ydb_scheme.py new file mode 100644 index 00000000..b9922035 --- /dev/null +++ b/ydb/_grpc/grpcwrapper/ydb_scheme.py @@ -0,0 +1,36 @@ +import datetime +import enum +from dataclasses import dataclass +from typing import List + + +@dataclass +class Entry: + name: str + owner: str + type: "Entry.Type" + effective_permissions: "Permissions" + permissions: "Permissions" + size_bytes: int + created_at: datetime.datetime + + class Type(enum.IntEnum): + UNSPECIFIED = 0 + DIRECTORY = 1 + TABLE = 2 + PERS_QUEUE_GROUP = 3 + DATABASE = 4 + RTMR_VOLUME = 5 + BLOCK_STORE_VOLUME = 6 + COORDINATION_NODE = 7 + COLUMN_STORE = 12 + COLUMN_TABLE = 13 + SEQUENCE = 15 + REPLICATION = 16 + TOPIC = 17 + + +@dataclass +class Permissions: + subject: str + permission_names: List[str] diff --git a/ydb/_grpc/grpcwrapper/ydb_topic.py b/ydb/_grpc/grpcwrapper/ydb_topic.py index df43b803..888363fe 100644 --- a/ydb/_grpc/grpcwrapper/ydb_topic.py +++ b/ydb/_grpc/grpcwrapper/ydb_topic.py @@ -3,28 +3,69 @@ import typing from dataclasses import dataclass, field from enum import IntEnum -from typing import List, Union, Dict +from typing import List, Union, Dict, Optional -from google.protobuf.duration_pb2 import Duration as ProtoDuration from google.protobuf.message import Message +from . import ydb_topic_public_types +from ... import scheme + # Workaround for good IDE and universal for runtime # noinspection PyUnreachableCode if False: - from ..v4.protos import ydb_topic_pb2 + from ..v4.protos import ydb_scheme_pb2, ydb_topic_pb2 else: - from ..common.protos import ydb_topic_pb2 - -from .common_utils import IFromProto, IToProto, ServerStatus, UnknownGrpcMessageError - - -class Codec(IntEnum): + from ..common.protos import ydb_scheme_pb2, ydb_topic_pb2 + +from .common_utils import ( + IFromProto, + IFromProtoWithProtoType, + IToProto, + IToPublic, + IFromPublic, + ServerStatus, + UnknownGrpcMessageError, + proto_duration_from_timedelta, + proto_timestamp_from_datetime, datetime_from_proto_timestamp, timedelta_from_proto_duration, +) + + +class Codec(int, IToPublic): CODEC_UNSPECIFIED = 0 CODEC_RAW = 1 CODEC_GZIP = 2 CODEC_LZOP = 3 CODEC_ZSTD = 4 + @staticmethod + def from_proto_iterable(codecs: typing.Iterable[int]) -> List["Codec"]: + return [Codec(int(codec)) for codec in codecs] + + def to_public(self) -> ydb_topic_public_types.PublicCodec: + return ydb_topic_public_types.PublicCodec(int(self)) + + +@dataclass +class SupportedCodecs(IToProto, IFromProto, IToPublic): + codecs: List[Codec] + + def to_proto(self) -> ydb_topic_pb2.SupportedCodecs: + return ydb_topic_pb2.SupportedCodecs( + codecs=self.codecs, + ) + + @staticmethod + def from_proto(msg: Optional[ydb_topic_pb2.SupportedCodecs]) -> "SupportedCodecs": + if msg is None: + return SupportedCodecs(codecs=[]) + + return SupportedCodecs( + codecs=Codec.from_proto_iterable(msg.codecs), + ) + + def to_public(self) -> List[ydb_topic_public_types.PublicCodec]: + return list(map(Codec.to_public, self.codecs)) + @dataclass class OffsetsRange(IFromProto): @@ -175,7 +216,7 @@ class WriteResponse(IFromProto): partition_id: int acks: typing.List["StreamWriteMessage.WriteResponse.WriteAck"] write_statistics: "StreamWriteMessage.WriteResponse.WriteStatistics" - status: ServerStatus = field(default=None) + status: Optional[ServerStatus] = field(default=None) @staticmethod def from_proto( @@ -376,8 +417,7 @@ def to_proto( res.path = self.path res.partition_ids.extend(self.partition_ids) if self.max_lag_seconds is not None: - res.max_lag = ProtoDuration() - res.max_lag.FromTimedelta(self.max_lag_seconds) + res.max_lag = proto_duration_from_timedelta(self.max_lag_seconds) return res @dataclass @@ -637,3 +677,367 @@ def from_proto( StreamReadMessage.StartPartitionSessionRequest, StreamReadMessage.StopPartitionSessionRequest, ] + + +@dataclass +class MultipleWindowsStat(IFromProto, IToPublic): + per_minute: int + per_hour: int + per_day: int + + @staticmethod + def from_proto(msg: Optional[ydb_topic_pb2.MultipleWindowsStat]) -> Optional["MultipleWindowsStat"]: + if msg is None: + return None + return MultipleWindowsStat( + per_minute=msg.per_minute, + per_hour=msg.per_hour, + per_day=msg.per_day, + ) + + def to_public(self) -> ydb_topic_public_types.PublicMultipleWindowsStat: + return ydb_topic_public_types.PublicMultipleWindowsStat( + per_minute=self.per_minute, + per_hour=self.per_hour, + per_day=self.per_day, + ) + +@dataclass +class Consumer(IToProto, IFromProto, IFromPublic, IToPublic): + name: str + important: bool + read_from: typing.Optional[datetime.datetime] + supported_codecs: SupportedCodecs + attributes: Dict[str, str] + consumer_stats: typing.Optional["Consumer.ConsumerStats"] + + def to_proto(self) -> ydb_topic_pb2.Consumer: + return ydb_topic_pb2.Consumer( + name=self.name, + important=self.important, + read_from=proto_timestamp_from_datetime(self.read_from), + supported_codecs=self.supported_codecs.to_proto(), + attributes=self.attributes, + # consumer_stats - readonly field + ) + + @staticmethod + def from_proto(msg: Optional[ydb_topic_pb2.Consumer]) -> Optional["Consumer"]: + return Consumer( + name=msg.name, + important=msg.important, + read_from=datetime_from_proto_timestamp(msg.read_from), + supported_codecs=SupportedCodecs.from_proto(msg.supported_codecs), + attributes=dict(msg.attributes), + consumer_stats=Consumer.ConsumerStats.from_proto(msg.consumer_stats), + ) + + @staticmethod + def from_public(consumer: ydb_topic_public_types.PublicConsumer): + if consumer is None: + return None + + supported_codecs = [] + if consumer.supported_codecs is not None: + supported_codecs = consumer.supported_codecs + + return Consumer( + name=consumer.name, + important=consumer.important, + read_from=consumer.read_from, + supported_codecs=SupportedCodecs( + codecs=supported_codecs + ), + attributes=consumer.attributes, + consumer_stats=None, + ) + + def to_public(self) -> ydb_topic_public_types.PublicConsumer: + return ydb_topic_public_types.PublicConsumer( + name=self.name, + important=self.important, + read_from=self.read_from, + supported_codecs=self.supported_codecs.to_public(), + attributes=self.attributes, + ) + + @dataclass + class ConsumerStats(IFromProto): + min_partitions_last_read_time: datetime.datetime + max_read_time_lag: datetime.timedelta + max_write_time_lag: datetime.timedelta + bytes_read: MultipleWindowsStat + + @staticmethod + def from_proto(msg: ydb_topic_pb2.Consumer.ConsumerStats) -> "Consumer.ConsumerStats": + return Consumer.ConsumerStats( + min_partitions_last_read_time=datetime_from_proto_timestamp(msg.min_partitions_last_read_time), + max_read_time_lag=timedelta_from_proto_duration(msg.max_read_time_lag), + max_write_time_lag=timedelta_from_proto_duration(msg.max_write_time_lag), + bytes_read=MultipleWindowsStat.from_proto(msg.bytes_read), + ) + + +@dataclass +class PartitioningSettings(IToProto, IFromProto): + min_active_partitions: int + partition_count_limit: int + + @staticmethod + def from_proto(msg: ydb_topic_pb2.PartitioningSettings) -> "PartitioningSettings": + return PartitioningSettings( + min_active_partitions=msg.min_active_partitions, + partition_count_limit=msg.partition_count_limit, + ) + + def to_proto(self) -> ydb_topic_pb2.PartitioningSettings: + return ydb_topic_pb2.PartitioningSettings( + min_active_partitions=self.min_active_partitions, + partition_count_limit=self.partition_count_limit, + ) + + +class MeteringMode(int, IFromProto, IFromPublic, IToPublic): + UNSPECIFIED = 0 + RESERVED_CAPACITY = 1 + REQUEST_UNITS = 2 + + @staticmethod + def from_public(m: Optional[ydb_topic_public_types.PublicMeteringMode]) -> Optional["MeteringMode"]: + if m is None: + return None + + return MeteringMode(m) + + @staticmethod + def from_proto(code: Optional[int]) -> Optional["MeteringMode"]: + if code is None: + return None + + return MeteringMode(code) + + def to_public(self) -> ydb_topic_public_types.PublicMeteringMode: + try: + ydb_topic_public_types.PublicMeteringMode(int(self)) + except KeyError: + return ydb_topic_public_types.PublicMeteringMode.UNSPECIFIED + + +@dataclass +class CreateTopicRequest(IToProto, IFromPublic): + path: str + partitioning_settings: "PartitioningSettings" + retention_period: typing.Optional[datetime.timedelta] + retention_storage_mb: typing.Optional[int] + supported_codecs: "SupportedCodecs" + partition_write_speed_bytes_per_second: typing.Optional[int] + partition_write_burst_bytes: typing.Optional[int] + attributes: Dict[str, str] + consumers: List["Consumer"] + metering_mode: "MeteringMode" + + def to_proto(self) -> ydb_topic_pb2.CreateTopicRequest: + return ydb_topic_pb2.CreateTopicRequest( + path=self.path, + partitioning_settings=self.partitioning_settings.to_proto(), + retention_period=proto_duration_from_timedelta(self.retention_period), + retention_storage_mb=self.retention_storage_mb, + supported_codecs=self.supported_codecs.to_proto(), + partition_write_speed_bytes_per_second=self.partition_write_speed_bytes_per_second, + partition_write_burst_bytes=self.partition_write_burst_bytes, + attributes=self.attributes, + consumers=[consumer.to_proto() for consumer in self.consumers], + metering_mode=self.metering_mode, + ) + + @staticmethod + def from_public(req: ydb_topic_public_types.CreateTopicRequestParams): + supported_codecs = [] + + if req.supported_codecs is not None: + supported_codecs = req.supported_codecs + + consumers = [] + if req.consumers is not None: + for consumer in req.consumers: + if isinstance(consumer, str): + consumer = ydb_topic_public_types.PublicConsumer(name=consumer) + consumers.append(Consumer.from_public(consumer)) + + return CreateTopicRequest( + path=req.path, + partitioning_settings=PartitioningSettings( + min_active_partitions=req.min_active_partitions, + partition_count_limit=req.partition_count_limit, + ), + retention_period=req.retention_period, + retention_storage_mb=req.retention_storage_mb, + supported_codecs=SupportedCodecs( + codecs=supported_codecs, + ), + partition_write_speed_bytes_per_second = req.partition_write_speed_bytes_per_second, + partition_write_burst_bytes=req.partition_write_burst_bytes, + attributes=req.attributes, + consumers=consumers, + metering_mode=MeteringMode.from_public(req.metering_mode), + ) + + + +@dataclass +class CreateTopicResult: + pass + + +@dataclass +class DescribeTopicRequest: + path: str + include_stats: bool + + +@dataclass +class DescribeTopicResult(IFromProtoWithProtoType, IToPublic): + self_proto: ydb_scheme_pb2.Entry + partitioning_settings: PartitioningSettings + partitions: List["DescribeTopicResult.PartitionInfo"] + retention_period: datetime.timedelta + retention_storage_mb: int + supported_codecs: SupportedCodecs + partition_write_speed_bytes_per_second: int + partition_write_burst_bytes: int + attributes: Dict[str, str] + consumers: List["Consumer"] + metering_mode: MeteringMode + topic_stats: "DescribeTopicResult.TopicStats" + + @staticmethod + def from_proto(msg: ydb_topic_pb2.DescribeTopicResult) -> "DescribeTopicResult": + return DescribeTopicResult( + self_proto=msg.self, + partitioning_settings=PartitioningSettings.from_proto(msg.partitioning_settings), + partitions=list(map(DescribeTopicResult.PartitionInfo.from_proto, msg.partitions)), + retention_period=msg.retention_period, + retention_storage_mb=msg.retention_storage_mb, + supported_codecs=SupportedCodecs.from_proto(msg.supported_codecs), + partition_write_speed_bytes_per_second=msg.partition_write_speed_bytes_per_second, + partition_write_burst_bytes=msg.partition_write_burst_bytes, + attributes=dict(msg.attributes), + consumers=list(map(Consumer.from_proto, msg.consumers)), + metering_mode=MeteringMode.from_proto(msg.metering_mode), + topic_stats=DescribeTopicResult.TopicStats.from_proto(msg.topic_stats), + ) + + @staticmethod + def empty_proto_message() -> ydb_topic_pb2.DescribeTopicResult: + return ydb_topic_pb2.DescribeTopicResult() + + def to_public(self) -> ydb_topic_public_types.PublicDescribeTopicResult: + return ydb_topic_public_types.PublicDescribeTopicResult( + self=scheme._wrap_scheme_entry(self.self_proto), + min_active_partitions=self.partitioning_settings.min_active_partitions, + partition_count_limit=self.partitioning_settings.partition_count_limit, + partitions=list(map(DescribeTopicResult.PartitionInfo.to_public, self.partitions)), + retention_period=self.retention_period, + retention_storage_mb=self.retention_storage_mb, + supported_codecs=self.supported_codecs.to_public(), + partition_write_speed_bytes_per_second=self.partition_write_speed_bytes_per_second, + partition_write_burst_bytes=self.partition_write_burst_bytes, + attributes=self.attributes, + consumers=list(map(Consumer.to_public, self.consumers)), + metering_mode=self.metering_mode.to_public(), + topic_stats=self.topic_stats.to_public(), + ) + + @dataclass + class PartitionInfo(IFromProto, IToPublic): + partition_id: int + active: bool + child_partition_ids: List[int] + parent_partition_ids: List[int] + partition_stats: "PartitionStats" + + @staticmethod + def from_proto(msg: Optional[ydb_topic_pb2.DescribeTopicResult.PartitionInfo]) -> Optional["DescribeTopicResult.PartitionInfo"]: + if msg is None: + return None + + return DescribeTopicResult.PartitionInfo( + partition_id=msg.partition_id, + active=msg.active, + child_partition_ids=list(msg.child_partition_ids), + parent_partition_ids=list(msg.parent_partition_ids), + partition_stats=PartitionStats.from_proto(msg.partition_stats) + ) + + def to_public(self) -> ydb_topic_public_types.PublicDescribeTopicResult.PartitionInfo: + partition_stats = None + if self.partition_stats is not None: + partition_stats = self.partition_stats.to_public() + return ydb_topic_public_types.PublicDescribeTopicResult.PartitionInfo( + partition_id=self.partition_id, + active=self.active, + child_partition_ids=self.child_partition_ids, + parent_partition_ids=self.parent_partition_ids, + partition_stats=partition_stats, + ) + + @dataclass + class TopicStats(IFromProto, IToPublic): + store_size_bytes: int + min_last_write_time: datetime.datetime + max_write_time_lag: datetime.timedelta + bytes_written: "MultipleWindowsStat" + + @staticmethod + def from_proto(msg: Optional[ydb_topic_pb2.DescribeTopicResult.TopicStats]) -> Optional["DescribeTopicResult.TopicStats"]: + if msg is None: + return None + + return DescribeTopicResult.TopicStats( + store_size_bytes=msg.store_size_bytes, + min_last_write_time=datetime_from_proto_timestamp(msg.min_last_write_time), + max_write_time_lag=timedelta_from_proto_duration(msg.max_write_time_lag), + bytes_written=MultipleWindowsStat.from_proto(msg.bytes_written), + ) + + def to_public(self) -> ydb_topic_public_types.PublicDescribeTopicResult.TopicStats: + return ydb_topic_public_types.PublicDescribeTopicResult.TopicStats( + store_size_bytes=self.store_size_bytes, + min_last_write_time=self.min_last_write_time, + max_write_time_lag=self.max_write_time_lag, + bytes_written=self.bytes_written.to_public(), + ) + + +@dataclass +class PartitionStats(IFromProto, IToPublic): + partition_offsets: OffsetsRange + store_size_bytes: int + last_write_time: datetime.datetime + max_write_time_lag: datetime.timedelta + bytes_written: "MultipleWindowsStat" + partition_node_id: int + + @staticmethod + def from_proto(msg: Optional[ydb_topic_pb2.PartitionStats]) -> Optional["PartitionStats"]: + if msg is None: + return None + return PartitionStats( + partition_offsets=OffsetsRange.from_proto(msg.partition_offsets), + store_size_bytes=msg.store_size_bytes, + last_write_time=datetime_from_proto_timestamp(msg.last_write_time), + max_write_time_lag=timedelta_from_proto_duration(msg.max_write_time_lag), + bytes_written=MultipleWindowsStat.from_proto(msg.bytes_written), + partition_node_id=msg.partition_node_id, + ) + + def to_public(self) -> ydb_topic_public_types.PublicPartitionStats: + return ydb_topic_public_types.PublicPartitionStats( + partition_start=self.partition_offsets.start, + partition_end=self.partition_offsets.end, + store_size_bytes=self.store_size_bytes, + last_write_time=self.last_write_time, + max_write_time_lag=self.max_write_time_lag, + bytes_written=self.bytes_written.to_public(), + partition_node_id=self.partition_node_id, + ) diff --git a/ydb/_grpc/grpcwrapper/ydb_topic_public_types.py b/ydb/_grpc/grpcwrapper/ydb_topic_public_types.py new file mode 100644 index 00000000..21ae0ed7 --- /dev/null +++ b/ydb/_grpc/grpcwrapper/ydb_topic_public_types.py @@ -0,0 +1,154 @@ +import datetime +from dataclasses import dataclass, field +from enum import IntEnum +from typing import Optional, List, Union, Dict + +# Workaround for good IDE and universal for runtime +# noinspection PyUnreachableCode +if False: + from ..v4.protos import ydb_topic_pb2 +else: + from ..common.protos import ydb_topic_pb2 + +from .common_utils import IToProto, IFromProto, proto_timestamp_from_datetime +from ...scheme import SchemeEntry, _wrap_scheme_entry + + +@dataclass +# need similar struct to PublicDescribeTopicResult +class CreateTopicRequestParams: + path: str + min_active_partitions: Optional[int] + partition_count_limit: Optional[int] + retention_period: Optional[datetime.timedelta] + retention_storage_mb: Optional[int] + supported_codecs: Optional[List[Union["PublicCodec", int]]] + partition_write_speed_bytes_per_second: Optional[int] + partition_write_burst_bytes: Optional[int] + attributes: Optional[Dict[str, str]] + consumers: Optional[List[Union["PublicConsumer", str]]] + metering_mode: Optional["PublicMeteringMode"] + + +class PublicCodec(int): + UNSPECIFIED = 0 + RAW = 1 + GZIP = 2 + LZOP = 3 + ZSTD = 4 + + +class PublicMeteringMode(IntEnum): + UNSPECIFIED = 0 + RESERVED_CAPACITY = 1 + REQUEST_UNITS = 2 + + +@dataclass +class PublicConsumer: + name: str + important: bool = False + read_from: Optional[datetime.datetime] = None + supported_codecs: List[Union[PublicCodec, int]] = field( + default_factory=lambda: list() + ) + attributes: Dict[str, str] = field(default_factory=lambda: dict()) + + +def consumers_to_proto( + consumers: Optional[List[Union[PublicConsumer, str]]] +) -> List[ydb_topic_pb2.Consumer]: + res = [] + if not consumers: + return res + + for consumer in consumers: + if isinstance(consumer, str): + consumer = PublicConsumer(name=consumer) + res.append(consumer_to_proto(consumer)) + + return res + + +def consumer_to_proto(consumer: PublicConsumer) -> ydb_topic_pb2.Consumer: + return ydb_topic_pb2.Consumer( + name=consumer.name, + important=consumer.important, + read_from=proto_timestamp_from_datetime(consumer.read_from), + supported_codecs=ydb_topic_pb2.SupportedCodecs( + codecs=consumer.supported_codecs, + ), + attributes=consumer.attributes, + ) + + +@dataclass +class DropTopicRequestParams(IToProto): + path: str + + def to_proto(self) -> ydb_topic_pb2.DropTopicRequest: + return ydb_topic_pb2.DropTopicRequest(path=self.path) + + +@dataclass +class DescribeTopicRequestParams(IToProto): + path: str + include_stats: bool + + def to_proto(self) -> ydb_topic_pb2.DescribeTopicRequest: + return ydb_topic_pb2.DescribeTopicRequest( + path=self.path, + include_stats=self.include_stats + ) + + +@dataclass +# Need similar struct to CreateTopicRequestParams +class PublicDescribeTopicResult: + self: SchemeEntry + min_active_partitions: int # Minimum partition count auto merge would stop working at + partition_count_limit: int # Limit for total partition count, including active (open for write) and read-only partitions. + partitions: List["PublicDescribeTopicResult.PartitionInfo"] # Partitions description + + retention_period: datetime.timedelta # How long data in partition should be stored + retention_storage_mb: int # How much data in partition should be stored. Zero value means infinite limit. + supported_codecs: List[PublicCodec] # List of allowed codecs for writers. + partition_write_speed_bytes_per_second: int # Partition write speed in bytes per second + partition_write_burst_bytes: int # Burst size for write in partition, in bytes + attributes: Dict[str, str] # User and server attributes of topic. Server attributes starts from "_" and will be validated by server. + consumers: List[PublicConsumer] # List of consumers for this topic + metering_mode: PublicMeteringMode # Metering settings + topic_stats: "PublicDescribeTopicResult.TopicStats" # Statistics of topic + + @dataclass + class PartitionInfo: + partition_id: int # Partition identifier + active: bool # Is partition open for write + child_partition_ids: List[int] # Ids of partitions which was formed when this partition was split or merged + parent_partition_ids: List[int] # Ids of partitions from which this partition was formed by split or merge + partition_stats: Optional["PublicPartitionStats"] # Stats for partition, filled only when include_stats in request is true + + @dataclass + class TopicStats: + store_size_bytes: int # Approximate size of topic + min_last_write_time: datetime.datetime # Minimum of timestamps of last write among all partitions. + max_write_time_lag: datetime.timedelta # Maximum of differences between write timestamp and create timestamp for all messages, written during last minute. + bytes_written: "PublicMultipleWindowsStat" # How much bytes were written statistics. + + +@dataclass +class PublicPartitionStats: + partition_start: int # first message offset in the partition + partition_end: int # last+1 message offset in the partition + store_size_bytes: int # Approximate size of partition + last_write_time: datetime.datetime # Timestamp of last write + max_write_time_lag: datetime.timedelta # Maximum of differences between write timestamp and create timestamp for all messages, written during last minute. + bytes_written: "PublicMultipleWindowsStat" # How much bytes were written during several windows in this partition. + partition_node_id: int # Host where tablet for this partition works. Useful for debugging purposes. + + +@dataclass +class PublicMultipleWindowsStat: + per_minute: int + per_hour: int + per_day: int diff --git a/ydb/_topic_common/common.py b/ydb/_topic_common/common.py index c92d9737..bef0320e 100644 --- a/ydb/_topic_common/common.py +++ b/ydb/_topic_common/common.py @@ -1,3 +1,32 @@ import typing +from .. import operation, issues +from .._grpc.grpcwrapper.common_utils import IFromProtoWithProtoType + +# Workaround for good IDE and universal for runtime +# noinspection PyUnreachableCode +if typing.TYPE_CHECKING: + from .._grpc.v4.protos import ydb_topic_pb2, ydb_operation_pb2 +else: + from .._grpc.common.protos import ydb_topic_pb2, ydb_operation_pb2 + + TokenGetterFuncType = typing.Optional[typing.Callable[[], str]] + + +def wrap_operation(rpc_state, response_pb, driver=None): + return operation.Operation(rpc_state, response_pb, driver) + + +ResultType = typing.TypeVar("ResultType", bound=IFromProtoWithProtoType) + + +def create_result_wrapper(result_type: typing.Type[ResultType]) -> typing.Callable[[typing.Any, typing.Any, typing.Any], ResultType]: + def wrapper(rpc_state, response_pb, driver=None): + issues._process_response(response_pb.operation) + msg = result_type.empty_proto_message() + response_pb.operation.result.Unpack(msg) + return result_type.from_proto(msg) + + return wrapper + diff --git a/ydb/_topic_common/control_plane.py b/ydb/_topic_common/control_plane.py deleted file mode 100644 index 052e8aeb..00000000 --- a/ydb/_topic_common/control_plane.py +++ /dev/null @@ -1,13 +0,0 @@ -from dataclasses import dataclass -from typing import Union, List - - -@dataclass -class CreateTopicRequest: - path: str - consumers: Union[List["Consumer"], None] = None - - -@dataclass -class Consumer: - name: str diff --git a/ydb/_utilities.py b/ydb/_utilities.py index 32419b1b..e2a9f98f 100644 --- a/ydb/_utilities.py +++ b/ydb/_utilities.py @@ -5,6 +5,7 @@ import functools import hashlib import collections +import warnings from . import ydb_version try: diff --git a/ydb/topic.py b/ydb/topic.py index 1b2722ca..a51ac082 100644 --- a/ydb/topic.py +++ b/ydb/topic.py @@ -1,6 +1,19 @@ -from typing import List, Callable, Union, Mapping, Any +import datetime +import warnings +from typing import List, Callable, Union, Mapping, Any, Optional, Dict + +from . import aio, Credentials, _apis + +from . import scheme + +from ._grpc.grpcwrapper.ydb_topic_public_types import ( + DropTopicRequestParams as _DropTopicRequestParams, + PublicCodec as TopicCodec, + PublicConsumer as TopicConsumer, + PublicMeteringMode as TopicMeteringMode, + DescribeTopicRequestParams as _DescribeTopicRequestParams, +) -from . import aio, Credentials from ._topic_reader.topic_reader import ( PublicReaderSettings as TopicReaderSettings, Reader as TopicReader, @@ -21,9 +34,16 @@ RetryPolicy as TopicWriterRetryPolicy, ) +from ._topic_common.common import ( + wrap_operation as _wrap_operation, + create_result_wrapper as _create_result_wrapper, +) from ydb._topic_writer.topic_writer_asyncio import WriterAsyncIO as TopicWriterAsyncIO +from ._grpc.grpcwrapper import ydb_topic as _ydb_topic +from ._grpc.grpcwrapper import ydb_topic_public_types as _ydb_topic_public_types + class TopicClientAsyncIO: _driver: aio.Driver @@ -32,6 +52,69 @@ class TopicClientAsyncIO: def __init__(self, driver: aio.Driver, settings: "TopicClientSettings" = None): self._driver = driver + async def create_topic( + self, + path: str, + min_active_partitions: Optional[ + int + ] = None, # Minimum partition count auto merge would stop working at. + partition_count_limit: Optional[ + int + ] = None, # Limit for total partition count, including active (open for write) and read-only partitions. + retention_period: Optional[ + datetime.timedelta + ] = None, # How long data in partition should be stored + retention_storage_mb: Optional[ + int + ] = None, # How much data in partition should be stored + # List of allowed codecs for writers. + # Writes with codec not from this list are forbidden. + supported_codecs: Optional[List[Union[TopicCodec, int]]] = None, + partition_write_speed_bytes_per_second: Optional[ + int + ] = None, # Partition write speed in bytes per second + partition_write_burst_bytes: Optional[ + int + ] = None, # Burst size for write in partition, in bytes + # User and server attributes of topic. Server attributes starts from "_" and will be validated by server. + attributes: Optional[Dict[str, str]] = None, + # List of consumers for this topic + consumers: Optional[List[Union[TopicConsumer, str]]] = None, + # Metering mode for the topic in a serverless database + metering_mode: Optional[TopicMeteringMode] = None, + ): + args = locals().copy() + del args["self"] + req = _ydb_topic_public_types.CreateTopicRequestParams(**args) + req = _ydb_topic.CreateTopicRequest.from_public(req) + await self._driver( + req.to_proto(), + _apis.TopicService.Stub, + _apis.TopicService.CreateTopic, + _wrap_operation, + ) + + async def describe(self, path: str, include_stats: bool = False): + args = locals().copy() + del args["self"] + req = _DescribeTopicRequestParams(**args) + res = await self._driver( + req.to_proto(), + _apis.TopicService.Stub, + _apis.TopicService.DescribeTopic, + _create_result_wrapper(_ydb_topic.DescribeTopicResult), + ) # type: _ydb_topic.DescribeTopicResult + return res.to_public() + + async def drop_topic(self, path: str): + req = _DropTopicRequestParams(path=path) + await self._driver( + req.to_proto(), + _apis.TopicService.Stub, + _apis.TopicService.DropTopic, + _wrap_operation, + ) + def topic_reader( self, consumer: str, From 16f347cbb02ff8dd6f111023b7a3791758451733 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Mon, 13 Feb 2023 15:03:33 +0300 Subject: [PATCH 039/147] add sync client --- tests/conftest.py | 17 +++++++ tests/topics/test_control_plane.py | 34 +++++++++++++ ydb/aio/driver.py | 4 +- ydb/driver.py | 3 ++ ydb/topic.py | 81 ++++++++++++++++++++++++++++-- 5 files changed, 134 insertions(+), 5 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index c4a62526..183ad3a9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -99,6 +99,23 @@ async def driver(endpoint, database, event_loop): await driver.stop(timeout=10) +@pytest.fixture() +async def driver_sync(endpoint, database, event_loop): + driver_config = ydb.DriverConfig( + endpoint, + database, + credentials=ydb.construct_credentials_from_environ(), + root_certificates=ydb.load_ydb_root_certificate(), + ) + + driver = ydb.Driver(driver_config=driver_config) + driver.wait(timeout=15) + + yield driver + + driver.stop(timeout=10) + + @pytest.fixture() def topic_consumer(): return "fixture-consumer" diff --git a/tests/topics/test_control_plane.py b/tests/topics/test_control_plane.py index 615bd271..446e64d2 100644 --- a/tests/topics/test_control_plane.py +++ b/tests/topics/test_control_plane.py @@ -38,3 +38,37 @@ async def test_describe_topic(self, driver, topic_path: str, topic_consumer): break assert has_consumer + + +class TestTopicClientControlPlane: + def test_create_topic(self, driver_sync, database): + client = driver_sync.topic_client + + topic_path = database + "/my-test-topic" + + client.create_topic(topic_path) + + with pytest.raises(issues.SchemeError): + # double create is ok - try create topic with bad path + client.create_topic(database) + + def test_drop_topic(self, driver_sync, topic_path): + client = driver_sync.topic_client + + client.drop_topic(topic_path) + + with pytest.raises(issues.SchemeError): + client.drop_topic(topic_path) + + def test_describe_topic(self, driver_sync, topic_path: str, topic_consumer): + res = driver_sync.topic_client.describe(topic_path) + + assert res.self.name == os.path.basename(topic_path) + + has_consumer = False + for consumer in res.consumers: + if consumer.name == topic_consumer: + has_consumer = True + break + + assert has_consumer diff --git a/ydb/aio/driver.py b/ydb/aio/driver.py index b6641e27..042170d8 100644 --- a/ydb/aio/driver.py +++ b/ydb/aio/driver.py @@ -4,8 +4,6 @@ import ydb from .. import _utilities from ydb.driver import get_config -from .. import topic - def default_credentials(credentials=None): if credentials is not None: @@ -81,6 +79,8 @@ def __init__( credentials=None, **kwargs ): + from .. import topic # local import for prevent cycle import error + config = get_config( driver_config, connection_string, diff --git a/ydb/driver.py b/ydb/driver.py index e66a5fc9..9aa6aab3 100644 --- a/ydb/driver.py +++ b/ydb/driver.py @@ -231,6 +231,8 @@ def __init__( :param database: A database path :param credentials: A credentials. If not specifed credentials constructed by default. """ + from . import topic # local import for prevent cycle import error + driver_config = get_config( driver_config, connection_string, @@ -246,3 +248,4 @@ def __init__( self.scheme_client = scheme.SchemeClient(self) self.table_client = table.TableClient(self, driver_config.table_client_settings) + self.topic_client = topic.TopicClient(self, driver_config.topic_client_settings) diff --git a/ydb/topic.py b/ydb/topic.py index a51ac082..b0b339dc 100644 --- a/ydb/topic.py +++ b/ydb/topic.py @@ -5,6 +5,7 @@ from . import aio, Credentials, _apis from . import scheme +from . import driver from ._grpc.grpcwrapper.ydb_topic_public_types import ( DropTopicRequestParams as _DropTopicRequestParams, @@ -43,6 +44,14 @@ from ._grpc.grpcwrapper import ydb_topic as _ydb_topic from ._grpc.grpcwrapper import ydb_topic_public_types as _ydb_topic_public_types +from ._grpc.grpcwrapper.ydb_topic_public_types import ( + PublicDescribeTopicResult as TopicDescription, + PublicMultipleWindowsStat as TopicStatWindow, + PublicPartitionStats as TopicPartitionStats, + PublicCodec as TopicCodec, + PublicConsumer as TopicConsumer, + PublicMeteringMode as TopicMeteringMode, +) class TopicClientAsyncIO: @@ -94,7 +103,7 @@ async def create_topic( _wrap_operation, ) - async def describe(self, path: str, include_stats: bool = False): + async def describe(self, path: str, include_stats: bool = False) -> TopicDescription: args = locals().copy() del args["self"] req = _DescribeTopicRequestParams(**args) @@ -164,8 +173,74 @@ def topic_writer( class TopicClient: - def __init__(self, driver, topic_client_settings: "TopicClientSettings" = None): - pass + _driver: driver.Driver + _credentials: Union[Credentials, None] + + def __init__(self, driver: driver.Driver, topic_client_settings: "TopicClientSettings" = None): + self._driver = driver + + def create_topic( + self, + path: str, + min_active_partitions: Optional[ + int + ] = None, # Minimum partition count auto merge would stop working at. + partition_count_limit: Optional[ + int + ] = None, # Limit for total partition count, including active (open for write) and read-only partitions. + retention_period: Optional[ + datetime.timedelta + ] = None, # How long data in partition should be stored + retention_storage_mb: Optional[ + int + ] = None, # How much data in partition should be stored + # List of allowed codecs for writers. + # Writes with codec not from this list are forbidden. + supported_codecs: Optional[List[Union[TopicCodec, int]]] = None, + partition_write_speed_bytes_per_second: Optional[ + int + ] = None, # Partition write speed in bytes per second + partition_write_burst_bytes: Optional[ + int + ] = None, # Burst size for write in partition, in bytes + # User and server attributes of topic. Server attributes starts from "_" and will be validated by server. + attributes: Optional[Dict[str, str]] = None, + # List of consumers for this topic + consumers: Optional[List[Union[TopicConsumer, str]]] = None, + # Metering mode for the topic in a serverless database + metering_mode: Optional[TopicMeteringMode] = None, + ): + args = locals().copy() + del args["self"] + req = _ydb_topic_public_types.CreateTopicRequestParams(**args) + req = _ydb_topic.CreateTopicRequest.from_public(req) + self._driver( + req.to_proto(), + _apis.TopicService.Stub, + _apis.TopicService.CreateTopic, + _wrap_operation, + ) + + def describe(self, path: str, include_stats: bool = False) -> TopicDescription: + args = locals().copy() + del args["self"] + req = _DescribeTopicRequestParams(**args) + res = self._driver( + req.to_proto(), + _apis.TopicService.Stub, + _apis.TopicService.DescribeTopic, + _create_result_wrapper(_ydb_topic.DescribeTopicResult), + ) # type: _ydb_topic.DescribeTopicResult + return res.to_public() + + def drop_topic(self, path: str): + req = _DropTopicRequestParams(path=path) + self._driver( + req.to_proto(), + _apis.TopicService.Stub, + _apis.TopicService.DropTopic, + _wrap_operation, + ) def topic_reader( self, From 6611c291ab06a52e2f6159d2eeeca322c3d34d35 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Mon, 13 Feb 2023 15:30:50 +0300 Subject: [PATCH 040/147] linters --- tests/topics/test_control_plane.py | 1 + ydb/_grpc/grpcwrapper/common_utils.py | 10 +- ydb/_grpc/grpcwrapper/ydb_topic.py | 75 +++++--- .../grpcwrapper/ydb_topic_public_types.py | 163 +++++++++++------- ydb/_topic_common/common.py | 13 +- ydb/_utilities.py | 1 - ydb/aio/driver.py | 1 + ydb/topic.py | 122 +++++++------ 8 files changed, 224 insertions(+), 162 deletions(-) diff --git a/tests/topics/test_control_plane.py b/tests/topics/test_control_plane.py index 446e64d2..8e1d6f23 100644 --- a/tests/topics/test_control_plane.py +++ b/tests/topics/test_control_plane.py @@ -62,6 +62,7 @@ def test_drop_topic(self, driver_sync, topic_path): def test_describe_topic(self, driver_sync, topic_path: str, topic_consumer): res = driver_sync.topic_client.describe(topic_path) + res.partition_count_limit assert res.self.name == os.path.basename(topic_path) diff --git a/ydb/_grpc/grpcwrapper/common_utils.py b/ydb/_grpc/grpcwrapper/common_utils.py index 3af5f363..5e771051 100644 --- a/ydb/_grpc/grpcwrapper/common_utils.py +++ b/ydb/_grpc/grpcwrapper/common_utils.py @@ -54,7 +54,6 @@ def to_proto(self) -> Message: class IFromPublic(abc.ABC): - @staticmethod @abc.abstractmethod def from_public(o: typing.Any) -> typing.Any: @@ -259,12 +258,17 @@ def proto_timestamp_from_datetime(t: Optional[datetime.datetime]) -> ProtoTimeSt res.FromDatetime(t) -def datetime_from_proto_timestamp(ts: Optional[ProtoTimeStamp]) -> Optional[datetime.datetime]: +def datetime_from_proto_timestamp( + ts: Optional[ProtoTimeStamp], +) -> Optional[datetime.datetime]: if ts is None: return None return ts.ToDatetime() -def timedelta_from_proto_duration(d: Optional[ProtoDuration]) -> Optional[datetime.timedelta]: + +def timedelta_from_proto_duration( + d: Optional[ProtoDuration], +) -> Optional[datetime.timedelta]: if d is None: return None return d.ToTimedelta() diff --git a/ydb/_grpc/grpcwrapper/ydb_topic.py b/ydb/_grpc/grpcwrapper/ydb_topic.py index 888363fe..8e3129f7 100644 --- a/ydb/_grpc/grpcwrapper/ydb_topic.py +++ b/ydb/_grpc/grpcwrapper/ydb_topic.py @@ -2,7 +2,6 @@ import enum import typing from dataclasses import dataclass, field -from enum import IntEnum from typing import List, Union, Dict, Optional from google.protobuf.message import Message @@ -26,7 +25,9 @@ ServerStatus, UnknownGrpcMessageError, proto_duration_from_timedelta, - proto_timestamp_from_datetime, datetime_from_proto_timestamp, timedelta_from_proto_duration, + proto_timestamp_from_datetime, + datetime_from_proto_timestamp, + timedelta_from_proto_duration, ) @@ -686,7 +687,9 @@ class MultipleWindowsStat(IFromProto, IToPublic): per_day: int @staticmethod - def from_proto(msg: Optional[ydb_topic_pb2.MultipleWindowsStat]) -> Optional["MultipleWindowsStat"]: + def from_proto( + msg: Optional[ydb_topic_pb2.MultipleWindowsStat], + ) -> Optional["MultipleWindowsStat"]: if msg is None: return None return MultipleWindowsStat( @@ -702,6 +705,7 @@ def to_public(self) -> ydb_topic_public_types.PublicMultipleWindowsStat: per_day=self.per_day, ) + @dataclass class Consumer(IToProto, IFromProto, IFromPublic, IToPublic): name: str @@ -745,9 +749,7 @@ def from_public(consumer: ydb_topic_public_types.PublicConsumer): name=consumer.name, important=consumer.important, read_from=consumer.read_from, - supported_codecs=SupportedCodecs( - codecs=supported_codecs - ), + supported_codecs=SupportedCodecs(codecs=supported_codecs), attributes=consumer.attributes, consumer_stats=None, ) @@ -769,11 +771,17 @@ class ConsumerStats(IFromProto): bytes_read: MultipleWindowsStat @staticmethod - def from_proto(msg: ydb_topic_pb2.Consumer.ConsumerStats) -> "Consumer.ConsumerStats": + def from_proto( + msg: ydb_topic_pb2.Consumer.ConsumerStats, + ) -> "Consumer.ConsumerStats": return Consumer.ConsumerStats( - min_partitions_last_read_time=datetime_from_proto_timestamp(msg.min_partitions_last_read_time), + min_partitions_last_read_time=datetime_from_proto_timestamp( + msg.min_partitions_last_read_time + ), max_read_time_lag=timedelta_from_proto_duration(msg.max_read_time_lag), - max_write_time_lag=timedelta_from_proto_duration(msg.max_write_time_lag), + max_write_time_lag=timedelta_from_proto_duration( + msg.max_write_time_lag + ), bytes_read=MultipleWindowsStat.from_proto(msg.bytes_read), ) @@ -803,7 +811,9 @@ class MeteringMode(int, IFromProto, IFromPublic, IToPublic): REQUEST_UNITS = 2 @staticmethod - def from_public(m: Optional[ydb_topic_public_types.PublicMeteringMode]) -> Optional["MeteringMode"]: + def from_public( + m: Optional[ydb_topic_public_types.PublicMeteringMode], + ) -> Optional["MeteringMode"]: if m is None: return None @@ -875,7 +885,7 @@ def from_public(req: ydb_topic_public_types.CreateTopicRequestParams): supported_codecs=SupportedCodecs( codecs=supported_codecs, ), - partition_write_speed_bytes_per_second = req.partition_write_speed_bytes_per_second, + partition_write_speed_bytes_per_second=req.partition_write_speed_bytes_per_second, partition_write_burst_bytes=req.partition_write_burst_bytes, attributes=req.attributes, consumers=consumers, @@ -883,7 +893,6 @@ def from_public(req: ydb_topic_public_types.CreateTopicRequestParams): ) - @dataclass class CreateTopicResult: pass @@ -914,8 +923,12 @@ class DescribeTopicResult(IFromProtoWithProtoType, IToPublic): def from_proto(msg: ydb_topic_pb2.DescribeTopicResult) -> "DescribeTopicResult": return DescribeTopicResult( self_proto=msg.self, - partitioning_settings=PartitioningSettings.from_proto(msg.partitioning_settings), - partitions=list(map(DescribeTopicResult.PartitionInfo.from_proto, msg.partitions)), + partitioning_settings=PartitioningSettings.from_proto( + msg.partitioning_settings + ), + partitions=list( + map(DescribeTopicResult.PartitionInfo.from_proto, msg.partitions) + ), retention_period=msg.retention_period, retention_storage_mb=msg.retention_storage_mb, supported_codecs=SupportedCodecs.from_proto(msg.supported_codecs), @@ -936,7 +949,9 @@ def to_public(self) -> ydb_topic_public_types.PublicDescribeTopicResult: self=scheme._wrap_scheme_entry(self.self_proto), min_active_partitions=self.partitioning_settings.min_active_partitions, partition_count_limit=self.partitioning_settings.partition_count_limit, - partitions=list(map(DescribeTopicResult.PartitionInfo.to_public, self.partitions)), + partitions=list( + map(DescribeTopicResult.PartitionInfo.to_public, self.partitions) + ), retention_period=self.retention_period, retention_storage_mb=self.retention_storage_mb, supported_codecs=self.supported_codecs.to_public(), @@ -957,7 +972,9 @@ class PartitionInfo(IFromProto, IToPublic): partition_stats: "PartitionStats" @staticmethod - def from_proto(msg: Optional[ydb_topic_pb2.DescribeTopicResult.PartitionInfo]) -> Optional["DescribeTopicResult.PartitionInfo"]: + def from_proto( + msg: Optional[ydb_topic_pb2.DescribeTopicResult.PartitionInfo], + ) -> Optional["DescribeTopicResult.PartitionInfo"]: if msg is None: return None @@ -966,10 +983,12 @@ def from_proto(msg: Optional[ydb_topic_pb2.DescribeTopicResult.PartitionInfo]) - active=msg.active, child_partition_ids=list(msg.child_partition_ids), parent_partition_ids=list(msg.parent_partition_ids), - partition_stats=PartitionStats.from_proto(msg.partition_stats) + partition_stats=PartitionStats.from_proto(msg.partition_stats), ) - def to_public(self) -> ydb_topic_public_types.PublicDescribeTopicResult.PartitionInfo: + def to_public( + self, + ) -> ydb_topic_public_types.PublicDescribeTopicResult.PartitionInfo: partition_stats = None if self.partition_stats is not None: partition_stats = self.partition_stats.to_public() @@ -989,18 +1008,26 @@ class TopicStats(IFromProto, IToPublic): bytes_written: "MultipleWindowsStat" @staticmethod - def from_proto(msg: Optional[ydb_topic_pb2.DescribeTopicResult.TopicStats]) -> Optional["DescribeTopicResult.TopicStats"]: + def from_proto( + msg: Optional[ydb_topic_pb2.DescribeTopicResult.TopicStats], + ) -> Optional["DescribeTopicResult.TopicStats"]: if msg is None: return None return DescribeTopicResult.TopicStats( store_size_bytes=msg.store_size_bytes, - min_last_write_time=datetime_from_proto_timestamp(msg.min_last_write_time), - max_write_time_lag=timedelta_from_proto_duration(msg.max_write_time_lag), + min_last_write_time=datetime_from_proto_timestamp( + msg.min_last_write_time + ), + max_write_time_lag=timedelta_from_proto_duration( + msg.max_write_time_lag + ), bytes_written=MultipleWindowsStat.from_proto(msg.bytes_written), ) - def to_public(self) -> ydb_topic_public_types.PublicDescribeTopicResult.TopicStats: + def to_public( + self, + ) -> ydb_topic_public_types.PublicDescribeTopicResult.TopicStats: return ydb_topic_public_types.PublicDescribeTopicResult.TopicStats( store_size_bytes=self.store_size_bytes, min_last_write_time=self.min_last_write_time, @@ -1019,7 +1046,9 @@ class PartitionStats(IFromProto, IToPublic): partition_node_id: int @staticmethod - def from_proto(msg: Optional[ydb_topic_pb2.PartitionStats]) -> Optional["PartitionStats"]: + def from_proto( + msg: Optional[ydb_topic_pb2.PartitionStats], + ) -> Optional["PartitionStats"]: if msg is None: return None return PartitionStats( diff --git a/ydb/_grpc/grpcwrapper/ydb_topic_public_types.py b/ydb/_grpc/grpcwrapper/ydb_topic_public_types.py index 21ae0ed7..43aa8449 100644 --- a/ydb/_grpc/grpcwrapper/ydb_topic_public_types.py +++ b/ydb/_grpc/grpcwrapper/ydb_topic_public_types.py @@ -10,8 +10,8 @@ else: from ..common.protos import ydb_topic_pb2 -from .common_utils import IToProto, IFromProto, proto_timestamp_from_datetime -from ...scheme import SchemeEntry, _wrap_scheme_entry +from .common_utils import IToProto +from ...scheme import SchemeEntry @dataclass @@ -48,38 +48,22 @@ class PublicMeteringMode(IntEnum): class PublicConsumer: name: str important: bool = False - read_from: Optional[datetime.datetime] = None - supported_codecs: List[Union[PublicCodec, int]] = field( - default_factory=lambda: list() - ) - attributes: Dict[str, str] = field(default_factory=lambda: dict()) - - -def consumers_to_proto( - consumers: Optional[List[Union[PublicConsumer, str]]] -) -> List[ydb_topic_pb2.Consumer]: - res = [] - if not consumers: - return res + """ + Consumer may be marked as 'important'. It means messages for this consumer will never expire due to retention. + User should take care that such consumer never stalls, to prevent running out of disk space. + """ - for consumer in consumers: - if isinstance(consumer, str): - consumer = PublicConsumer(name=consumer) - res.append(consumer_to_proto(consumer)) - - return res + read_from: Optional[datetime.datetime] = None + "All messages with smaller server written_at timestamp will be skipped." + supported_codecs: List[PublicCodec] = field(default_factory=lambda: list()) + """ + List of supported codecs by this consumer. + supported_codecs on topic must be contained inside this list. + """ -def consumer_to_proto(consumer: PublicConsumer) -> ydb_topic_pb2.Consumer: - return ydb_topic_pb2.Consumer( - name=consumer.name, - important=consumer.important, - read_from=proto_timestamp_from_datetime(consumer.read_from), - supported_codecs=ydb_topic_pb2.SupportedCodecs( - codecs=consumer.supported_codecs, - ), - attributes=consumer.attributes, - ) + attributes: Dict[str, str] = field(default_factory=lambda: dict()) + "Attributes of consumer" @dataclass @@ -97,8 +81,7 @@ class DescribeTopicRequestParams(IToProto): def to_proto(self) -> ydb_topic_pb2.DescribeTopicRequest: return ydb_topic_pb2.DescribeTopicRequest( - path=self.path, - include_stats=self.include_stats + path=self.path, include_stats=self.include_stats ) @@ -106,45 +89,101 @@ def to_proto(self) -> ydb_topic_pb2.DescribeTopicRequest: # Need similar struct to CreateTopicRequestParams class PublicDescribeTopicResult: self: SchemeEntry - min_active_partitions: int # Minimum partition count auto merge would stop working at - partition_count_limit: int # Limit for total partition count, including active (open for write) and read-only partitions. - partitions: List["PublicDescribeTopicResult.PartitionInfo"] # Partitions description - - retention_period: datetime.timedelta # How long data in partition should be stored - retention_storage_mb: int # How much data in partition should be stored. Zero value means infinite limit. - supported_codecs: List[PublicCodec] # List of allowed codecs for writers. - partition_write_speed_bytes_per_second: int # Partition write speed in bytes per second - partition_write_burst_bytes: int # Burst size for write in partition, in bytes - attributes: Dict[str, str] # User and server attributes of topic. Server attributes starts from "_" and will be validated by server. - consumers: List[PublicConsumer] # List of consumers for this topic - metering_mode: PublicMeteringMode # Metering settings - topic_stats: "PublicDescribeTopicResult.TopicStats" # Statistics of topic + "Description of scheme object" + + min_active_partitions: int + "Minimum partition count auto merge would stop working at" + + partition_count_limit: int + "Limit for total partition count, including active (open for write) and read-only partitions" + + partitions: List["PublicDescribeTopicResult.PartitionInfo"] + "Partitions description" + + retention_period: datetime.timedelta + "How long data in partition should be stored" + + retention_storage_mb: int + "How much data in partition should be stored. Zero value means infinite limit" + + supported_codecs: List[PublicCodec] + "List of allowed codecs for writers" + + partition_write_speed_bytes_per_second: int + "Partition write speed in bytes per second" + + partition_write_burst_bytes: int + "Burst size for write in partition, in bytes" + + attributes: Dict[str, str] + """User and server attributes of topic. Server attributes starts from "_" and will be validated by server.""" + + consumers: List[PublicConsumer] + """List of consumers for this topic""" + + metering_mode: PublicMeteringMode + "Metering settings" + + topic_stats: "PublicDescribeTopicResult.TopicStats" + "Statistics of topic" @dataclass class PartitionInfo: - partition_id: int # Partition identifier - active: bool # Is partition open for write - child_partition_ids: List[int] # Ids of partitions which was formed when this partition was split or merged - parent_partition_ids: List[int] # Ids of partitions from which this partition was formed by split or merge - partition_stats: Optional["PublicPartitionStats"] # Stats for partition, filled only when include_stats in request is true + partition_id: int + "Partition identifier" + + active: bool + "Is partition open for write" + + child_partition_ids: List[int] + "Ids of partitions which was formed when this partition was split or merged" + + parent_partition_ids: List[int] + "Ids of partitions from which this partition was formed by split or merge" + + partition_stats: Optional["PublicPartitionStats"] + "Stats for partition, filled only when include_stats in request is true" @dataclass class TopicStats: - store_size_bytes: int # Approximate size of topic - min_last_write_time: datetime.datetime # Minimum of timestamps of last write among all partitions. - max_write_time_lag: datetime.timedelta # Maximum of differences between write timestamp and create timestamp for all messages, written during last minute. - bytes_written: "PublicMultipleWindowsStat" # How much bytes were written statistics. + store_size_bytes: int + "Approximate size of topic" + + min_last_write_time: datetime.datetime + "Minimum of timestamps of last write among all partitions." + + max_write_time_lag: datetime.timedelta + """ + Maximum of differences between write timestamp and create timestamp for all messages, + written during last minute. + """ + + bytes_written: "PublicMultipleWindowsStat" + "How much bytes were written statistics." @dataclass class PublicPartitionStats: - partition_start: int # first message offset in the partition - partition_end: int # last+1 message offset in the partition - store_size_bytes: int # Approximate size of partition - last_write_time: datetime.datetime # Timestamp of last write - max_write_time_lag: datetime.timedelta # Maximum of differences between write timestamp and create timestamp for all messages, written during last minute. - bytes_written: "PublicMultipleWindowsStat" # How much bytes were written during several windows in this partition. - partition_node_id: int # Host where tablet for this partition works. Useful for debugging purposes. + partition_start: int + "first message offset in the partition" + + partition_end: int + "offset after last stored message offset in the partition (last offset + 1)" + + store_size_bytes: int + "Approximate size of partition" + + last_write_time: datetime.datetime + "Timestamp of last write" + + max_write_time_lag: datetime.timedelta + "Maximum of differences between write timestamp and create timestamp for all messages, written during last minute." + + bytes_written: "PublicMultipleWindowsStat" + "How much bytes were written during several windows in this partition." + + partition_node_id: int + "Host where tablet for this partition works. Useful for debugging purposes." @dataclass diff --git a/ydb/_topic_common/common.py b/ydb/_topic_common/common.py index bef0320e..8dcafcb7 100644 --- a/ydb/_topic_common/common.py +++ b/ydb/_topic_common/common.py @@ -3,14 +3,6 @@ from .. import operation, issues from .._grpc.grpcwrapper.common_utils import IFromProtoWithProtoType -# Workaround for good IDE and universal for runtime -# noinspection PyUnreachableCode -if typing.TYPE_CHECKING: - from .._grpc.v4.protos import ydb_topic_pb2, ydb_operation_pb2 -else: - from .._grpc.common.protos import ydb_topic_pb2, ydb_operation_pb2 - - TokenGetterFuncType = typing.Optional[typing.Callable[[], str]] @@ -21,7 +13,9 @@ def wrap_operation(rpc_state, response_pb, driver=None): ResultType = typing.TypeVar("ResultType", bound=IFromProtoWithProtoType) -def create_result_wrapper(result_type: typing.Type[ResultType]) -> typing.Callable[[typing.Any, typing.Any, typing.Any], ResultType]: +def create_result_wrapper( + result_type: typing.Type[ResultType], +) -> typing.Callable[[typing.Any, typing.Any, typing.Any], ResultType]: def wrapper(rpc_state, response_pb, driver=None): issues._process_response(response_pb.operation) msg = result_type.empty_proto_message() @@ -29,4 +23,3 @@ def wrapper(rpc_state, response_pb, driver=None): return result_type.from_proto(msg) return wrapper - diff --git a/ydb/_utilities.py b/ydb/_utilities.py index e2a9f98f..32419b1b 100644 --- a/ydb/_utilities.py +++ b/ydb/_utilities.py @@ -5,7 +5,6 @@ import functools import hashlib import collections -import warnings from . import ydb_version try: diff --git a/ydb/aio/driver.py b/ydb/aio/driver.py index 042170d8..319cb14c 100644 --- a/ydb/aio/driver.py +++ b/ydb/aio/driver.py @@ -5,6 +5,7 @@ from .. import _utilities from ydb.driver import get_config + def default_credentials(credentials=None): if credentials is not None: return credentials diff --git a/ydb/topic.py b/ydb/topic.py index b0b339dc..42d283bc 100644 --- a/ydb/topic.py +++ b/ydb/topic.py @@ -1,20 +1,10 @@ import datetime -import warnings from typing import List, Callable, Union, Mapping, Any, Optional, Dict from . import aio, Credentials, _apis -from . import scheme from . import driver -from ._grpc.grpcwrapper.ydb_topic_public_types import ( - DropTopicRequestParams as _DropTopicRequestParams, - PublicCodec as TopicCodec, - PublicConsumer as TopicConsumer, - PublicMeteringMode as TopicMeteringMode, - DescribeTopicRequestParams as _DescribeTopicRequestParams, -) - from ._topic_reader.topic_reader import ( PublicReaderSettings as TopicReaderSettings, Reader as TopicReader, @@ -44,7 +34,7 @@ from ._grpc.grpcwrapper import ydb_topic as _ydb_topic from ._grpc.grpcwrapper import ydb_topic_public_types as _ydb_topic_public_types -from ._grpc.grpcwrapper.ydb_topic_public_types import ( +from ._grpc.grpcwrapper.ydb_topic_public_types import ( # noqa: F401 PublicDescribeTopicResult as TopicDescription, PublicMultipleWindowsStat as TopicStatWindow, PublicPartitionStats as TopicPartitionStats, @@ -64,34 +54,35 @@ def __init__(self, driver: aio.Driver, settings: "TopicClientSettings" = None): async def create_topic( self, path: str, - min_active_partitions: Optional[ - int - ] = None, # Minimum partition count auto merge would stop working at. - partition_count_limit: Optional[ - int - ] = None, # Limit for total partition count, including active (open for write) and read-only partitions. - retention_period: Optional[ - datetime.timedelta - ] = None, # How long data in partition should be stored - retention_storage_mb: Optional[ - int - ] = None, # How much data in partition should be stored - # List of allowed codecs for writers. - # Writes with codec not from this list are forbidden. + min_active_partitions: Optional[int] = None, + partition_count_limit: Optional[int] = None, + retention_period: Optional[datetime.timedelta] = None, + retention_storage_mb: Optional[int] = None, supported_codecs: Optional[List[Union[TopicCodec, int]]] = None, - partition_write_speed_bytes_per_second: Optional[ - int - ] = None, # Partition write speed in bytes per second - partition_write_burst_bytes: Optional[ - int - ] = None, # Burst size for write in partition, in bytes - # User and server attributes of topic. Server attributes starts from "_" and will be validated by server. + partition_write_speed_bytes_per_second: Optional[int] = None, + partition_write_burst_bytes: Optional[int] = None, attributes: Optional[Dict[str, str]] = None, - # List of consumers for this topic consumers: Optional[List[Union[TopicConsumer, str]]] = None, - # Metering mode for the topic in a serverless database metering_mode: Optional[TopicMeteringMode] = None, ): + """ + create topic command + + :param path: full path to topic + :param min_active_partitions: Minimum partition count auto merge would stop working at. + :param partition_count_limit: Limit for total partition count, including active (open for write) + and read-only partitions. + :param retention_period: How long data in partition should be stored + :param retention_storage_mb: How much data in partition should be stored + :param supported_codecs: List of allowed codecs for writers. Writes with codec not from this list are forbidden. + Empty list mean disable codec compatibility checks for the topic. + :param partition_write_speed_bytes_per_second: Partition write speed in bytes per second + :param partition_write_burst_bytes: Burst size for write in partition, in bytes + :param attributes: User and server attributes of topic. + Server attributes starts from "_" and will be validated by server. + :param consumers: List of consumers for this topic + :param metering_mode: Metering mode for the topic in a serverless database + """ args = locals().copy() del args["self"] req = _ydb_topic_public_types.CreateTopicRequestParams(**args) @@ -103,10 +94,12 @@ async def create_topic( _wrap_operation, ) - async def describe(self, path: str, include_stats: bool = False) -> TopicDescription: + async def describe( + self, path: str, include_stats: bool = False + ) -> TopicDescription: args = locals().copy() del args["self"] - req = _DescribeTopicRequestParams(**args) + req = _ydb_topic_public_types.DescribeTopicRequestParams(**args) res = await self._driver( req.to_proto(), _apis.TopicService.Stub, @@ -116,7 +109,7 @@ async def describe(self, path: str, include_stats: bool = False) -> TopicDescrip return res.to_public() async def drop_topic(self, path: str): - req = _DropTopicRequestParams(path=path) + req = _ydb_topic_public_types.DropTopicRequestParams(path=path) await self._driver( req.to_proto(), _apis.TopicService.Stub, @@ -176,40 +169,43 @@ class TopicClient: _driver: driver.Driver _credentials: Union[Credentials, None] - def __init__(self, driver: driver.Driver, topic_client_settings: "TopicClientSettings" = None): + def __init__( + self, driver: driver.Driver, topic_client_settings: "TopicClientSettings" = None + ): self._driver = driver def create_topic( self, path: str, - min_active_partitions: Optional[ - int - ] = None, # Minimum partition count auto merge would stop working at. - partition_count_limit: Optional[ - int - ] = None, # Limit for total partition count, including active (open for write) and read-only partitions. - retention_period: Optional[ - datetime.timedelta - ] = None, # How long data in partition should be stored - retention_storage_mb: Optional[ - int - ] = None, # How much data in partition should be stored - # List of allowed codecs for writers. - # Writes with codec not from this list are forbidden. + min_active_partitions: Optional[int] = None, + partition_count_limit: Optional[int] = None, + retention_period: Optional[datetime.timedelta] = None, + retention_storage_mb: Optional[int] = None, supported_codecs: Optional[List[Union[TopicCodec, int]]] = None, - partition_write_speed_bytes_per_second: Optional[ - int - ] = None, # Partition write speed in bytes per second - partition_write_burst_bytes: Optional[ - int - ] = None, # Burst size for write in partition, in bytes - # User and server attributes of topic. Server attributes starts from "_" and will be validated by server. + partition_write_speed_bytes_per_second: Optional[int] = None, + partition_write_burst_bytes: Optional[int] = None, attributes: Optional[Dict[str, str]] = None, - # List of consumers for this topic consumers: Optional[List[Union[TopicConsumer, str]]] = None, - # Metering mode for the topic in a serverless database metering_mode: Optional[TopicMeteringMode] = None, ): + """ + create topic command + + :param path: full path to topic + :param min_active_partitions: Minimum partition count auto merge would stop working at. + :param partition_count_limit: Limit for total partition count, including active (open for write) + and read-only partitions. + :param retention_period: How long data in partition should be stored + :param retention_storage_mb: How much data in partition should be stored + :param supported_codecs: List of allowed codecs for writers. Writes with codec not from this list are forbidden. + Empty list mean disable codec compatibility checks for the topic. + :param partition_write_speed_bytes_per_second: Partition write speed in bytes per second + :param partition_write_burst_bytes: Burst size for write in partition, in bytes + :param attributes: User and server attributes of topic. + Server attributes starts from "_" and will be validated by server. + :param consumers: List of consumers for this topic + :param metering_mode: Metering mode for the topic in a serverless database + """ args = locals().copy() del args["self"] req = _ydb_topic_public_types.CreateTopicRequestParams(**args) @@ -224,7 +220,7 @@ def create_topic( def describe(self, path: str, include_stats: bool = False) -> TopicDescription: args = locals().copy() del args["self"] - req = _DescribeTopicRequestParams(**args) + req = _ydb_topic_public_types.DescribeTopicRequestParams(**args) res = self._driver( req.to_proto(), _apis.TopicService.Stub, @@ -234,7 +230,7 @@ def describe(self, path: str, include_stats: bool = False) -> TopicDescription: return res.to_public() def drop_topic(self, path: str): - req = _DropTopicRequestParams(path=path) + req = _ydb_topic_public_types.DropTopicRequestParams(path=path) self._driver( req.to_proto(), _apis.TopicService.Stub, From 76f8b704bb11f08efa90fa0ac3f3eda716c6862c Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Mon, 13 Feb 2023 16:13:52 +0300 Subject: [PATCH 041/147] fix usage internal codec --- tests/conftest.py | 3 +-- ydb/_topic_writer/topic_writer.py | 2 +- ydb/_topic_writer/topic_writer_asyncio.py | 2 ++ 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 183ad3a9..aa28a9bd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -142,11 +142,10 @@ async def topic_path(driver, topic_consumer, database) -> str: @pytest.fixture() @pytest.mark.asyncio() async def topic_with_messages(driver, topic_path): - pass writer = driver.topic_client.topic_writer( topic_path, producer_and_message_group_id="fixture-producer-id" ) - await writer.write_with_ack( + res = await writer.write_with_ack( ydb.TopicWriterMessage(data="123".encode()), ydb.TopicWriterMessage(data="456".encode()), ) diff --git a/ydb/_topic_writer/topic_writer.py b/ydb/_topic_writer/topic_writer.py index f3b0b3ab..b66d8205 100644 --- a/ydb/_topic_writer/topic_writer.py +++ b/ydb/_topic_writer/topic_writer.py @@ -303,7 +303,7 @@ def messages_to_proto_requests( req = StreamWriteMessage.FromClient( StreamWriteMessage.WriteRequest( messages=[msg.to_message_data()], - codec=Codec.CODEC_RAW.value, + codec=Codec.CODEC_RAW, ) ) res.append(req) diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index a231a6b5..f1e6c455 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -363,6 +363,8 @@ async def _send_loop(self, writer: "WriterAsyncIOStream"): m = await self._new_messages.get() # type: InternalMessage if m.seq_no > last_seq_no: writer.write([m]) + except Exception as e: + await self._stop(e) finally: pass From 25bd6bb6874504ebd71d576a36d56bfd1e37960a Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Mon, 13 Feb 2023 18:41:45 +0300 Subject: [PATCH 042/147] style remove codec.value --- tests/conftest.py | 2 +- ydb/_topic_writer/topic_writer_asyncio_test.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index aa28a9bd..674422ec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -145,7 +145,7 @@ async def topic_with_messages(driver, topic_path): writer = driver.topic_client.topic_writer( topic_path, producer_and_message_group_id="fixture-producer-id" ) - res = await writer.write_with_ack( + await writer.write_with_ack( ydb.TopicWriterMessage(data="123".encode()), ydb.TopicWriterMessage(data="456".encode()), ) diff --git a/ydb/_topic_writer/topic_writer_asyncio_test.py b/ydb/_topic_writer/topic_writer_asyncio_test.py index 32fe9c02..c99e392f 100644 --- a/ydb/_topic_writer/topic_writer_asyncio_test.py +++ b/ydb/_topic_writer/topic_writer_asyncio_test.py @@ -59,7 +59,7 @@ async def writer_and_stream(self, stream) -> WriterWithMockedStream: last_seq_no=4, session_id="123", partition_id=3, - supported_codecs=[Codec.CODEC_RAW.value, Codec.CODEC_GZIP.value], + supported_codecs=[Codec.CODEC_RAW, Codec.CODEC_GZIP], status=ServerStatus(StatusCode.SUCCESS, []), ) ) @@ -131,7 +131,7 @@ async def test_write_a_message(self, writer_and_stream: WriterWithMockedStream): expected_message = StreamWriteMessage.FromClient( StreamWriteMessage.WriteRequest( - codec=Codec.CODEC_RAW.value, + codec=Codec.CODEC_RAW, messages=[ StreamWriteMessage.WriteRequest.MessageData( seq_no=1, From 1e74b11be42ecd9e7fdfb1e0d3cdf3b8d48fc521 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 14 Feb 2023 14:39:37 +0300 Subject: [PATCH 043/147] simplify topic writer --- ydb/_topic_writer/topic_writer_asyncio.py | 113 +++++++++++----------- 1 file changed, 56 insertions(+), 57 deletions(-) diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index f1e6c455..217249a1 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -40,7 +40,6 @@ class WriterAsyncIO: _loop: asyncio.AbstractEventLoop _reconnector: "WriterAsyncIOReconnector" - _lock: asyncio.Lock _closed: bool @property @@ -48,7 +47,6 @@ def last_seqno(self) -> int: raise NotImplementedError() def __init__(self, driver: SupportedDriverType, settings: PublicWriterSettings): - self._lock = asyncio.Lock() self._loop = asyncio.get_running_loop() self._closed = False self._reconnector = WriterAsyncIOReconnector( @@ -68,10 +66,10 @@ def __del__(self): self._loop.call_soon(self.close) async def close(self): - async with self._lock: - if self._closed: - return - self._closed = True + if self._closed: + return + + self._closed = True await self._reconnector.close() @@ -164,65 +162,81 @@ class WriterAsyncIOReconnector: _update_token_interval: int _token_get_function: TokenGetterFuncType _init_message: StreamWriteMessage.InitRequest - _new_messages: asyncio.Queue _init_info: asyncio.Future _stream_connected: asyncio.Event _settings: WriterSettings - _lock: asyncio.Lock _last_known_seq_no: int _messages: Deque[InternalMessage] _messages_future: Deque[asyncio.Future] - _stop_reason: Optional[Exception] + _new_messages: asyncio.Queue + _stop_reason: asyncio.Future _background_tasks: List[asyncio.Task] def __init__(self, driver: SupportedDriverType, settings: WriterSettings): self._driver = driver self._credentials = driver._credentials self._init_message = settings.create_init_request() - self._new_messages = asyncio.Queue() self._init_info = asyncio.Future() self._stream_connected = asyncio.Event() self._settings = settings - self._lock = asyncio.Lock() self._last_known_seq_no = 0 self._messages = deque() self._messages_future = deque() - self._stop_reason = None + self._new_messages = asyncio.Queue() + self._stop_reason = asyncio.Future() self._background_tasks = [ asyncio.create_task(self._connection_loop(), name="connection_loop") ] async def close(self): - await self._check_stop() - await self._stop(TopicWriterStopped()) + self._check_stop() + self._stop(TopicWriterStopped()) + + background_tasks = self._background_tasks + + for task in background_tasks: + task.cancel() + + await asyncio.wait(self._background_tasks) async def wait_init(self) -> PublicWriterInitInfo: - return await self._init_info + done, _ = await asyncio.wait( + [self._init_info, self._stop_reason], return_when=asyncio.FIRST_COMPLETED + ) + res = done.pop() # type: asyncio.Future + res_val = res.result() + + if isinstance(res_val, Exception): + raise res_val + + return res_val + + async def wait_stop(self) -> Exception: + return await self._stop_reason async def write_with_ack( self, messages: List[PublicMessage] ) -> List[asyncio.Future]: # todo check internal buffer limit - await self._check_stop() + self._check_stop() if self._settings.auto_seqno: await self.wait_init() - async with self._lock: - internal_messages = self._prepare_internal_messages_locked(messages) - messages_future = [asyncio.Future() for _ in internal_messages] + internal_messages = self._prepare_internal_messages(messages) + messages_future = [asyncio.Future() for _ in internal_messages] - self._messages.extend(internal_messages) - self._messages_future.extend(messages_future) + self._messages.extend(internal_messages) + self._messages_future.extend(messages_future) for m in internal_messages: self._new_messages.put_nowait(m) return messages_future - def _prepare_internal_messages_locked(self, messages: List[PublicMessage]): + def _prepare_internal_messages(self, messages: List[PublicMessage]): if self._settings.auto_created_at: now = datetime.datetime.now() else: @@ -263,10 +277,9 @@ def _prepare_internal_messages_locked(self, messages: List[PublicMessage]): return res - async def _check_stop(self): - async with self._lock: - if self._stop_reason is not None: - raise self._stop_reason + def _check_stop(self): + if self._stop_reason.done(): + raise self._stop_reason.result() async def _connection_loop(self): retry_settings = RetrySettings() # todo @@ -275,23 +288,16 @@ async def _connection_loop(self): attempt = 0 # todo calc and reset pending = [] - async def on_stop(e): - for t in pending: - self._background_tasks.append(t) - pending.clear() - await self._stop(e) - # noinspection PyBroadException try: stream_writer = await WriterAsyncIOStream.create( self._driver, self._init_message, self._get_token ) try: - async with self._lock: - self._last_known_seq_no = stream_writer.last_seqno - self._init_info.set_result( - PublicWriterInitInfo(last_seqno=stream_writer.last_seqno) - ) + self._last_known_seq_no = stream_writer.last_seqno + self._init_info.set_result( + PublicWriterInitInfo(last_seqno=stream_writer.last_seqno) + ) except asyncio.InvalidStateError: pass @@ -316,13 +322,13 @@ async def on_stop(e): err_info = check_retriable_error(err, retry_settings, attempt) if not err_info.is_retriable: - await on_stop(err) + self._stop(err) return await asyncio.sleep(err_info.sleep_timeout_seconds) - except Exception as e: - await on_stop(e) + except (asyncio.CancelledError, Exception) as err: + self._stop(err) return finally: if len(pending) > 0: @@ -333,11 +339,11 @@ async def on_stop(e): async def _read_loop(self, writer: "WriterAsyncIOStream"): while True: resp = await writer.receive() - async with self._lock: - for ack in resp.acks: - self._handle_receive_ack_need_lock(ack) - def _handle_receive_ack_need_lock(self, ack): + for ack in resp.acks: + self._handle_receive_ack(ack) + + def _handle_receive_ack(self, ack): current_message = self._messages.popleft() message_future = self._messages_future.popleft() if current_message.seq_no != ack.seq_no: @@ -351,8 +357,7 @@ def _handle_receive_ack_need_lock(self, ack): async def _send_loop(self, writer: "WriterAsyncIOStream"): try: - async with self._lock: - messages = list(self._messages) + messages = list(self._messages) last_seq_no = 0 for m in messages: @@ -364,24 +369,18 @@ async def _send_loop(self, writer: "WriterAsyncIOStream"): if m.seq_no > last_seq_no: writer.write([m]) except Exception as e: - await self._stop(e) + self._stop(e) finally: pass - async def _stop(self, reason: Exception): + def _stop(self, reason: Exception): if reason is None: raise Exception("writer stop reason can not be None") - async with self._lock: - if self._stop_reason is not None: - return - self._stop_reason = reason - background_tasks = self._background_tasks - - for task in background_tasks: - task.cancel() + if self._stop_reason.done(): + return - await asyncio.wait(self._background_tasks) + self._stop_reason.set_result(reason) def _get_token(self) -> str: raise NotImplementedError() From 1b95a9760b72eabec0dc3b89d49718959af05b34 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Wed, 15 Feb 2023 14:10:30 +0300 Subject: [PATCH 044/147] fix wait_init for raise exception better. --- ydb/_topic_writer/topic_writer_asyncio.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 217249a1..24cde408 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -208,7 +208,7 @@ async def wait_init(self) -> PublicWriterInitInfo: res = done.pop() # type: asyncio.Future res_val = res.result() - if isinstance(res_val, Exception): + if isinstance(res_val, BaseException): raise res_val return res_val From a4146ffab0f983080cc135437b3046ed7739f59c Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Wed, 15 Feb 2023 17:25:43 +0300 Subject: [PATCH 045/147] fix closed for internal writer --- ydb/_topic_writer/topic_writer_asyncio.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 24cde408..5d4583fc 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -157,6 +157,7 @@ async def wait_init(self) -> PublicWriterInitInfo: class WriterAsyncIOReconnector: + _closed: bool _credentials: Union[ydb.Credentials, None] _driver: ydb.aio.Driver _update_token_interval: int @@ -174,6 +175,7 @@ class WriterAsyncIOReconnector: _background_tasks: List[asyncio.Task] def __init__(self, driver: SupportedDriverType, settings: WriterSettings): + self._closed = False self._driver = driver self._credentials = driver._credentials self._init_message = settings.create_init_request() @@ -191,7 +193,11 @@ def __init__(self, driver: SupportedDriverType, settings: WriterSettings): ] async def close(self): - self._check_stop() + if self._closed: + return + + self._closed = True + self._stop(TopicWriterStopped()) background_tasks = self._background_tasks From 44edc47113fc649949751b80445685e7d037d959 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Wed, 15 Feb 2023 17:53:30 +0300 Subject: [PATCH 046/147] raise exception on first close if writer was stopped by error before close. --- ydb/_topic_writer/topic_writer_asyncio.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 5d4583fc..9c66109e 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -197,7 +197,6 @@ async def close(self): return self._closed = True - self._stop(TopicWriterStopped()) background_tasks = self._background_tasks @@ -207,6 +206,12 @@ async def close(self): await asyncio.wait(self._background_tasks) + # if work was stopped before close by error - raise the error + try: + self._check_stop() + except TopicWriterStopped: + pass + async def wait_init(self) -> PublicWriterInitInfo: done, _ = await asyncio.wait( [self._init_info, self._stop_reason], return_when=asyncio.FIRST_COMPLETED From 304d491660e4da8e868a9d2a83aa00d5c8614132 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Wed, 15 Feb 2023 19:03:01 +0300 Subject: [PATCH 047/147] remove unused code --- ydb/_topic_writer/topic_writer_asyncio.py | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 9c66109e..0eb9e5ca 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -429,29 +429,6 @@ async def create( await writer._start(stream, init_request) return writer - @staticmethod - async def _create_stream_from_async( - driver: ydb.aio.Driver, - init_request: StreamWriteMessage.InitRequest, - token_getter: TokenGetterFuncType, - ) -> "WriterAsyncIOStream": - return GrpcWrapperAsyncIO(StreamWriteMessage.FromServer.from_proto) - - @staticmethod - async def _create_from_sync( - driver: ydb.Driver, - init_request: StreamWriteMessage.InitRequest, - token_getter: TokenGetterFuncType, - ) -> "WriterAsyncIOStream": - stream = GrpcWrapperAsyncIO(StreamWriteMessage.FromServer.from_proto) - await stream.start( - driver, _apis.TopicService.Stub, _apis.TopicService.StreamWrite - ) - - writer = WriterAsyncIOStream(token_getter) - await writer._start(stream, init_request) - return writer - async def receive(self) -> StreamWriteMessage.WriteResponse: while True: item = await self._stream.receive() From 342d5be35b65f51f772b32eba4c735e109679bfe Mon Sep 17 00:00:00 2001 From: Valeriya Popova Date: Mon, 6 Feb 2023 18:17:21 +0300 Subject: [PATCH 048/147] fix passing set to execute params --- tests/aio/test_types.py | 3 +-- ydb/convert.py | 7 ++++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/aio/test_types.py b/tests/aio/test_types.py index 3225df10..17093896 100644 --- a/tests/aio/test_types.py +++ b/tests/aio/test_types.py @@ -29,7 +29,7 @@ ('{"foo":"bar"}', "JsonDocument"), (uuid4(), "Uuid"), ([1, 2, 3], "List"), - # ({1, 2, 3}, "Set"), # FIXME: AttributeError: 'set' object has no attribute 'items' + ({1: None, 2: None, 3: None}, "Set"), ([b"a", b"b", b"c"], "List"), ({"a": 1001, "b": 1002}, "Dict"), (("a", 1001), "Tuple"), @@ -47,7 +47,6 @@ async def test_types(driver, database, value, ydb_type): prepared = await session.prepare( f"DECLARE $param as {ydb_type}; SELECT $param as value" ) - result = await session.transaction().execute( prepared, {"$param": value}, commit_tx=True ) diff --git a/ydb/convert.py b/ydb/convert.py index 2867f695..02c7de0c 100644 --- a/ydb/convert.py +++ b/ydb/convert.py @@ -214,9 +214,10 @@ def _dict_to_pb(type_pb, value): for key, payload in value.items(): kv_pair = value_pb.pairs.add() kv_pair.key.MergeFrom(_from_native_value(type_pb.dict_type.key, key)) - kv_pair.payload.MergeFrom( - _from_native_value(type_pb.dict_type.payload, payload) - ) + if payload: + kv_pair.payload.MergeFrom( + _from_native_value(type_pb.dict_type.payload, payload) + ) return value_pb From 6e641ebf746d16530fe64729ac64d09d2e66afd6 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 8 Feb 2023 01:30:50 +0000 Subject: [PATCH 049/147] Bump cryptography from 3.4.7 to 39.0.1 Bumps [cryptography](https://github.com/pyca/cryptography) from 3.4.7 to 39.0.1. - [Release notes](https://github.com/pyca/cryptography/releases) - [Changelog](https://github.com/pyca/cryptography/blob/main/CHANGELOG.rst) - [Commits](https://github.com/pyca/cryptography/compare/3.4.7...39.0.1) --- updated-dependencies: - dependency-name: cryptography dependency-type: direct:production ... Signed-off-by: dependabot[bot] --- test-requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test-requirements.txt b/test-requirements.txt index 8eca233b..9f592875 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -5,7 +5,7 @@ cached-property==1.5.2 certifi==2022.12.7 cffi==1.14.6 charset-normalizer==2.0.1 -cryptography==3.4.7 +cryptography==39.0.1 distro==1.5.0 docker==5.0.0 docker-compose==1.29.2 From fe58bdfac37ccd0a4b69e5f43ea3002bffc1852c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 8 Feb 2023 00:45:21 +0000 Subject: [PATCH 050/147] Bump cryptography in /examples/reservations-bot-demo/cloud_function Bumps [cryptography](https://github.com/pyca/cryptography) from 3.3.2 to 39.0.1. - [Release notes](https://github.com/pyca/cryptography/releases) - [Changelog](https://github.com/pyca/cryptography/blob/main/CHANGELOG.rst) - [Commits](https://github.com/pyca/cryptography/compare/3.3.2...39.0.1) --- updated-dependencies: - dependency-name: cryptography dependency-type: direct:production ... Signed-off-by: dependabot[bot] --- examples/reservations-bot-demo/cloud_function/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/reservations-bot-demo/cloud_function/requirements.txt b/examples/reservations-bot-demo/cloud_function/requirements.txt index 2d39d3b8..9b243fbb 100644 --- a/examples/reservations-bot-demo/cloud_function/requirements.txt +++ b/examples/reservations-bot-demo/cloud_function/requirements.txt @@ -1,7 +1,7 @@ certifi==2020.6.20 cffi==1.14.2 chardet==3.0.4 -cryptography==3.3.2 +cryptography==39.0.1 enum-compat==0.0.3 googleapis-common-protos==1.52.0 grpcio==1.31.0 From 84ad0e27877c7fdc25338e6790d97ad88660ee0e Mon Sep 17 00:00:00 2001 From: Valeriya Popova Date: Fri, 20 Jan 2023 16:43:49 +0300 Subject: [PATCH 051/147] ydb sqlalchemy experiment --- .gitignore | 4 + examples/_sqlalchemy_example/example.py | 229 ++++++++++++++ examples/_sqlalchemy_example/fill_tables.py | 83 +++++ examples/_sqlalchemy_example/models.py | 34 ++ test-requirements.txt | 2 + tests/sqlalchemy/conftest.py | 22 ++ tests/sqlalchemy/test_dbapi.py | 84 +++++ tests/sqlalchemy/test_sqlalchemy.py | 27 ++ tox.ini | 8 +- ydb/_dbapi/__init__.py | 36 +++ ydb/_dbapi/connection.py | 73 +++++ ydb/_dbapi/cursor.py | 172 +++++++++++ ydb/_dbapi/errors.py | 92 ++++++ ydb/_sqlalchemy/__init__.py | 324 ++++++++++++++++++++ ydb/_sqlalchemy/types.py | 28 ++ 15 files changed, 1217 insertions(+), 1 deletion(-) create mode 100644 examples/_sqlalchemy_example/example.py create mode 100644 examples/_sqlalchemy_example/fill_tables.py create mode 100644 examples/_sqlalchemy_example/models.py create mode 100644 tests/sqlalchemy/conftest.py create mode 100644 tests/sqlalchemy/test_dbapi.py create mode 100644 tests/sqlalchemy/test_sqlalchemy.py create mode 100644 ydb/_dbapi/__init__.py create mode 100644 ydb/_dbapi/connection.py create mode 100644 ydb/_dbapi/cursor.py create mode 100644 ydb/_dbapi/errors.py create mode 100644 ydb/_sqlalchemy/__init__.py create mode 100644 ydb/_sqlalchemy/types.py diff --git a/.gitignore b/.gitignore index 45896947..55c4ea54 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,8 @@ ydb.egg-info/ /tox /venv /ydb_certs +/ydb_data /tmp +.coverage +/cov_html +/build diff --git a/examples/_sqlalchemy_example/example.py b/examples/_sqlalchemy_example/example.py new file mode 100644 index 00000000..00cd80d3 --- /dev/null +++ b/examples/_sqlalchemy_example/example.py @@ -0,0 +1,229 @@ +import datetime +import logging +import argparse +import sqlalchemy as sa +from sqlalchemy import orm, exc, sql +from sqlalchemy import Table, Column, Integer, String, Float, TIMESTAMP +from ydb._sqlalchemy import register_dialect + +from fill_tables import fill_all_tables, to_days +from models import Base, Series, Episodes + + +def describe_table(engine, name): + inspect = sa.inspect(engine) + print(f"describe table {name}:") + for col in inspect.get_columns(name): + print(f"\t{col['name']}: {col['type']}") + + +def simple_select(conn): + stm = sa.select(Series).where(Series.series_id == 1) + res = conn.execute(stm) + print(res.one()) + + +def simple_insert(conn): + stm = Episodes.__table__.insert().values( + series_id=3, season_id=6, episode_id=1, title="TBD" + ) + conn.execute(stm) + + +def test_types(conn): + types_tb = Table( + "test_types", + Base.metadata, + Column("id", Integer, primary_key=True), + Column("str", String), + Column("num", Float), + Column("dt", TIMESTAMP), + ) + types_tb.drop(bind=conn.engine, checkfirst=True) + types_tb.create(bind=conn.engine, checkfirst=True) + + stm = types_tb.insert().values( + id=1, + str=b"Hello World!", + num=3.1415, + dt=datetime.datetime.now(), + ) + conn.execute(stm) + + # GROUP BY + stm = sa.select(types_tb.c.str, sa.func.max(types_tb.c.num)).group_by( + types_tb.c.str + ) + rs = conn.execute(stm) + for x in rs: + print(x) + + +def run_example_orm(engine): + Base.metadata.bind = engine + Base.metadata.drop_all() + Base.metadata.create_all() + + session = orm.sessionmaker(bind=engine)() + + rs = session.query(Episodes).all() + for e in rs: + print(f"{e.episode_id}: {e.title}") + + fill_all_tables(session.connection()) + + try: + session.add_all( + [ + Episodes( + series_id=2, + season_id=1, + episode_id=1, + title="Minimum Viable Product", + air_date=to_days("2014-04-06"), + ), + Episodes( + series_id=2, + season_id=1, + episode_id=2, + title="The Cap Table", + air_date=to_days("2014-04-13"), + ), + Episodes( + series_id=2, + season_id=1, + episode_id=3, + title="Articles of Incorporation", + air_date=to_days("2014-04-20"), + ), + Episodes( + series_id=2, + season_id=1, + episode_id=4, + title="Fiduciary Duties", + air_date=to_days("2014-04-27"), + ), + Episodes( + series_id=2, + season_id=1, + episode_id=5, + title="Signaling Risk", + air_date=to_days("2014-05-04"), + ), + ] + ) + session.commit() + except exc.DatabaseError: + print("Episodes already added!") + session.rollback() + + rs = session.query(Episodes).all() + for e in rs: + print(f"{e.episode_id}: {e.title}") + + rs = session.query(Episodes).filter(Episodes.title == "abc??").all() + for e in rs: + print(e.title) + + print("Episodes count:", session.query(Episodes).count()) + + max_episode = session.query(sql.expression.func.max(Episodes.episode_id)).scalar() + print("Maximum episodes id:", max_episode) + + session.add( + Episodes( + series_id=2, + season_id=1, + episode_id=max_episode + 1, + title="Signaling Risk", + air_date=to_days("2014-05-04"), + ) + ) + + print("Episodes count:", session.query(Episodes).count()) + + +def run_example_core(engine): + with engine.connect() as conn: + # raw sql + rs = conn.execute("SELECT 1 AS value") + print(rs.fetchone()["value"]) + + fill_all_tables(conn) + + for t in "series seasons episodes".split(): + describe_table(engine, t) + + tb = sa.Table("episodes", sa.MetaData(engine), autoload=True) + stm = ( + sa.select([tb.c.title]) + .where(sa.and_(tb.c.series_id == 1, tb.c.season_id == 3)) + .where(tb.c.title.like("%")) + .order_by(sa.asc(tb.c.title)) + # TODO: limit isn't working now + # .limit(3) + ) + rs = conn.execute(stm) + print(rs.fetchall()) + + simple_select(conn) + + simple_insert(conn) + + # simple join + stm = sa.select( + [Episodes.__table__.join(Series, Episodes.series_id == Series.series_id)] + ).where(sa.and_(Series.series_id == 1, Episodes.season_id == 1)) + rs = conn.execute(stm) + for row in rs: + print(f"{row.series_title}({row.episode_id}): {row.title}") + + rs = conn.execute(sa.select(Episodes).where(Episodes.series_id == 3)) + print(rs.fetchall()) + + # count + cnt = conn.execute(sa.func.count(Episodes.episode_id)).scalar() + print("Episodes cnt:", cnt) + + # simple delete + conn.execute(sa.delete(Episodes).where(Episodes.title == "TBD")) + cnt = conn.execute(sa.func.count(Episodes.episode_id)).scalar() + print("Episodes cnt:", cnt) + + test_types(conn) + + +def main(): + parser = argparse.ArgumentParser( + formatter_class=argparse.RawDescriptionHelpFormatter, + description="""\033[92mYandex.Database examples sqlalchemy usage.\x1b[0m\n""", + ) + parser.add_argument( + "-d", + "--database", + help="Name of the database to use", + default="/local", + ) + parser.add_argument( + "-e", + "--endpoint", + help="Endpoint url to use", + default="grpc://localhost:2136", + ) + + args = parser.parse_args() + register_dialect() + engine = sa.create_engine( + "yql:///ydb/", + connect_args={"database": args.database, "endpoint": args.endpoint}, + ) + + logging.basicConfig(level=logging.INFO) + logging.getLogger("sqlalchemy.engine").setLevel(logging.INFO) + + run_example_core(engine) + # run_example_orm(engine) + + +if __name__ == "__main__": + main() diff --git a/examples/_sqlalchemy_example/fill_tables.py b/examples/_sqlalchemy_example/fill_tables.py new file mode 100644 index 00000000..5a9eb954 --- /dev/null +++ b/examples/_sqlalchemy_example/fill_tables.py @@ -0,0 +1,83 @@ +import iso8601 + +import sqlalchemy as sa +from models import Base, Series, Seasons, Episodes + + +def to_days(date): + timedelta = iso8601.parse_date(date) - iso8601.parse_date("1970-1-1") + return timedelta.days + + +def fill_series(conn): + data = [ + ( + 1, + "IT Crowd", + "The IT Crowd is a British sitcom produced by Channel 4, written by Graham Linehan, produced by " + "Ash Atalla and starring Chris O'Dowd, Richard Ayoade, Katherine Parkinson, and Matt Berry.", + to_days("2006-02-03"), + ), + ( + 2, + "Silicon Valley", + "Silicon Valley is an American comedy television series created by Mike Judge, John Altschuler and " + "Dave Krinsky. The series focuses on five young men who founded a startup company in Silicon Valley.", + to_days("2014-04-06"), + ), + ] + conn.execute(sa.insert(Series).values(data)) + + +def fill_seasons(conn): + data = [ + (1, 1, "Season 1", to_days("2006-02-03"), to_days("2006-03-03")), + (1, 2, "Season 2", to_days("2007-08-24"), to_days("2007-09-28")), + (1, 3, "Season 3", to_days("2008-11-21"), to_days("2008-12-26")), + (1, 4, "Season 4", to_days("2010-06-25"), to_days("2010-07-30")), + (2, 1, "Season 1", to_days("2014-04-06"), to_days("2014-06-01")), + (2, 2, "Season 2", to_days("2015-04-12"), to_days("2015-06-14")), + (2, 3, "Season 3", to_days("2016-04-24"), to_days("2016-06-26")), + (2, 4, "Season 4", to_days("2017-04-23"), to_days("2017-06-25")), + (2, 5, "Season 5", to_days("2018-03-25"), to_days("2018-05-13")), + ] + conn.execute(sa.insert(Seasons).values(data)) + + +def fill_episodes(conn): + data = [ + (1, 1, 1, "Yesterday's Jam", to_days("2006-02-03")), + (1, 1, 2, "Calamity Jen", to_days("2006-02-03")), + (1, 1, 3, "Fifty-Fifty", to_days("2006-02-10")), + (1, 1, 4, "The Red Door", to_days("2006-02-17")), + (1, 1, 5, "The Haunting of Bill Crouse", to_days("2006-02-24")), + (1, 1, 6, "Aunt Irma Visits", to_days("2006-03-03")), + (1, 2, 1, "The Work Outing", to_days("2006-08-24")), + (1, 2, 2, "Return of the Golden Child", to_days("2007-08-31")), + (1, 2, 3, "Moss and the German", to_days("2007-09-07")), + (1, 2, 4, "The Dinner Party", to_days("2007-09-14")), + (1, 2, 5, "Smoke and Mirrors", to_days("2007-09-21")), + (1, 2, 6, "Men Without Women", to_days("2007-09-28")), + (1, 3, 1, "From Hell", to_days("2008-11-21")), + (1, 3, 2, "Are We Not Men?", to_days("2008-11-28")), + (1, 3, 3, "Tramps Like Us", to_days("2008-12-05")), + (1, 3, 4, "The Speech", to_days("2008-12-12")), + (1, 3, 5, "Friendface", to_days("2008-12-19")), + (1, 3, 6, "Calendar Geeks", to_days("2008-12-26")), + (1, 4, 1, "Jen The Fredo", to_days("2010-06-25")), + (1, 4, 2, "The Final Countdown", to_days("2010-07-02")), + (1, 4, 3, "Something Happened", to_days("2010-07-09")), + (1, 4, 4, "Italian For Beginners", to_days("2010-07-16")), + (1, 4, 5, "Bad Boys", to_days("2010-07-23")), + (1, 4, 6, "Reynholm vs Reynholm", to_days("2010-07-30")), + ] + conn.execute(sa.insert(Episodes).values(data)) + + +def fill_all_tables(conn): + Base.metadata.drop_all(conn.engine) + Base.metadata.create_all(conn.engine) + + fill_series(conn) + fill_seasons(conn) + fill_episodes(conn) diff --git a/examples/_sqlalchemy_example/models.py b/examples/_sqlalchemy_example/models.py new file mode 100644 index 00000000..a02349a9 --- /dev/null +++ b/examples/_sqlalchemy_example/models.py @@ -0,0 +1,34 @@ +import sqlalchemy.orm as orm +from sqlalchemy import Column, Integer, Unicode + + +Base = orm.declarative_base() + + +class Series(Base): + __tablename__ = "series" + + series_id = Column(Integer, primary_key=True) + title = Column(Unicode) + series_info = Column(Unicode) + release_date = Column(Integer) + + +class Seasons(Base): + __tablename__ = "seasons" + + series_id = Column(Integer, primary_key=True) + season_id = Column(Integer, primary_key=True) + title = Column(Unicode) + first_aired = Column(Integer) + last_aired = Column(Integer) + + +class Episodes(Base): + __tablename__ = "episodes" + + series_id = Column(Integer, primary_key=True) + season_id = Column(Integer, primary_key=True) + episode_id = Column(Integer, primary_key=True) + title = Column(Unicode) + air_date = Column(Integer) diff --git a/test-requirements.txt b/test-requirements.txt index 9f592875..d1ca4276 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -47,3 +47,5 @@ sqlalchemy==1.4.26 pylint-protobuf cython freezegun==1.2.2 +grpcio-tools +pytest-cov diff --git a/tests/sqlalchemy/conftest.py b/tests/sqlalchemy/conftest.py new file mode 100644 index 00000000..6ebabac3 --- /dev/null +++ b/tests/sqlalchemy/conftest.py @@ -0,0 +1,22 @@ +import pytest +import sqlalchemy as sa + +from ydb._sqlalchemy import register_dialect + + +@pytest.fixture(scope="module") +def engine(endpoint, database): + register_dialect() + engine = sa.create_engine( + "yql:///ydb/", + connect_args={"database": database, "endpoint": endpoint}, + ) + + yield engine + engine.dispose() + + +@pytest.fixture(scope="module") +def connection(engine): + with engine.connect() as conn: + yield conn diff --git a/tests/sqlalchemy/test_dbapi.py b/tests/sqlalchemy/test_dbapi.py new file mode 100644 index 00000000..407cc9e4 --- /dev/null +++ b/tests/sqlalchemy/test_dbapi.py @@ -0,0 +1,84 @@ +from ydb import _dbapi as dbapi + + +def test_dbapi(endpoint, database): + conn = dbapi.connect(endpoint, database=database) + assert conn + + conn.commit() + conn.rollback() + + cur = conn.cursor() + assert cur + + cur.execute( + "CREATE TABLE test(id Int64 NOT NULL, text Utf8, PRIMARY KEY (id))", + context={"isddl": True}, + ) + + cur.execute('INSERT INTO test(id, text) VALUES (1, "foo")') + + cur.execute("SELECT id, text FROM test") + assert cur.fetchone() == (1, "foo"), "fetchone is ok" + + cur.execute("SELECT id, text FROM test WHERE id = %(id)s", {"id": 1}) + assert cur.fetchone() == (1, "foo"), "parametrized query is ok" + + cur.execute( + "INSERT INTO test(id, text) VALUES (%(id1)s, %(text1)s), (%(id2)s, %(text2)s)", + {"id1": 2, "text1": "", "id2": 3, "text2": "bar"}, + ) + + cur.execute( + "UPDATE test SET text = %(t)s WHERE id = %(id)s", {"id": 2, "t": "foo2"} + ) + + cur.execute("SELECT id FROM test") + assert cur.fetchall() == [(1,), (2,), (3,)], "fetchall is ok" + + cur.execute("SELECT id FROM test ORDER BY id DESC") + assert cur.fetchmany(2) == [(3,), (2,)], "fetchmany is ok" + assert cur.fetchmany(1) == [(1,)] + + cur.execute("SELECT id FROM test ORDER BY id LIMIT 2") + assert cur.fetchall() == [(1,), (2,)], "limit clause without params is ok" + + # TODO: Failed to convert type: Int64 to Uint64 + # cur.execute("SELECT id FROM test ORDER BY id LIMIT %(limit)s", {"limit": 2}) + # assert cur.fetchall() == [(1,), (2,)], "limit clause with params is ok" + + cur2 = conn.cursor() + cur2.execute( + "INSERT INTO test(id) VALUES (%(id1)s), (%(id2)s)", {"id1": 5, "id2": 6} + ) + + cur.execute("SELECT id FROM test ORDER BY id") + assert cur.fetchall() == [(1,), (2,), (3,), (5,), (6,)], "cursor2 commit changes" + + cur.execute("SELECT text FROM test WHERE id > %(min_id)s", {"min_id": 3}) + assert cur.fetchall() == [(None,), (None,)], "NULL returns as None" + + cur.execute("SELECT id, text FROM test WHERE text LIKE %(p)s", {"p": "foo%"}) + assert cur.fetchall() == [(1, "foo"), (2, "foo2")], "like clause works" + + cur.execute( + # DECLARE statement (DECLARE $data AS List>) + # will generate automatically + """INSERT INTO test SELECT id, text FROM AS_TABLE($data);""", + { + "data": [ + {"id": 17, "text": "seventeen"}, + {"id": 21, "text": "twenty one"}, + ] + }, + ) + + cur.execute("SELECT id FROM test ORDER BY id") + assert cur.rowcount == 7, "rowcount ok" + assert cur.fetchall() == [(1,), (2,), (3,), (5,), (6,), (17,), (21,)], "ok" + + cur.execute("DROP TABLE test", context={"isddl": True}) + + cur.close() + cur2.close() + conn.close() diff --git a/tests/sqlalchemy/test_sqlalchemy.py b/tests/sqlalchemy/test_sqlalchemy.py new file mode 100644 index 00000000..914553ea --- /dev/null +++ b/tests/sqlalchemy/test_sqlalchemy.py @@ -0,0 +1,27 @@ +import sqlalchemy as sa +from sqlalchemy import MetaData, Table, Column, Integer, Unicode + +meta = MetaData() + + +def clear_sql(stm): + return stm.replace("\n", " ").replace(" ", " ").strip() + + +def test_sqlalchemy_core(connection): + # raw sql + rs = connection.execute("SELECT 1 AS value") + assert rs.fetchone()["value"] == 1 + + tb_test = Table( + "test", + meta, + Column("id", Integer, primary_key=True), + Column("text", Unicode), + ) + + stm = sa.select(tb_test) + assert clear_sql(str(stm)) == "SELECT test.id, test.text FROM test" + + stm = sa.insert(tb_test).values(id=2, text="foo") + assert clear_sql(str(stm)) == "INSERT INTO test (id, text) VALUES (:id, :text)" diff --git a/tox.ini b/tox.ini index 28181d20..7aca13db 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py,py-proto3,py-tls,py-tls-proto3,style,pylint,black,protoc +envlist = py,py-proto3,py-tls,py-tls-proto3,style,pylint,black,protoc,py-cov minversion = 4.2.6 skipsdist = True ignore_basepython_conflict = true @@ -25,6 +25,12 @@ deps = commands = pytest -v -m "not tls" --docker-compose-remove-volumes --docker-compose=docker-compose.yml {posargs} +[testenv:py-cov] +commands = + pytest -v -m "not tls" \ + --cov-report html:cov_html --cov=ydb \ + --docker-compose-remove-volumes --docker-compose=docker-compose.yml {posargs} + [testenv:py-proto3] commands = pytest -v -m "not tls" --docker-compose-remove-volumes --docker-compose=docker-compose.yml {posargs} diff --git a/ydb/_dbapi/__init__.py b/ydb/_dbapi/__init__.py new file mode 100644 index 00000000..8756b0f2 --- /dev/null +++ b/ydb/_dbapi/__init__.py @@ -0,0 +1,36 @@ +from .connection import Connection +from .errors import ( + Warning, + Error, + InterfaceError, + DatabaseError, + DataError, + OperationalError, + IntegrityError, + InternalError, + ProgrammingError, + NotSupportedError, +) + +apilevel = "1.0" + +threadsafety = 0 + +paramstyle = "pyformat" + +errors = ( + Warning, + Error, + InterfaceError, + DatabaseError, + DataError, + OperationalError, + IntegrityError, + InternalError, + ProgrammingError, + NotSupportedError, +) + + +def connect(*args, **kwargs): + return Connection(*args, **kwargs) diff --git a/ydb/_dbapi/connection.py b/ydb/_dbapi/connection.py new file mode 100644 index 00000000..75bfeb58 --- /dev/null +++ b/ydb/_dbapi/connection.py @@ -0,0 +1,73 @@ +import posixpath + +import ydb +from .cursor import Cursor +from .errors import DatabaseError + + +class Connection: + def __init__(self, endpoint, database=None, **conn_kwargs): + self.endpoint = endpoint + self.database = database + self.driver = self._create_driver(self.endpoint, self.database, **conn_kwargs) + self.pool = ydb.SessionPool(self.driver) + + def cursor(self): + return Cursor(self) + + def describe(self, table_path): + full_path = posixpath.join(self.database, table_path) + try: + res = self.pool.retry_operation_sync( + lambda cli: cli.describe_table(full_path) + ) + return res.columns + except ydb.Error as e: + raise DatabaseError(e.message, e.issues, e.status) + except Exception: + raise DatabaseError(f"Failed to describe table {table_path}") + + def check_exists(self, table_path): + try: + self.driver.scheme_client.describe_path(table_path) + return True + except ydb.SchemeError: + return False + + def commit(self): + pass + + def rollback(self): + pass + + def close(self): + if self.pool: + self.pool.stop() + if self.driver: + self.driver.stop() + + @staticmethod + def _create_driver(endpoint, database, **conn_kwargs): + # TODO: add cache for initialized drivers/pools? + driver_config = ydb.DriverConfig( + endpoint, + database=database, + table_client_settings=ydb.TableClientSettings() + .with_native_date_in_result_sets(True) + .with_native_datetime_in_result_sets(True) + .with_native_timestamp_in_result_sets(True) + .with_native_interval_in_result_sets(True) + .with_native_json_in_result_sets(True), + **conn_kwargs, + ) + driver = ydb.Driver(driver_config) + try: + driver.wait(timeout=5, fail_fast=True) + except ydb.Error as e: + raise DatabaseError(e.message, e.issues, e.status) + except Exception: + driver.stop() + raise DatabaseError( + f"Failed to connect to YDB, details {driver.discovery_debug_details()}" + ) + return driver diff --git a/ydb/_dbapi/cursor.py b/ydb/_dbapi/cursor.py new file mode 100644 index 00000000..57659c7a --- /dev/null +++ b/ydb/_dbapi/cursor.py @@ -0,0 +1,172 @@ +import datetime +import itertools +import logging +import uuid +import decimal + +import ydb +from .errors import DatabaseError, ProgrammingError + + +logger = logging.getLogger(__name__) + + +def get_column_type(type_obj): + return str(ydb.convert.type_to_native(type_obj)) + + +def _generate_type_str(value): + tvalue = type(value) + + stype = { + bool: "Bool", + bytes: "String", + str: "Utf8", + int: "Int64", + float: "Double", + decimal.Decimal: "Decimal(22, 9)", + datetime.date: "Date", + datetime.datetime: "Timestamp", + datetime.timedelta: "Interval", + uuid.UUID: "Uuid", + }.get(tvalue) + + if tvalue == dict: + types_lst = ", ".join(f"{k}: {_generate_type_str(v)}" for k, v in value.items()) + stype = f"Struct<{types_lst}>" + + elif tvalue == tuple: + types_lst = ", ".join(_generate_type_str(x) for x in value) + stype = f"Tuple<{types_lst}>" + + elif tvalue == list: + nested_type = _generate_type_str(value[0]) + stype = f"List<{nested_type}>" + + elif tvalue == set: + nested_type = _generate_type_str(next(iter(value))) + stype = f"Set<{nested_type}>" + + if stype is None: + raise ProgrammingError( + "Cannot translate python type to ydb type.", tvalue, value + ) + + return stype + + +def _generate_declare_stms(params: dict) -> str: + return "".join( + f"DECLARE {k} AS {_generate_type_str(t)}; " for k, t in params.items() + ) + + +class Cursor(object): + def __init__(self, connection): + self.connection = connection + self.description = None + self.arraysize = 1 + self.rows = None + self._rows_prefetched = None + + def execute(self, sql, parameters=None, context=None): + self.description = None + sql_params = None + + if parameters: + sql = sql % {k: f"${k}" for k, v in parameters.items()} + sql_params = {f"${k}": v for k, v in parameters.items()} + declare_stms = _generate_declare_stms(sql_params) + sql = f"{declare_stms}{sql}" + + logger.info("execute sql: %s, params: %s", sql, sql_params) + + def _execute_in_pool(cli): + try: + if context and context.get("isddl"): + return cli.execute_scheme(sql) + else: + prepared_query = cli.prepare(sql) + return cli.transaction().execute( + prepared_query, sql_params, commit_tx=True + ) + except ydb.Error as e: + raise DatabaseError(e.message, e.issues, e.status) + + chunks = self.connection.pool.retry_operation_sync(_execute_in_pool) + rows = self._rows_iterable(chunks) + # Prefetch the description: + try: + first_row = next(rows) + except StopIteration: + pass + else: + rows = itertools.chain((first_row,), rows) + if self.rows is not None: + rows = itertools.chain(self.rows, rows) + + self.rows = rows + + def _rows_iterable(self, chunks_iterable): + try: + for chunk in chunks_iterable: + self.description = [ + ( + col.name, + get_column_type(col.type), + None, + None, + None, + None, + None, + ) + for col in chunk.columns + ] + for row in chunk.rows: + # returns tuple to be compatible with SqlAlchemy and because + # of this PEP to return a sequence: https://www.python.org/dev/peps/pep-0249/#fetchmany + yield row[::] + except ydb.Error as e: + raise DatabaseError(e.message, e.issues, e.status) + + def _ensure_prefetched(self): + if self.rows is not None and self._rows_prefetched is None: + self._rows_prefetched = list(self.rows) + self.rows = iter(self._rows_prefetched) + return self._rows_prefetched + + def executemany(self, sql, seq_of_parameters): + for parameters in seq_of_parameters: + self.execute(sql, parameters) + + def executescript(self, script): + return self.execute(script) + + def fetchone(self): + if self.rows is None: + return None + return next(self.rows, None) + + def fetchmany(self, size=None): + size = self.arraysize if size is None else size + return list(itertools.islice(self.rows, size)) + + def fetchall(self): + return list(self.rows) + + def nextset(self): + self.fetchall() + + def setinputsizes(self, sizes): + pass + + def setoutputsize(self, column=None): + pass + + def close(self): + self.rows = None + self._rows_prefetched = None + + @property + def rowcount(self): + return len(self._ensure_prefetched()) diff --git a/ydb/_dbapi/errors.py b/ydb/_dbapi/errors.py new file mode 100644 index 00000000..ddb55b4c --- /dev/null +++ b/ydb/_dbapi/errors.py @@ -0,0 +1,92 @@ +class Warning(Exception): + pass + + +class Error(Exception): + def __init__(self, message, issues=None, status=None): + super(Error, self).__init__(message) + + pretty_issues = _pretty_issues(issues) + self.issues = issues + self.message = pretty_issues or message + self.status = status + + +class InterfaceError(Error): + pass + + +class DatabaseError(Error): + pass + + +class DataError(DatabaseError): + pass + + +class OperationalError(DatabaseError): + pass + + +class IntegrityError(DatabaseError): + pass + + +class InternalError(DatabaseError): + pass + + +class ProgrammingError(DatabaseError): + pass + + +class NotSupportedError(DatabaseError): + pass + + +def _pretty_issues(issues): + if issues is None: + return None + + children_messages = [_get_messages(issue, root=True) for issue in issues] + + if None in children_messages: + return None + + return "\n" + "\n".join(children_messages) + + +def _get_messages(issue, max_depth=100, indent=2, depth=0, root=False): + if depth >= max_depth: + return None + + margin_str = " " * depth * indent + pre_message = "" + children = "" + + if issue.issues: + collapsed_messages = [] + while not root and len(issue.issues) == 1: + collapsed_messages.append(issue.message) + issue = issue.issues[0] + + if collapsed_messages: + pre_message = f"{margin_str}{', '.join(collapsed_messages)}\n" + depth += 1 + margin_str = " " * depth * indent + + children_messages = [ + _get_messages(iss, max_depth=max_depth, indent=indent, depth=depth + 1) + for iss in issue.issues + ] + + if None in children_messages: + return None + + children = "\n".join(children_messages) + + return ( + f"{pre_message}{margin_str}{issue.message}\n{margin_str}" + f"severity level: {issue.severity}\n{margin_str}" + f"issue code: {issue.issue_code}\n{children}" + ) diff --git a/ydb/_sqlalchemy/__init__.py b/ydb/_sqlalchemy/__init__.py new file mode 100644 index 00000000..8336a9a8 --- /dev/null +++ b/ydb/_sqlalchemy/__init__.py @@ -0,0 +1,324 @@ +""" +Experimental +Work in progress, breaking changes are possible. +""" +import ydb +import ydb._dbapi as dbapi + +import sqlalchemy as sa +from sqlalchemy import dialects +from sqlalchemy import Table +from sqlalchemy.exc import CompileError +from sqlalchemy.sql import functions, literal_column +from sqlalchemy.sql.compiler import ( + IdentifierPreparer, + GenericTypeCompiler, + SQLCompiler, + DDLCompiler, +) +from sqlalchemy.sql.elements import ClauseList +from sqlalchemy.engine.default import DefaultDialect +from sqlalchemy.util.compat import inspect_getfullargspec + +from ydb._sqlalchemy.types import UInt32, UInt64 + + +SQLALCHEMY_VERSION = tuple(sa.__version__.split(".")) +SA_14 = SQLALCHEMY_VERSION >= ("1", "4") + + +class YqlIdentifierPreparer(IdentifierPreparer): + def __init__(self, dialect): + super(YqlIdentifierPreparer, self).__init__( + dialect, + initial_quote="`", + final_quote="`", + ) + + def _requires_quotes(self, value): + # Force all identifiers to get quoted unless already quoted. + return not ( + value.startswith(self.initial_quote) and value.endswith(self.final_quote) + ) + + +class YqlTypeCompiler(GenericTypeCompiler): + def visit_VARCHAR(self, type_, **kw): + return "STRING" + + def visit_unicode(self, type_, **kw): + return "UTF8" + + def visit_NVARCHAR(self, type_, **kw): + return "UTF8" + + def visit_TEXT(self, type_, **kw): + return "UTF8" + + def visit_FLOAT(self, type_, **kw): + return "DOUBLE" + + def visit_BOOLEAN(self, type_, **kw): + return "BOOL" + + def visit_uint32(self, type_, **kw): + return "UInt32" + + def visit_uint64(self, type_, **kw): + return "UInt64" + + def visit_uint8(self, type_, **kw): + return "UInt8" + + def visit_INTEGER(self, type_, **kw): + return "Int64" + + def visit_NUMERIC(self, type_, **kw): + return "Int64" + + +class ParametrizedFunction(functions.Function): + __visit_name__ = "parametrized_function" + + def __init__(self, name, params, *args, **kwargs): + super(ParametrizedFunction, self).__init__(name, *args, **kwargs) + self._func_name = name + self._func_params = params + self.params_expr = ClauseList( + operator=functions.operators.comma_op, group_contents=True, *params + ).self_group() + + +class YqlCompiler(SQLCompiler): + def group_by_clause(self, select, **kw): + # Hack to ensure it is possible to define labels in groupby. + kw.update(within_columns_clause=True) + return super(YqlCompiler, self).group_by_clause(select, **kw) + + def visit_lambda(self, lambda_, **kw): + func = lambda_.func + spec = inspect_getfullargspec(func) + + if spec.varargs: + raise CompileError("Lambdas with *args are not supported") + + try: + keywords = spec.keywords + except AttributeError: + keywords = spec.varkw + + if keywords: + raise CompileError("Lambdas with **kwargs are not supported") + + text = "(" + ", ".join("$" + arg for arg in spec.args) + ")" + " -> " + + args = [literal_column("$" + arg) for arg in spec.args] + text += "{ RETURN " + self.process(func(*args), **kw) + " ;}" + + return text + + def visit_parametrized_function(self, func, **kwargs): + name = func.name + name_parts = [] + for name in name.split("::"): + fname = ( + self.preparer.quote(name) + if self.preparer._requires_quotes_illegal_chars(name) + or isinstance(name, sa.sql.elements.quoted_name) + else name + ) + + name_parts.append(fname) + + name = "::".join(name_parts) + params = func.params_expr._compiler_dispatch(self, **kwargs) + args = self.function_argspec(func, **kwargs) + return "%(name)s%(params)s%(args)s" % dict(name=name, params=params, args=args) + + def visit_function(self, func, add_to_result_map=None, **kwargs): + # Copypaste of `sa.sql.compiler.SQLCompiler.visit_function` with + # `::` as namespace separator instead of `.` + if add_to_result_map is not None: + add_to_result_map(func.name, func.name, (), func.type) + + disp = getattr(self, "visit_%s_func" % func.name.lower(), None) + if disp: + return disp(func, **kwargs) + else: + name = sa.sql.compiler.FUNCTIONS.get(func.__class__, None) + if name: + if func._has_args: + name += "%(expr)s" + else: + name = func.name + name = ( + self.preparer.quote(name) + if self.preparer._requires_quotes_illegal_chars(name) + or isinstance(name, sa.sql.elements.quoted_name) + else name + ) + name = name + "%(expr)s" + return "::".join( + [ + ( + self.preparer.quote(tok) + if self.preparer._requires_quotes_illegal_chars(tok) + or isinstance(name, sa.sql.elements.quoted_name) + else tok + ) + for tok in func.packagenames + ] + + [name] + ) % {"expr": self.function_argspec(func, **kwargs)} + + +class YqlDdlCompiler(DDLCompiler): + pass + + +def upsert(table): + return sa.sql.Insert(table) + + +COLUMN_TYPES = { + ydb.PrimitiveType.Int8: sa.INTEGER, + ydb.PrimitiveType.Int16: sa.INTEGER, + ydb.PrimitiveType.Int32: sa.INTEGER, + ydb.PrimitiveType.Int64: sa.INTEGER, + ydb.PrimitiveType.Uint8: sa.INTEGER, + ydb.PrimitiveType.Uint16: sa.INTEGER, + ydb.PrimitiveType.Uint32: UInt32, + ydb.PrimitiveType.Uint64: UInt64, + ydb.PrimitiveType.Float: sa.FLOAT, + ydb.PrimitiveType.Double: sa.FLOAT, + ydb.PrimitiveType.String: sa.TEXT, + ydb.PrimitiveType.Utf8: sa.TEXT, + ydb.PrimitiveType.Json: sa.JSON, + ydb.PrimitiveType.JsonDocument: sa.JSON, + ydb.DecimalType: sa.DECIMAL, + ydb.PrimitiveType.Yson: sa.TEXT, + ydb.PrimitiveType.Date: sa.DATE, + ydb.PrimitiveType.Datetime: sa.DATETIME, + ydb.PrimitiveType.Timestamp: sa.DATETIME, + ydb.PrimitiveType.Interval: sa.INTEGER, + ydb.PrimitiveType.Bool: sa.BOOLEAN, + ydb.PrimitiveType.DyNumber: sa.TEXT, +} + + +def _get_column_type(t): + if isinstance(t, ydb.OptionalType): + t = t.item + + if isinstance(t, ydb.DecimalType): + return sa.DECIMAL(precision=t.item.precision, scale=t.item.scale) + + return COLUMN_TYPES[t] + + +class YqlDialect(DefaultDialect): + name = "yql" + supports_alter = False + max_identifier_length = 63 + supports_sane_rowcount = False + supports_statement_cache = False + + supports_native_enum = False + supports_native_boolean = True + supports_smallserial = False + + supports_sequences = False + sequences_optional = True + preexecute_autoincrement_sequences = True + postfetch_lastrowid = False + + supports_default_values = False + supports_empty_insert = False + supports_multivalues_insert = True + default_paramstyle = "qmark" + + isolation_level = None + + preparer = YqlIdentifierPreparer + statement_compiler = YqlCompiler + ddl_compiler = YqlDdlCompiler + type_compiler = YqlTypeCompiler + + driver = ydb.Driver + + @staticmethod + def dbapi(): + return dbapi + + def _check_unicode_returns(self, *args, **kwargs): + # Normally, this would do 2 SQL queries, which isn't quite necessary. + return "conditional" + + def get_columns(self, connection, table_name, schema=None, **kw): + if schema is not None: + raise dbapi.errors.NotSupportedError("unsupported on non empty schema") + + qt = table_name.name if isinstance(table_name, Table) else table_name + + if SA_14: + raw_conn = connection.connection + else: + raw_conn = connection.raw_connection() + + columns = raw_conn.describe(qt) + as_compatible = [] + for column in columns: + as_compatible.append( + { + "name": column.name, + "type": _get_column_type(column.type), + "nullable": True, + } + ) + + return as_compatible + + def has_table(self, connection, table_name, schema=None, **kwargs): + if schema is not None: + raise dbapi.errors.NotSupportedError("unsupported on non empty schema") + + quote = self.identifier_preparer.quote_identifier + qtable = quote(table_name) + + # TODO: use `get_columns` instead. + statement = "SELECT * FROM " + qtable + try: + connection.execute(statement) + return True + except Exception: + return False + + def get_pk_constraint(self, connection, table_name, schema=None, **kwargs): + # TODO: implement me + return [] + + def get_foreign_keys(self, connection, table_name, schema=None, **kwargs): + # foreign keys unsupported + return [] + + def get_indexes(self, connection, table_name, schema=None, **kwargs): + # TODO: implement me + return [] + + def do_commit(self, dbapi_connection) -> None: + # TODO: needs to implement? + pass + + def do_execute(self, cursor, statement, parameters, context=None) -> None: + c = None + if context is not None and context.isddl: + c = {"isddl": True} + cursor.execute(statement, parameters, c) + + +def register_dialect( + name="yql", + module=__name__, + cls="YqlDialect", +): + return dialects.registry.register(name, module, cls) diff --git a/ydb/_sqlalchemy/types.py b/ydb/_sqlalchemy/types.py new file mode 100644 index 00000000..21748ec1 --- /dev/null +++ b/ydb/_sqlalchemy/types.py @@ -0,0 +1,28 @@ +from sqlalchemy.types import Integer +from sqlalchemy.sql import type_api +from sqlalchemy.sql.elements import ColumnElement +from sqlalchemy import util, exc + + +class UInt32(Integer): + __visit_name__ = "uint32" + + +class UInt64(Integer): + __visit_name__ = "uint64" + + +class UInt8(Integer): + __visit_name__ = "uint8" + + +class Lambda(ColumnElement): + + __visit_name__ = "lambda" + + def __init__(self, func): + if not util.callable(func): + raise exc.ArgumentError("func must be callable") + + self.type = type_api.NULLTYPE + self.func = func From 9a64e792d0d899556afb1cf284f0c39385c098c2 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 21 Feb 2023 10:53:12 +0300 Subject: [PATCH 052/147] stop old tests if new commit pushed to same branch --- .github/workflows/tests.yaml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 71bdf07f..9c0a2392 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -9,6 +9,11 @@ on: jobs: build: runs-on: ubuntu-latest + + concurrency: + group: unit-${{ github.ref }}-${{ matrix.environment }}-${{ matrix.python-version }} + cancel-in-progress: true + strategy: max-parallel: 4 matrix: From 1df46188a80e1caaa0f8af9e45a1fc0e15d90104 Mon Sep 17 00:00:00 2001 From: Valeriya Popova Date: Mon, 20 Feb 2023 14:11:04 +0300 Subject: [PATCH 053/147] topic writer add flush --- examples/topic/writer_async_example.py | 2 +- tests/topics/test_topic_writer.py | 24 ++++++++++- ydb/_topic_writer/topic_writer.py | 14 +++--- ydb/_topic_writer/topic_writer_asyncio.py | 34 ++++++++------- .../topic_writer_asyncio_test.py | 43 ++++++++++--------- 5 files changed, 73 insertions(+), 44 deletions(-) diff --git a/examples/topic/writer_async_example.py b/examples/topic/writer_async_example.py index 6dd37490..29c79b08 100644 --- a/examples/topic/writer_async_example.py +++ b/examples/topic/writer_async_example.py @@ -65,7 +65,7 @@ async def send_message_without_block_if_internal_buffer_is_full( return False -def send_messages_with_manual_seqno(writer: ydb.TopicWriter): +async def send_messages_with_manual_seqno(writer: ydb.TopicWriter): await writer.write(ydb.TopicWriterMessage("mess")) # send text diff --git a/tests/topics/test_topic_writer.py b/tests/topics/test_topic_writer.py index 3071c655..0bdce33f 100644 --- a/tests/topics/test_topic_writer.py +++ b/tests/topics/test_topic_writer.py @@ -9,7 +9,8 @@ async def test_send_message(self, driver: ydb.aio.Driver, topic_path): writer = driver.topic_client.topic_writer( topic_path, producer_and_message_group_id="test" ) - writer.write(ydb.TopicWriterMessage(data="123".encode())) + await writer.write(ydb.TopicWriterMessage(data="123".encode())) + await writer.close() async def test_wait_last_seqno(self, driver: ydb.aio.Driver, topic_path): async with driver.topic_client.topic_writer( @@ -28,3 +29,24 @@ async def test_wait_last_seqno(self, driver: ydb.aio.Driver, topic_path): ) as writer2: init_info = await writer2.wait_init() assert init_info.last_seqno == 5 + + async def test_auto_flush_on_close(self, driver: ydb.aio.Driver, topic_path): + async with driver.topic_client.topic_writer( + topic_path, + producer_and_message_group_id="test", + auto_seqno=False, + ) as writer: + last_seqno = 0 + for i in range(10): + last_seqno = i + 1 + await writer.write( + ydb.TopicWriterMessage(data=f"msg-{i}", seqno=last_seqno) + ) + + async with driver.topic_client.topic_writer( + topic_path, + producer_and_message_group_id="test", + get_last_seqno=True, + ) as writer: + init_info = await writer.wait_init() + assert init_info.last_seqno == last_seqno diff --git a/ydb/_topic_writer/topic_writer.py b/ydb/_topic_writer/topic_writer.py index b66d8205..cd60d00f 100644 --- a/ydb/_topic_writer/topic_writer.py +++ b/ydb/_topic_writer/topic_writer.py @@ -104,7 +104,7 @@ def async_flush(self): """ raise NotImplementedError() - def flush(self, timeout: Union[float, None] = None) -> concurrent.futures.Future: + def flush(self, timeout: Optional[float] = None) -> concurrent.futures.Future: """ Force send all messages from internal buffer and wait acks from server for all messages. @@ -122,7 +122,7 @@ def async_wait_init(self) -> concurrent.futures.Future: """ raise NotImplementedError() - def wait_init(self, timeout: Union[float, None] = None): + def wait_init(self, timeout: Optional[float] = None): """ Wait until underling connection established @@ -141,15 +141,15 @@ class PublicWriterSettings: session_metadata: Optional[Dict[str, str]] = None encoders: Union[Mapping[int, Callable[[bytes], bytes]], None] = None serializer: Union[Callable[[Any], bytes], None] = None - send_buffer_count: Union[int, None] = 10000 - send_buffer_bytes: Union[int, None] = 100 * 1024 * 1024 + send_buffer_count: Optional[int] = 10000 + send_buffer_bytes: Optional[int] = 100 * 1024 * 1024 partition_id: Optional[int] = None - codec: Union[int, None] = None + codec: Optional[int] = None codec_autoselect: bool = True auto_seqno: bool = True auto_created_at: bool = True get_last_seqno: bool = False - retry_policy: Union["RetryPolicy", None] = None + retry_policy: Optional["RetryPolicy"] = None update_token_interval: Union[int, float] = 3600 @@ -251,7 +251,7 @@ def to_message_data(self) -> StreamWriteMessage.WriteRequest.MessageData: class MessageSendResult: - offset: Union[None, int] + offset: Optional[int] write_status: "MessageWriteStatus" diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 0eb9e5ca..669507d8 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -86,21 +86,14 @@ async def write_with_ack( For wait with timeout use asyncio.wait_for. """ - if isinstance(messages, PublicMessage): - futures = await self._reconnector.write_with_ack([messages]) - return await futures[0] - if isinstance(messages, list): - for m in messages: - if not isinstance(m, PublicMessage): - raise NotImplementedError() + futures = await self.write_with_ack_future(messages, *args) + if not isinstance(futures, list): + futures = [futures] - futures = await self._reconnector.write_with_ack(messages) - await asyncio.wait(futures) + await asyncio.wait(futures) + results = [f.result() for f in futures] - results = [f.result() for f in futures] - return results - - raise NotImplementedError() + return results if isinstance(messages, list) else results[0] async def write_with_ack_future( self, @@ -145,7 +138,7 @@ async def flush(self): For wait with timeout use asyncio.wait_for. """ - raise NotImplementedError() + return await self._reconnector.flush() async def wait_init(self) -> PublicWriterInitInfo: """ @@ -192,10 +185,13 @@ def __init__(self, driver: SupportedDriverType, settings: WriterSettings): asyncio.create_task(self._connection_loop(), name="connection_loop") ] - async def close(self): + async def close(self, flush: bool = True): if self._closed: return + if flush: + await self.flush() + self._closed = True self._stop(TopicWriterStopped()) @@ -396,6 +392,14 @@ def _stop(self, reason: Exception): def _get_token(self) -> str: raise NotImplementedError() + async def flush(self): + self._check_stop() + if not self._messages_future: + return + + # wait last message + await asyncio.wait((self._messages_future[-1],)) + class WriterAsyncIOStream: # todo slots diff --git a/ydb/_topic_writer/topic_writer_asyncio_test.py b/ydb/_topic_writer/topic_writer_asyncio_test.py index c99e392f..6658adbd 100644 --- a/ydb/_topic_writer/topic_writer_asyncio_test.py +++ b/ydb/_topic_writer/topic_writer_asyncio_test.py @@ -239,6 +239,20 @@ def default_write_statistic( topic_quota_wait_time=datetime.timedelta(milliseconds=5), ) + def make_default_ack_message(self, seq_no=1) -> StreamWriteMessage.WriteResponse: + return StreamWriteMessage.WriteResponse( + partition_id=1, + acks=[ + StreamWriteMessage.WriteResponse.WriteAck( + seq_no=seq_no, + message_write_status=StreamWriteMessage.WriteResponse.WriteAck.StatusWritten( + offset=1 + ), + ) + ], + write_statistics=self.default_write_statistic, + ) + @pytest.fixture async def reconnector( self, default_driver, default_settings @@ -275,20 +289,7 @@ async def test_reconnect_and_resent_non_acked_messages_on_retriable_error( assert [InternalMessage(message2)] == messages # ack first message - stream_writer.from_server.put_nowait( - StreamWriteMessage.WriteResponse( - partition_id=1, - acks=[ - StreamWriteMessage.WriteResponse.WriteAck( - seq_no=1, - message_write_status=StreamWriteMessage.WriteResponse.WriteAck.StatusWritten( - offset=1 - ), - ) - ], - write_statistics=default_write_statistic, - ) - ) + stream_writer.from_server.put_nowait(self.make_default_ack_message(seq_no=1)) stream_writer.from_server.put_nowait(issues.Overloaded("test")) @@ -297,6 +298,8 @@ async def test_reconnect_and_resent_non_acked_messages_on_retriable_error( expected_messages = [InternalMessage(message2)] assert second_sent_msg == expected_messages + + second_writer.from_server.put_nowait(self.make_default_ack_message(seq_no=2)) await reconnector.close() async def test_stop_on_unexpected_exception( @@ -323,7 +326,7 @@ async def wait_stop(): await asyncio.wait_for(wait_stop(), 1) with pytest.raises(TestException): - await reconnector.close() + await reconnector.close(flush=False) async def test_wait_init(self, default_driver, default_settings, get_stream_writer): init_seqno = 100 @@ -350,7 +353,7 @@ async def test_wait_init(self, default_driver, default_settings, get_stream_writ info = await reconnector.wait_init() assert info == expected_init_info - await reconnector.close() + await reconnector.close(flush=False) async def test_write_message( self, reconnector: WriterAsyncIOReconnector, get_stream_writer @@ -365,7 +368,7 @@ async def test_write_message( sent_messages = await asyncio.wait_for(stream_writer.from_client.get(), 1) assert sent_messages == [InternalMessage(message)] - await reconnector.close() + await reconnector.close(flush=False) async def test_auto_seq_no( self, default_driver, default_settings, get_stream_writer @@ -399,7 +402,7 @@ async def test_auto_seq_no( [PublicMessage(seqno=last_seq_no + 3, data="123")] ) - await reconnector.close() + await reconnector.close(flush=False) async def test_deny_double_seqno(self, reconnector: WriterAsyncIOReconnector): await reconnector.write_with_ack([PublicMessage(seqno=10, data="123")]) @@ -412,7 +415,7 @@ async def test_deny_double_seqno(self, reconnector: WriterAsyncIOReconnector): await reconnector.write_with_ack([PublicMessage(seqno=11, data="123")]) - await reconnector.close() + await reconnector.close(flush=False) @freezegun.freeze_time("2022-01-13 20:50:00", tz_offset=0) async def test_auto_created_at( @@ -431,7 +434,7 @@ async def test_auto_created_at( assert [ InternalMessage(PublicMessage(seqno=4, data="123", created_at=now)) ] == sent - await reconnector.close() + await reconnector.close(flush=False) @pytest.mark.asyncio From 7703dda58149ad7931627160ce4f7e1151138ee5 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Wed, 22 Feb 2023 21:21:32 +0300 Subject: [PATCH 054/147] Fix flaky fail unit test from cause in integration tests (#169) separate run unit and integration tests --- .github/workflows/tests.yaml | 17 +++++++++++++---- test-requirements.txt | 4 ++-- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 9c0a2392..9682a48a 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -11,14 +11,21 @@ jobs: runs-on: ubuntu-latest concurrency: - group: unit-${{ github.ref }}-${{ matrix.environment }}-${{ matrix.python-version }} + group: unit-${{ github.ref }}-${{ matrix.environment }}-${{ matrix.python-version }}-${{ matrix.folder }} cancel-in-progress: true strategy: + fail-fast: false max-parallel: 4 matrix: python-version: [3.8] environment: [py, py-tls, py-proto3, py-tls-proto3] + folder: [ydb, tests] + exclude: + - environment: py-tls + folder: ydb + - environment: py-tls-proto3 + folder: ydb steps: - uses: actions/checkout@v1 @@ -26,9 +33,11 @@ jobs: uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} - - name: Install dependencies + + - name: Install tox run: | python -m pip install --upgrade pip pip install tox==4.2.6 - - name: Test with tox - run: tox -e ${{ matrix.environment }} + + - name: Run unit tests + run: tox -e ${{ matrix.environment }} -- ${{ matrix.folder }} diff --git a/test-requirements.txt b/test-requirements.txt index d1ca4276..c4a58290 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -11,8 +11,8 @@ docker==5.0.0 docker-compose==1.29.2 dockerpty==0.4.1 docopt==0.6.2 -grpcio==1.42.0 -grpcio-tools==1.42.0 +grpcio==1.47.0 +grpcio-tools==1.47.0 idna==3.2 importlib-metadata==4.6.1 iniconfig==1.1.1 From 143d7cedfe35fe720319cc2786a2818d0a31772c Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Mon, 27 Feb 2023 11:10:33 +0300 Subject: [PATCH 055/147] fix grpc import workaround template --- ydb/_apis.py | 5 +++-- ydb/_grpc/grpcwrapper/common_utils.py | 3 +-- ydb/_grpc/grpcwrapper/ydb_topic.py | 3 +-- ydb/_grpc/grpcwrapper/ydb_topic_public_types.py | 4 ++-- ydb/_topic_common/common_test.py | 6 +++--- ydb/_topic_reader/topic_reader_asyncio_test.py | 6 +++--- ydb/aio/connection.py | 5 +++-- ydb/credentials.py | 5 +++-- ydb/export.py | 4 ++-- ydb/import_client.py | 4 ++-- ydb/scripting.py | 5 +++-- 11 files changed, 26 insertions(+), 24 deletions(-) diff --git a/ydb/_apis.py b/ydb/_apis.py index 557871f8..27bc1bbe 100644 --- a/ydb/_apis.py +++ b/ydb/_apis.py @@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- +import typing + # Workaround for good IDE and universal for runtime -# noinspection PyUnreachableCode -if False: +if typing.TYPE_CHECKING: from ._grpc.v4 import ( ydb_cms_v1_pb2_grpc, ydb_discovery_v1_pb2_grpc, diff --git a/ydb/_grpc/grpcwrapper/common_utils.py b/ydb/_grpc/grpcwrapper/common_utils.py index 5e771051..3a2e6c2b 100644 --- a/ydb/_grpc/grpcwrapper/common_utils.py +++ b/ydb/_grpc/grpcwrapper/common_utils.py @@ -24,8 +24,7 @@ import ydb.aio # Workaround for good IDE and universal for runtime -# noinspection PyUnreachableCode -if False: +if typing.TYPE_CHECKING: from ..v4.protos import ydb_topic_pb2, ydb_issue_message_pb2 else: from ..common.protos import ydb_topic_pb2, ydb_issue_message_pb2 diff --git a/ydb/_grpc/grpcwrapper/ydb_topic.py b/ydb/_grpc/grpcwrapper/ydb_topic.py index 8e3129f7..e6a5a8e3 100644 --- a/ydb/_grpc/grpcwrapper/ydb_topic.py +++ b/ydb/_grpc/grpcwrapper/ydb_topic.py @@ -10,8 +10,7 @@ from ... import scheme # Workaround for good IDE and universal for runtime -# noinspection PyUnreachableCode -if False: +if typing.TYPE_CHECKING: from ..v4.protos import ydb_scheme_pb2, ydb_topic_pb2 else: from ..common.protos import ydb_scheme_pb2, ydb_topic_pb2 diff --git a/ydb/_grpc/grpcwrapper/ydb_topic_public_types.py b/ydb/_grpc/grpcwrapper/ydb_topic_public_types.py index 43aa8449..6d922137 100644 --- a/ydb/_grpc/grpcwrapper/ydb_topic_public_types.py +++ b/ydb/_grpc/grpcwrapper/ydb_topic_public_types.py @@ -1,11 +1,11 @@ import datetime +import typing from dataclasses import dataclass, field from enum import IntEnum from typing import Optional, List, Union, Dict # Workaround for good IDE and universal for runtime -# noinspection PyUnreachableCode -if False: +if typing.TYPE_CHECKING: from ..v4.protos import ydb_topic_pb2 else: from ..common.protos import ydb_topic_pb2 diff --git a/ydb/_topic_common/common_test.py b/ydb/_topic_common/common_test.py index ce19f4a0..445abdcf 100644 --- a/ydb/_topic_common/common_test.py +++ b/ydb/_topic_common/common_test.py @@ -1,4 +1,5 @@ import asyncio +import typing import grpc import pytest @@ -10,9 +11,8 @@ ) from .. import issues -# Workaround for good autocomplete in IDE and universal import at runtime -# noinspection PyUnreachableCode -if False: +# Workaround for good IDE and universal for runtime +if typing.TYPE_CHECKING: from ydb._grpc.v4.protos import ( ydb_status_codes_pb2, ydb_topic_pb2, diff --git a/ydb/_topic_reader/topic_reader_asyncio_test.py b/ydb/_topic_reader/topic_reader_asyncio_test.py index 0fae1bec..f761a315 100644 --- a/ydb/_topic_reader/topic_reader_asyncio_test.py +++ b/ydb/_topic_reader/topic_reader_asyncio_test.py @@ -1,5 +1,6 @@ import asyncio import datetime +import typing from unittest import mock import pytest @@ -12,9 +13,8 @@ from .._grpc.grpcwrapper.ydb_topic import StreamReadMessage, Codec, OffsetsRange from .._topic_common.test_helpers import StreamMock, wait_condition, wait_for_fast -# Workaround for good autocomplete in IDE and universal import at runtime -# noinspection PyUnreachableCode -if False: +# Workaround for good IDE and universal for runtime +if typing.TYPE_CHECKING: from .._grpc.v4.protos import ydb_status_codes_pb2 else: from .._grpc.common.protos import ydb_status_codes_pb2 diff --git a/ydb/aio/connection.py b/ydb/aio/connection.py index 85c22638..fbfcfaaf 100644 --- a/ydb/aio/connection.py +++ b/ydb/aio/connection.py @@ -1,5 +1,6 @@ import logging import asyncio +import typing from typing import Any, Tuple, Callable, Iterable import collections import grpc @@ -24,8 +25,8 @@ from ydb.settings import BaseRequestSettings from ydb import issues -# Workaround for good IDE and universal runtime -if False: +# Workaround for good IDE and universal for runtime +if typing.TYPE_CHECKING: from ydb._grpc.v4 import ydb_topic_v1_pb2_grpc else: from ydb._grpc.common import ydb_topic_v1_pb2_grpc diff --git a/ydb/credentials.py b/ydb/credentials.py index 8e22fe2a..330eefda 100644 --- a/ydb/credentials.py +++ b/ydb/credentials.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- import abc +import typing + import six from . import tracing, issues, connection from . import settings as settings_impl @@ -9,8 +11,7 @@ import time # Workaround for good IDE and universal for runtime -# noinspection PyUnreachableCode -if False: +if typing.TYPE_CHECKING: from ._grpc.v4.protos import ydb_auth_pb2 from ._grpc.v4 import ydb_auth_v1_pb2_grpc else: diff --git a/ydb/export.py b/ydb/export.py index 419a753b..8e6b446a 100644 --- a/ydb/export.py +++ b/ydb/export.py @@ -1,12 +1,12 @@ import enum +import typing from . import _apis from . import settings_impl as s_impl # Workaround for good IDE and universal for runtime -# noinspection PyUnreachableCode -if False: +if typing.TYPE_CHECKING: from ._grpc.v4.protos import ydb_export_pb2 from ._grpc.v4 import ydb_export_v1_pb2_grpc else: diff --git a/ydb/import_client.py b/ydb/import_client.py index d1ccc99a..d94294ca 100644 --- a/ydb/import_client.py +++ b/ydb/import_client.py @@ -1,12 +1,12 @@ import enum +import typing from . import _apis from . import settings_impl as s_impl # Workaround for good IDE and universal for runtime -# noinspection PyUnreachableCode -if False: +if typing.TYPE_CHECKING: from ._grpc.v4.protos import ydb_import_pb2 from ._grpc.v4 import ydb_import_v1_pb2_grpc else: diff --git a/ydb/scripting.py b/ydb/scripting.py index 9fed037a..13132430 100644 --- a/ydb/scripting.py +++ b/ydb/scripting.py @@ -1,6 +1,7 @@ +import typing + # Workaround for good IDE and universal for runtime -# noinspection PyUnreachableCode -if False: +if typing.TYPE_CHECKING: from ._grpc.v4.protos import ydb_scripting_pb2 from ._grpc.v4 import ydb_scripting_v1_pb2_grpc else: From 6f0b65abd6d5414a1a209656932c9cd8116bea42 Mon Sep 17 00:00:00 2001 From: Valeriya Popova Date: Wed, 22 Feb 2023 17:52:04 +0300 Subject: [PATCH 056/147] topic-writer: fix default flush on close parameter, renaming --- ydb/_topic_writer/topic_writer_asyncio.py | 12 ++++---- .../topic_writer_asyncio_test.py | 30 +++++++++++-------- 2 files changed, 23 insertions(+), 19 deletions(-) diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 669507d8..69372075 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -65,13 +65,13 @@ def __del__(self): self._loop.call_soon(self.close) - async def close(self): + async def close(self, *, flush: bool = True): if self._closed: return self._closed = True - await self._reconnector.close() + await self._reconnector.close(flush) async def write_with_ack( self, @@ -109,13 +109,13 @@ async def write_with_ack_future( For wait with timeout use asyncio.wait_for. """ if isinstance(messages, PublicMessage): - futures = await self._reconnector.write_with_ack([messages]) + futures = await self._reconnector.write_with_ack_future([messages]) return futures[0] if isinstance(messages, list): for m in messages: if not isinstance(m, PublicMessage): raise NotImplementedError() - return await self._reconnector.write_with_ack(messages) + return await self._reconnector.write_with_ack_future(messages) raise NotImplementedError() async def write( @@ -185,7 +185,7 @@ def __init__(self, driver: SupportedDriverType, settings: WriterSettings): asyncio.create_task(self._connection_loop(), name="connection_loop") ] - async def close(self, flush: bool = True): + async def close(self, flush: bool): if self._closed: return @@ -223,7 +223,7 @@ async def wait_init(self) -> PublicWriterInitInfo: async def wait_stop(self) -> Exception: return await self._stop_reason - async def write_with_ack( + async def write_with_ack_future( self, messages: List[PublicMessage] ) -> List[asyncio.Future]: # todo check internal buffer limit diff --git a/ydb/_topic_writer/topic_writer_asyncio_test.py b/ydb/_topic_writer/topic_writer_asyncio_test.py index 6658adbd..90799e7a 100644 --- a/ydb/_topic_writer/topic_writer_asyncio_test.py +++ b/ydb/_topic_writer/topic_writer_asyncio_test.py @@ -278,7 +278,7 @@ async def test_reconnect_and_resent_non_acked_messages_on_retriable_error( seqno=2, created_at=now, ) - await reconnector.write_with_ack([message1, message2]) + await reconnector.write_with_ack_future([message1, message2]) # sent to first stream stream_writer = get_stream_writer() @@ -300,7 +300,7 @@ async def test_reconnect_and_resent_non_acked_messages_on_retriable_error( assert second_sent_msg == expected_messages second_writer.from_server.put_nowait(self.make_default_ack_message(seq_no=2)) - await reconnector.close() + await reconnector.close(flush=True) async def test_stop_on_unexpected_exception( self, reconnector: WriterAsyncIOReconnector, get_stream_writer @@ -320,7 +320,7 @@ class TestException(Exception): async def wait_stop(): while True: - await reconnector.write_with_ack([message]) + await reconnector.write_with_ack_future([message]) await asyncio.sleep(0.1) await asyncio.wait_for(wait_stop(), 1) @@ -363,7 +363,7 @@ async def test_write_message( data="123", seqno=3, ) - await reconnector.write_with_ack([message]) + await reconnector.write_with_ack_future([message]) sent_messages = await asyncio.wait_for(stream_writer.from_client.get(), 1) assert sent_messages == [InternalMessage(message)] @@ -382,8 +382,8 @@ async def test_auto_seq_no( reconnector = WriterAsyncIOReconnector(default_driver, settings) - await reconnector.write_with_ack([PublicMessage(data="123")]) - await reconnector.write_with_ack([PublicMessage(data="456")]) + await reconnector.write_with_ack_future([PublicMessage(data="123")]) + await reconnector.write_with_ack_future([PublicMessage(data="456")]) stream_writer = get_stream_writer() @@ -398,22 +398,26 @@ async def test_auto_seq_no( ] == sent with pytest.raises(TopicWriterError): - await reconnector.write_with_ack( + await reconnector.write_with_ack_future( [PublicMessage(seqno=last_seq_no + 3, data="123")] ) await reconnector.close(flush=False) async def test_deny_double_seqno(self, reconnector: WriterAsyncIOReconnector): - await reconnector.write_with_ack([PublicMessage(seqno=10, data="123")]) + await reconnector.write_with_ack_future([PublicMessage(seqno=10, data="123")]) with pytest.raises(TopicWriterError): - await reconnector.write_with_ack([PublicMessage(seqno=9, data="123")]) + await reconnector.write_with_ack_future( + [PublicMessage(seqno=9, data="123")] + ) with pytest.raises(TopicWriterError): - await reconnector.write_with_ack([PublicMessage(seqno=10, data="123")]) + await reconnector.write_with_ack_future( + [PublicMessage(seqno=10, data="123")] + ) - await reconnector.write_with_ack([PublicMessage(seqno=11, data="123")]) + await reconnector.write_with_ack_future([PublicMessage(seqno=11, data="123")]) await reconnector.close(flush=False) @@ -426,7 +430,7 @@ async def test_auto_created_at( settings = copy.deepcopy(default_settings) settings.auto_created_at = True reconnector = WriterAsyncIOReconnector(default_driver, settings) - await reconnector.write_with_ack([PublicMessage(seqno=4, data="123")]) + await reconnector.write_with_ack_future([PublicMessage(seqno=4, data="123")]) stream_writer = get_stream_writer() sent = await stream_writer.from_client.get() @@ -451,7 +455,7 @@ def __init__(self): self.futures = [] self.messages_writted = asyncio.Event() - async def write_with_ack(self, messages: typing.List[InternalMessage]): + async def write_with_ack_future(self, messages: typing.List[InternalMessage]): async with self.lock: futures = [asyncio.Future() for _ in messages] self.messages.extend(messages) From 190dd9f5aaf7d2f78ff5519225359616972732d2 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Mon, 27 Feb 2023 12:23:06 +0300 Subject: [PATCH 057/147] close grpc streams when stream reader/writer closed --- CHANGELOG.md | 1 + ydb/_grpc/grpcwrapper/common_utils.py | 29 +++++++++++++------ ydb/_topic_common/test_helpers.py | 16 ++++++++++ ydb/_topic_reader/topic_reader_asyncio.py | 1 + ydb/_topic_writer/topic_writer_asyncio.py | 4 +++ .../topic_writer_asyncio_test.py | 17 +++++++++++ 6 files changed, 59 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 61f06737..fd306185 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,4 @@ +* Close grpc streams while closing readers/writers * Add control plane operations for topic api: create, drop ## 3.0.1b4 ## diff --git a/ydb/_grpc/grpcwrapper/common_utils.py b/ydb/_grpc/grpcwrapper/common_utils.py index 3a2e6c2b..6cfb2a9f 100644 --- a/ydb/_grpc/grpcwrapper/common_utils.py +++ b/ydb/_grpc/grpcwrapper/common_utils.py @@ -79,10 +79,10 @@ def __aiter__(self): return self async def __anext__(self): - try: - return await self._queue.get() - except asyncio.QueueEmpty: + item = await self._queue.get() + if item is None: raise StopAsyncIteration() + return item class AsyncQueueToSyncIteratorAsyncIO: @@ -100,13 +100,10 @@ def __iter__(self): return self def __next__(self): - try: - res = asyncio.run_coroutine_threadsafe( - self._queue.get(), self._loop - ).result() - return res - except asyncio.QueueEmpty: + item = asyncio.run_coroutine_threadsafe(self._queue.get(), self._loop).result() + if item is None: raise StopIteration() + return item class SyncIteratorToAsyncIterator: @@ -133,6 +130,10 @@ async def receive(self) -> Any: def write(self, wrap_message: IToProto): ... + @abc.abstractmethod + def close(self): + ... + SupportedDriverType = Union[ydb.Driver, ydb.aio.Driver] @@ -142,11 +143,15 @@ class GrpcWrapperAsyncIO(IGrpcWrapperAsyncIO): from_server_grpc: AsyncIterator convert_server_grpc_to_wrapper: Callable[[Any], Any] _connection_state: str + _stream_call: Optional[ + Union[grpc.aio.StreamStreamCall, "grpc._channel._MultiThreadedRendezvous"] + ] def __init__(self, convert_server_grpc_to_wrapper): self.from_client_grpc = asyncio.Queue() self.convert_server_grpc_to_wrapper = convert_server_grpc_to_wrapper self._connection_state = "new" + self._stream_call = None async def start(self, driver: SupportedDriverType, stub, method): if asyncio.iscoroutinefunction(driver.__call__): @@ -155,6 +160,10 @@ async def start(self, driver: SupportedDriverType, stub, method): await self._start_sync_driver(driver, stub, method) self._connection_state = "started" + def close(self): + self.from_client_grpc.put_nowait(None) + self._stream_call.cancel() + async def _start_asyncio_driver(self, driver: ydb.aio.Driver, stub, method): requests_iterator = QueueToIteratorAsyncIO(self.from_client_grpc) stream_call = await driver( @@ -162,6 +171,7 @@ async def _start_asyncio_driver(self, driver: ydb.aio.Driver, stub, method): stub, method, ) + self._stream_call = stream_call self.from_server_grpc = stream_call.__aiter__() async def _start_sync_driver(self, driver: ydb.Driver, stub, method): @@ -172,6 +182,7 @@ async def _start_sync_driver(self, driver: ydb.Driver, stub, method): stub, method, ) + self._stream_call = stream_call self.from_server_grpc = SyncIteratorToAsyncIterator(stream_call.__iter__()) async def receive(self) -> Any: diff --git a/ydb/_topic_common/test_helpers.py b/ydb/_topic_common/test_helpers.py index bea6fea5..9023f759 100644 --- a/ydb/_topic_common/test_helpers.py +++ b/ydb/_topic_common/test_helpers.py @@ -8,20 +8,36 @@ class StreamMock(IGrpcWrapperAsyncIO): from_server: asyncio.Queue from_client: asyncio.Queue + _closed: bool def __init__(self): self.from_server = asyncio.Queue() self.from_client = asyncio.Queue() + self._closed = False async def receive(self) -> typing.Any: + if self._closed: + raise Exception("read from closed StreamMock") + item = await self.from_server.get() + if item is None: + raise StopAsyncIteration() if isinstance(item, Exception): raise item return item def write(self, wrap_message: IToProto): + if self._closed: + raise Exception("write to closed StreamMock") self.from_client.put_nowait(wrap_message) + def close(self): + if self._closed: + return + + self._closed = True + self.from_server.put_nowait(None) + async def wait_condition(f: typing.Callable[[], bool], timeout=1): start = time.monotonic() diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index 95bd1008..a3f792de 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -496,6 +496,7 @@ async def close(self): self._closed = True self._set_first_error(TopicReaderStreamClosedError()) self._state_changed.set() + self._stream.close() for task in self._background_tasks: task.cancel() diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 669507d8..d8c3b4d4 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -322,6 +322,7 @@ async def _connection_loop(self): done, pending = await asyncio.wait( [send_loop, receive_loop], return_when=asyncio.FIRST_COMPLETED ) + stream_writer.close() done.pop().result() except issues.Error as err: # todo log error @@ -417,6 +418,9 @@ def __init__( ): self._token_getter = token_getter + def close(self): + self._stream.close() + @staticmethod async def create( driver: SupportedDriverType, diff --git a/ydb/_topic_writer/topic_writer_asyncio_test.py b/ydb/_topic_writer/topic_writer_asyncio_test.py index 6658adbd..1c96097f 100644 --- a/ydb/_topic_writer/topic_writer_asyncio_test.py +++ b/ydb/_topic_writer/topic_writer_asyncio_test.py @@ -158,20 +158,37 @@ class StreamWriterMock: from_client: asyncio.Queue from_server: asyncio.Queue + _closed: bool + def __init__(self): self.last_seqno = 0 self.from_server = asyncio.Queue() self.from_client = asyncio.Queue() + self._closed = False def write(self, messages: typing.List[InternalMessage]): + if self._closed: + raise Exception("write to closed StreamWriterMock") + self.from_client.put_nowait(messages) async def receive(self) -> StreamWriteMessage.WriteResponse: + if self._closed: + raise Exception("read from closed StreamWriterMock") + item = await self.from_server.get() if isinstance(item, Exception): raise item return item + def close(self): + if self._closed: + return + + self.from_server.put_nowait( + Exception("waited message while StreamWriterMock closed") + ) + @pytest.fixture(autouse=True) async def stream_writer_double_queue(self, monkeypatch): class DoubleQueueWriters: From 9da0026546c5c10234d565eaa578b12d687dc285 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Mon, 27 Feb 2023 18:44:10 +0300 Subject: [PATCH 058/147] add stop marker fix close path --- tests/topics/test_topic_reader.py | 1 + ydb/_grpc/grpcwrapper/common_utils.py | 12 ++++++++---- ydb/_topic_writer/topic_writer_asyncio.py | 3 +++ 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/tests/topics/test_topic_reader.py b/tests/topics/test_topic_reader.py index 6d87fc0b..ac338fbd 100644 --- a/tests/topics/test_topic_reader.py +++ b/tests/topics/test_topic_reader.py @@ -9,3 +9,4 @@ async def test_read_message( reader = driver.topic_client.topic_reader(topic_consumer, topic_path) assert await reader.receive_batch() is not None + await reader.close() diff --git a/ydb/_grpc/grpcwrapper/common_utils.py b/ydb/_grpc/grpcwrapper/common_utils.py index 6cfb2a9f..1e56ad05 100644 --- a/ydb/_grpc/grpcwrapper/common_utils.py +++ b/ydb/_grpc/grpcwrapper/common_utils.py @@ -69,6 +69,9 @@ class UnknownGrpcMessageError(issues.Error): pass +_stop_grpc_connection_marker = object() + + class QueueToIteratorAsyncIO: __slots__ = ("_queue",) @@ -80,7 +83,7 @@ def __aiter__(self): async def __anext__(self): item = await self._queue.get() - if item is None: + if item is _stop_grpc_connection_marker: raise StopAsyncIteration() return item @@ -101,7 +104,7 @@ def __iter__(self): def __next__(self): item = asyncio.run_coroutine_threadsafe(self._queue.get(), self._loop).result() - if item is None: + if item is _stop_grpc_connection_marker: raise StopIteration() return item @@ -161,8 +164,9 @@ async def start(self, driver: SupportedDriverType, stub, method): self._connection_state = "started" def close(self): - self.from_client_grpc.put_nowait(None) - self._stream_call.cancel() + self.from_client_grpc.put_nowait(_stop_grpc_connection_marker) + if self._stream_call: + self._stream_call.cancel() async def _start_asyncio_driver(self, driver: ydb.aio.Driver, stub, method): requests_iterator = QueueToIteratorAsyncIO(self.from_client_grpc) diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index d8c3b4d4..c0ef2491 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -296,6 +296,7 @@ async def _connection_loop(self): pending = [] # noinspection PyBroadException + stream_writer = None try: stream_writer = await WriterAsyncIOStream.create( self._driver, self._init_message, self._get_token @@ -339,6 +340,8 @@ async def _connection_loop(self): self._stop(err) return finally: + if stream_writer: + stream_writer.close() if len(pending) > 0: for task in pending: task.cancel() From c0f6f7b42751f4f1af98e1ed32c734a365a9ecc4 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Mon, 27 Feb 2023 13:22:53 +0300 Subject: [PATCH 059/147] fix timeout for sync call add test for sync writer --- tests/topics/test_topic_writer.py | 44 ++++++++ ydb/_grpc/grpcwrapper/common_utils.py | 25 ++++- ydb/_topic_common/common.py | 1 + ydb/_topic_writer/topic_writer.py | 122 +--------------------- ydb/_topic_writer/topic_writer_asyncio.py | 14 +-- ydb/_topic_writer/topic_writer_sync.py | 21 ++-- ydb/topic.py | 8 +- 7 files changed, 96 insertions(+), 139 deletions(-) diff --git a/tests/topics/test_topic_writer.py b/tests/topics/test_topic_writer.py index 0bdce33f..799c4d13 100644 --- a/tests/topics/test_topic_writer.py +++ b/tests/topics/test_topic_writer.py @@ -50,3 +50,47 @@ async def test_auto_flush_on_close(self, driver: ydb.aio.Driver, topic_path): ) as writer: init_info = await writer.wait_init() assert init_info.last_seqno == last_seqno + + +class TestTopicWriterSync: + def test_send_message(self, driver_sync: ydb.Driver, topic_path): + writer = driver_sync.topic_client.topic_writer( + topic_path, producer_and_message_group_id="test" + ) + writer.write(ydb.TopicWriterMessage(data="123".encode())) + writer.close() + + def test_wait_last_seqno(self, driver_sync: ydb.Driver, topic_path): + with driver_sync.topic_client.topic_writer( + topic_path, + producer_and_message_group_id="test", + auto_seqno=False, + ) as writer: + writer.write_with_ack(ydb.TopicWriterMessage(data="123".encode(), seqno=5)) + + with driver_sync.topic_client.topic_writer( + topic_path, + producer_and_message_group_id="test", + get_last_seqno=True, + ) as writer2: + init_info = writer2.wait_init() + assert init_info.last_seqno == 5 + + def test_auto_flush_on_close(self, driver_sync: ydb.Driver, topic_path): + with driver_sync.topic_client.topic_writer( + topic_path, + producer_and_message_group_id="test", + auto_seqno=False, + ) as writer: + last_seqno = 0 + for i in range(10): + last_seqno = i + 1 + writer.write(ydb.TopicWriterMessage(data=f"msg-{i}", seqno=last_seqno)) + + with driver_sync.topic_client.topic_writer( + topic_path, + producer_and_message_group_id="test", + get_last_seqno=True, + ) as writer: + init_info = writer.wait_init() + assert init_info.last_seqno == last_seqno diff --git a/ydb/_grpc/grpcwrapper/common_utils.py b/ydb/_grpc/grpcwrapper/common_utils.py index 1e56ad05..6c624520 100644 --- a/ydb/_grpc/grpcwrapper/common_utils.py +++ b/ydb/_grpc/grpcwrapper/common_utils.py @@ -2,7 +2,9 @@ import abc import asyncio +import contextvars import datetime +import functools import typing from typing import ( Optional, @@ -118,7 +120,7 @@ def __aiter__(self): async def __anext__(self): try: - res = await asyncio.to_thread(self._sync_iterator.__next__) + res = await to_thread(self._sync_iterator.__next__) return res except StopAsyncIteration: raise StopIteration() @@ -180,7 +182,7 @@ async def _start_asyncio_driver(self, driver: ydb.aio.Driver, stub, method): async def _start_sync_driver(self, driver: ydb.Driver, stub, method): requests_iterator = AsyncQueueToSyncIteratorAsyncIO(self.from_client_grpc) - stream_call = await asyncio.to_thread( + stream_call = await to_thread( driver, requests_iterator, stub, @@ -257,6 +259,25 @@ def callback_from_asyncio( return loop.run_in_executor(None, callback) +async def to_thread(func, /, *args, **kwargs): + """Asynchronously run function *func* in a separate thread. + + Any *args and **kwargs supplied for this function are directly passed + to *func*. Also, the current :class:`contextvars.Context` is propagated, + allowing context variables from the main thread to be accessed in the + separate thread. + + Return a coroutine that can be awaited to get the eventual result of *func*. + + copy to_thread from 3.10 + """ + + loop = asyncio.get_running_loop() + ctx = contextvars.copy_context() + func_call = functools.partial(ctx.run, func, *args, **kwargs) + return await loop.run_in_executor(None, func_call) + + def proto_duration_from_timedelta(t: Optional[datetime.timedelta]) -> ProtoDuration: if t is None: return None diff --git a/ydb/_topic_common/common.py b/ydb/_topic_common/common.py index 8dcafcb7..e325ca4b 100644 --- a/ydb/_topic_common/common.py +++ b/ydb/_topic_common/common.py @@ -4,6 +4,7 @@ from .._grpc.grpcwrapper.common_utils import IFromProtoWithProtoType TokenGetterFuncType = typing.Optional[typing.Callable[[], str]] +TimeoutType = typing.Union[int, float] def wrap_operation(rpc_state, response_pb, driver=None): diff --git a/ydb/_topic_writer/topic_writer.py b/ydb/_topic_writer/topic_writer.py index cd60d00f..75858324 100644 --- a/ydb/_topic_writer/topic_writer.py +++ b/ydb/_topic_writer/topic_writer.py @@ -1,4 +1,3 @@ -import concurrent.futures import datetime import enum from dataclasses import dataclass @@ -12,126 +11,7 @@ from .._grpc.grpcwrapper.common_utils import IToProto -class Writer: - @property - def last_seqno(self) -> int: - raise NotImplementedError() - - def __init__(self, db: ydb.Driver): - pass - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.close() - - def close(self): - pass - - MessageType = typing.Union["PublicMessage", "PublicMessage.SimpleMessageSourceType"] - - def write( - self, - message: Union[MessageType, List[MessageType]], - *args: Optional[MessageType], - timeout: [float, None] = None, - ): - """ - send one or number of messages to server. - it fast put message to internal buffer, without wait message result - return None - - message will send independent of wait/no wait result - - timeout - time for waiting for put message into internal queue. - if 0 or negative - non block calls - if None or not set - infinite wait - It will raise TimeoutError() exception if it can't put message to internal queue by limits during timeout. - """ - raise NotImplementedError() - - def async_write_with_ack( - self, - message: Union[MessageType, List[MessageType]], - *args: Optional[MessageType], - timeout: [float, None] = None, - ) -> concurrent.futures.Future: - """ - send one or number of messages to server. - return feature, which can be waited for check send result: ack/duplicate/error - - Usually it is fast method, but can wait if internal buffer is full. - - timeout - time for waiting for put message into internal queue. - The method can be blocked up to timeout seconds before return future. - - if 0 or negative - non block calls - if None or not set - infinite wait - It will raise TimeoutError() exception if it can't put message to internal queue by limits during timeout. - """ - raise NotImplementedError() - - def write_with_ack( - self, - message: Union[MessageType, List[MessageType]], - *args: Optional[MessageType], - buffer_timeout: [float, None] = None, - ) -> Union["MessageWriteStatus", List["MessageWriteStatus"]]: - """ - IT IS SLOWLY WAY. IT IS BAD CHOISE IN MOST CASES. - It is recommended to use write with optionally flush or async_write_with_ack and receive acks by wait future. - - send one or number of messages to server. - blocked until receive server ack for the message/messages. - - message will send independent of wait/no wait result - - buffer_timeout - time for send message to server and receive ack. - if 0 or negative - non block calls - if None or not set - infinite wait - It will raise TimeoutError() exception if it isn't receive ack in timeout - """ - raise NotImplementedError() - - def async_flush(self): - """ - Force send all messages from internal buffer and wait acks from server for all - messages. - - flush starts of flush process, and return Future for wait result. - messages will be flushed independent of future waiting. - """ - raise NotImplementedError() - - def flush(self, timeout: Optional[float] = None) -> concurrent.futures.Future: - """ - Force send all messages from internal buffer and wait acks from server for all - messages. - - timeout - time for waiting for send all messages and receive server ack. - if 0 or negative - non block calls - if None or not set - infinite wait - It will raise TimeoutError() exception if it isn't receive ack in timeout - """ - raise NotImplementedError() - - def async_wait_init(self) -> concurrent.futures.Future: - """ - Return feature, which done when underling connection established - """ - raise NotImplementedError() - - def wait_init(self, timeout: Optional[float] = None): - """ - Wait until underling connection established - - timeout - time for waiting for send all messages and receive server ack. - if 0 or negative - non block calls - if None or not set - infinite wait - It will raise TimeoutError() exception if it isn't receive ack in timeout - """ - raise NotImplementedError() +MessageType = typing.Union["PublicMessage", "PublicMessage.SimpleMessageSourceType"] @dataclass diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 1b175e8e..4724ab2f 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -7,7 +7,6 @@ from .topic_writer import ( PublicWriterSettings, WriterSettings, - Writer, PublicMessage, PublicWriterInitInfo, InternalMessage, @@ -15,6 +14,7 @@ TopicWriterError, messages_to_proto_requests, PublicWriteResultTypes, + MessageType, ) from .. import ( _apis, @@ -75,8 +75,8 @@ async def close(self, *, flush: bool = True): async def write_with_ack( self, - messages: Union[Writer.MessageType, List[Writer.MessageType]], - *args: Optional[Writer.MessageType], + messages: Union[MessageType, List[MessageType]], + *args: Optional[MessageType], ) -> Union[PublicWriteResultTypes, List[PublicWriteResultTypes]]: """ IT IS SLOWLY WAY. IT IS BAD CHOISE IN MOST CASES. @@ -97,8 +97,8 @@ async def write_with_ack( async def write_with_ack_future( self, - messages: Union[Writer.MessageType, List[Writer.MessageType]], - *args: Optional[Writer.MessageType], + messages: Union[MessageType, List[MessageType]], + *args: Optional[MessageType], ) -> Union[asyncio.Future, List[asyncio.Future]]: """ send one or number of messages to server. @@ -120,8 +120,8 @@ async def write_with_ack_future( async def write( self, - messages: Union[Writer.MessageType, List[Writer.MessageType]], - *args: Optional[Writer.MessageType], + messages: Union[MessageType, List[MessageType]], + *args: Optional[MessageType], ): """ send one or number of messages to server. diff --git a/ydb/_topic_writer/topic_writer_sync.py b/ydb/_topic_writer/topic_writer_sync.py index d8c66213..2c58325e 100644 --- a/ydb/_topic_writer/topic_writer_sync.py +++ b/ydb/_topic_writer/topic_writer_sync.py @@ -11,11 +11,12 @@ TopicWriterError, PublicWriterInitInfo, PublicMessage, - Writer, PublicWriteResult, + MessageType, ) from .topic_writer_asyncio import WriterAsyncIO +from .._topic_common.common import TimeoutType _shared_event_loop_lock = threading.Lock() _shared_event_loop = None # type: Optional[asyncio.AbstractEventLoop] @@ -78,6 +79,12 @@ async def create_async_writer(): create_async_writer(), self._loop ).result() + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + def _call(self, coro, *args, **kwargs): if self._closed: raise TopicWriterError("writer is closed") @@ -87,7 +94,7 @@ def _call(self, coro, *args, **kwargs): def _call_sync(self, coro: Coroutine, timeout, *args, **kwargs): f = self._call(coro, *args, **kwargs) try: - return f.result() + return f.result(timeout=timeout) except TimeoutError: f.cancel() raise @@ -111,7 +118,7 @@ def flush(self, timeout=None): def async_wait_init(self) -> Future[PublicWriterInitInfo]: return self._call(self._async_writer.wait_init()) - def wait_init(self, timeout) -> PublicWriterInitInfo: + def wait_init(self, timeout: Optional[TimeoutType] = None) -> PublicWriterInitInfo: return self._call_sync(self._async_writer.wait_init(), timeout) def write( @@ -124,15 +131,15 @@ def write( def async_write_with_ack( self, - messages: Union[Writer.MessageType, List[Writer.MessageType]], - *args: Optional[Writer.MessageType], + messages: Union[MessageType, List[MessageType]], + *args: Optional[MessageType], ) -> Future[Union[PublicWriteResult, List[PublicWriteResult]]]: return self._call(self._async_writer.write_with_ack(messages, *args)) def write_with_ack( self, - messages: Union[Writer.MessageType, List[Writer.MessageType]], - *args: Optional[Writer.MessageType], + messages: Union[MessageType, List[MessageType]], + *args: Optional[MessageType], timeout: Union[float, None] = None, ) -> Union[PublicWriteResult, List[PublicWriteResult]]: return self._call_sync( diff --git a/ydb/topic.py b/ydb/topic.py index 42d283bc..593c0378 100644 --- a/ydb/topic.py +++ b/ydb/topic.py @@ -19,12 +19,13 @@ ) from ._topic_writer.topic_writer import ( # noqa: F401 - Writer as TopicWriter, PublicWriterSettings as TopicWriterSettings, PublicMessage as TopicWriterMessage, RetryPolicy as TopicWriterRetryPolicy, ) +from ._topic_writer.topic_writer_sync import WriterSync as TopicWriter + from ._topic_common.common import ( wrap_operation as _wrap_operation, create_result_wrapper as _create_result_wrapper, @@ -278,7 +279,10 @@ def topic_writer( get_last_seqno: bool = False, retry_policy: Union["TopicWriterRetryPolicy", None] = None, ) -> TopicWriter: - raise NotImplementedError() + args = locals() + del args["self"] + settings = TopicWriterSettings(**args) + return TopicWriter(self._driver, settings) class TopicClientSettings: From 5166d367a2494b92ca5b6821b30e6ff278ece359 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 31 Jan 2023 12:30:16 +0300 Subject: [PATCH 060/147] use anonymous credentials by default --- CHANGELOG.md | 1 + tests/conftest.py | 7 ------- ydb/aio/driver.py | 38 +------------------------------------- ydb/driver.py | 13 +++++++++---- 4 files changed, 11 insertions(+), 48 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fd306185..61ffe36f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,4 @@ +* Use anonymous credentials by default instead of iam metadata (use ydb.driver.credentials_from_env_variables for creds by env var) * Close grpc streams while closing readers/writers * Add control plane operations for topic api: create, drop diff --git a/tests/conftest.py b/tests/conftest.py index 674422ec..26580e8d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,4 @@ import os -from unittest import mock import pytest import ydb @@ -7,12 +6,6 @@ from ydb import issues -@pytest.fixture(autouse=True, scope="session") -def mock_settings_env_vars(): - with mock.patch.dict(os.environ, {"YDB_ANONYMOUS_CREDENTIALS": "1"}): - yield - - @pytest.fixture(scope="module") def docker_compose_file(pytestconfig): return os.path.join(str(pytestconfig.rootdir), "docker-compose.yml") diff --git a/ydb/aio/driver.py b/ydb/aio/driver.py index 319cb14c..1aa3ad27 100644 --- a/ydb/aio/driver.py +++ b/ydb/aio/driver.py @@ -1,43 +1,7 @@ -import os - from . import pool, scheme, table import ydb from .. import _utilities -from ydb.driver import get_config - - -def default_credentials(credentials=None): - if credentials is not None: - return credentials - - service_account_key_file = os.getenv("YDB_SERVICE_ACCOUNT_KEY_FILE_CREDENTIALS") - if service_account_key_file is not None: - from .iam import ServiceAccountCredentials - - return ServiceAccountCredentials.from_file(service_account_key_file) - - anonymous_credetials = os.getenv("YDB_ANONYMOUS_CREDENTIALS", "0") == "1" - if anonymous_credetials: - return ydb.credentials.AnonymousCredentials() - - metadata_credentials = os.getenv("YDB_METADATA_CREDENTIALS", "0") == "1" - if metadata_credentials: - from .iam import MetadataUrlCredentials - - return MetadataUrlCredentials() - - access_token = os.getenv("YDB_ACCESS_TOKEN_CREDENTIALS") - if access_token is not None: - return ydb.credentials.AccessTokenCredentials(access_token) - - # (legacy instantiation) - creds = ydb.auth_helpers.construct_credentials_from_environ() - if creds is not None: - return creds - - from .iam import MetadataUrlCredentials - - return MetadataUrlCredentials() +from ydb.driver import get_config, default_credentials class DriverConfig(ydb.DriverConfig): diff --git a/ydb/driver.py b/ydb/driver.py index 9aa6aab3..0ef723fe 100644 --- a/ydb/driver.py +++ b/ydb/driver.py @@ -23,10 +23,17 @@ class RPCCompression: def default_credentials(credentials=None, tracer=None): tracer = tracer if tracer is not None else tracing.Tracer(None) with tracer.trace("Driver.default_credentials") as ctx: - if credentials is not None: + if credentials is None: + ctx.trace({"credentials.anonymous": True}) + return credentials_impl.AnonymousCredentials() + else: ctx.trace({"credentials.prepared": True}) return credentials + +def credentials_from_env_variables(tracer=None): + tracer = tracer if tracer is not None else tracing.Tracer(None) + with tracer.trace("Driver.credentials_from_env_variables") as ctx: service_account_key_file = os.getenv("YDB_SERVICE_ACCOUNT_KEY_FILE_CREDENTIALS") if service_account_key_file is not None: ctx.trace({"credentials.service_account_key_file": True}) @@ -51,9 +58,7 @@ def default_credentials(credentials=None, tracer=None): ctx.trace({"credentials.access_token": True}) return credentials_impl.AuthTokenCredentials(access_token) - import ydb.iam - - return ydb.iam.MetadataUrlCredentials(tracer=tracer) + return default_credentials(None, tracer) class DriverConfig(object): From 0d57d0a502b48f9b097e3834e6a67418f3aa8d7a Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Wed, 22 Feb 2023 21:36:11 +0300 Subject: [PATCH 061/147] use common auth algorithm by default remove deprecated auth algoritm --- .../cloud_function/utils.py | 2 +- .../secondary_indexes_builtin.py | 2 +- examples/time-series-serverless/database.py | 2 +- examples/ttl/ttl.py | 2 +- examples/ttl_readtable/ttl.py | 2 +- tests/aio/test_connection_pool.py | 10 ----- tests/conftest.py | 4 -- ydb/auth_helpers.py | 39 ------------------- 8 files changed, 5 insertions(+), 58 deletions(-) diff --git a/examples/reservations-bot-demo/cloud_function/utils.py b/examples/reservations-bot-demo/cloud_function/utils.py index 52aa56f9..10e0d53d 100644 --- a/examples/reservations-bot-demo/cloud_function/utils.py +++ b/examples/reservations-bot-demo/cloud_function/utils.py @@ -7,7 +7,7 @@ def make_driver_config(endpoint, database, path): return ydb.DriverConfig( endpoint, database, - credentials=ydb.construct_credentials_from_environ(), + credentials=ydb.credentials_from_env_variables(), root_certificates=ydb.load_ydb_root_certificate(), ) diff --git a/examples/secondary_indexes_builtin/secondary_indexes_builtin.py b/examples/secondary_indexes_builtin/secondary_indexes_builtin.py index 03949970..877ed84c 100644 --- a/examples/secondary_indexes_builtin/secondary_indexes_builtin.py +++ b/examples/secondary_indexes_builtin/secondary_indexes_builtin.py @@ -245,7 +245,7 @@ def callee(session): def run(endpoint, database, path): driver_config = ydb.DriverConfig( - endpoint, database, credentials=ydb.construct_credentials_from_environ() + endpoint, database, credentials=ydb.credentials_from_env_variables() ) with ydb.Driver(driver_config) as driver: try: diff --git a/examples/time-series-serverless/database.py b/examples/time-series-serverless/database.py index df3e2f6a..8e8d9181 100644 --- a/examples/time-series-serverless/database.py +++ b/examples/time-series-serverless/database.py @@ -14,7 +14,7 @@ def create_driver(self) -> ydb.Driver: driver_config = ydb.DriverConfig( self.config.endpoint, self.config.database, - credentials=ydb.construct_credentials_from_environ(), + credentials=ydb.credentials_from_env_variables(), root_certificates=ydb.load_ydb_root_certificate(), ) diff --git a/examples/ttl/ttl.py b/examples/ttl/ttl.py index bdbe0dec..7a4c5cc4 100644 --- a/examples/ttl/ttl.py +++ b/examples/ttl/ttl.py @@ -307,7 +307,7 @@ def _run(driver, database, path): def run(endpoint, database, path): driver_config = ydb.DriverConfig( - endpoint, database, credentials=ydb.construct_credentials_from_environ() + endpoint, database, credentials=ydb.credentials_from_env_variables() ) with ydb.Driver(driver_config) as driver: try: diff --git a/examples/ttl_readtable/ttl.py b/examples/ttl_readtable/ttl.py index f5fff741..f9ca50aa 100644 --- a/examples/ttl_readtable/ttl.py +++ b/examples/ttl_readtable/ttl.py @@ -288,7 +288,7 @@ def _run(driver, session_pool, database, path): def run(endpoint, database, path): driver_config = ydb.DriverConfig( - endpoint, database, credentials=ydb.construct_credentials_from_environ() + endpoint, database, credentials=ydb.credentials_from_env_variables() ) with ydb.Driver(driver_config) as driver: try: diff --git a/tests/aio/test_connection_pool.py b/tests/aio/test_connection_pool.py index 221f1e39..12882f38 100644 --- a/tests/aio/test_connection_pool.py +++ b/tests/aio/test_connection_pool.py @@ -11,8 +11,6 @@ async def test_async_call(endpoint, database): driver_config = ydb.DriverConfig( endpoint, database, - credentials=ydb.construct_credentials_from_environ(), - root_certificates=ydb.load_ydb_root_certificate(), ) driver = Driver(driver_config=driver_config) @@ -26,8 +24,6 @@ async def test_gzip_compression(endpoint, database): driver_config = ydb.DriverConfig( endpoint, database, - credentials=ydb.construct_credentials_from_environ(), - root_certificates=ydb.load_ydb_root_certificate(), compression=ydb.RPCCompression.Gzip, ) @@ -53,8 +49,6 @@ async def test_session(endpoint, database): driver_config = ydb.DriverConfig( endpoint, database, - credentials=ydb.construct_credentials_from_environ(), - root_certificates=ydb.load_ydb_root_certificate(), ) driver = Driver(driver_config=driver_config) @@ -98,8 +92,6 @@ async def test_raises_when_disconnect(endpoint, database, docker_project): driver_config = ydb.DriverConfig( endpoint, database, - credentials=ydb.construct_credentials_from_environ(), - root_certificates=ydb.load_ydb_root_certificate(), ) driver = Driver(driver_config=driver_config) @@ -124,8 +116,6 @@ async def test_disconnect_by_call(endpoint, database, docker_project): driver_config = ydb.DriverConfig( endpoint, database, - credentials=ydb.construct_credentials_from_environ(), - root_certificates=ydb.load_ydb_root_certificate(), ) driver = Driver(driver_config=driver_config) diff --git a/tests/conftest.py b/tests/conftest.py index 26580e8d..09c02977 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -80,8 +80,6 @@ async def driver(endpoint, database, event_loop): driver_config = ydb.DriverConfig( endpoint, database, - credentials=ydb.construct_credentials_from_environ(), - root_certificates=ydb.load_ydb_root_certificate(), ) driver = ydb.aio.Driver(driver_config=driver_config) @@ -97,8 +95,6 @@ async def driver_sync(endpoint, database, event_loop): driver_config = ydb.DriverConfig( endpoint, database, - credentials=ydb.construct_credentials_from_environ(), - root_certificates=ydb.load_ydb_root_certificate(), ) driver = ydb.Driver(driver_config=driver_config) diff --git a/ydb/auth_helpers.py b/ydb/auth_helpers.py index 5d889555..043e3fd4 100644 --- a/ydb/auth_helpers.py +++ b/ydb/auth_helpers.py @@ -16,42 +16,3 @@ def load_ydb_root_certificate(): return read_bytes(path) return None - -def construct_credentials_from_environ(tracer=None): - tracer = tracer if tracer is not None else tracing.Tracer(None) - warnings.warn( - "using construct_credentials_from_environ method for credentials instantiation is deprecated and will be " - "removed in the future major releases. Please instantialize credentials by default or provide correct credentials " - "instance to the Driver." - ) - - # dynamically import required authentication libraries - if ( - os.getenv("USE_METADATA_CREDENTIALS") is not None - and int(os.getenv("USE_METADATA_CREDENTIALS")) == 1 - ): - import ydb.iam - - tracing.trace(tracer, {"credentials.metadata": True}) - return ydb.iam.MetadataUrlCredentials() - - if os.getenv("YDB_TOKEN") is not None: - tracing.trace(tracer, {"credentials.access_token": True}) - return credentials.AuthTokenCredentials(os.getenv("YDB_TOKEN")) - - if os.getenv("SA_KEY_FILE") is not None: - - import ydb.iam - - tracing.trace(tracer, {"credentials.sa_key_file": True}) - root_certificates_file = os.getenv("SSL_ROOT_CERTIFICATES_FILE", None) - iam_channel_credentials = {} - if root_certificates_file is not None: - iam_channel_credentials = { - "root_certificates": read_bytes(root_certificates_file) - } - return ydb.iam.ServiceAccountCredentials.from_file( - os.getenv("SA_KEY_FILE"), - iam_channel_credentials=iam_channel_credentials, - iam_endpoint=os.getenv("IAM_ENDPOINT", "iam.api.cloud.yandex.net:443"), - ) From 18a7636d4b0c57d70ecb9fea40fb1fdd64de5841 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Wed, 22 Feb 2023 21:54:59 +0300 Subject: [PATCH 062/147] fix linter --- ydb/auth_helpers.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/ydb/auth_helpers.py b/ydb/auth_helpers.py index 043e3fd4..6399c3cf 100644 --- a/ydb/auth_helpers.py +++ b/ydb/auth_helpers.py @@ -1,9 +1,6 @@ # -*- coding: utf-8 -*- import os -from . import credentials, tracing -import warnings - def read_bytes(f): with open(f, "rb") as fr: @@ -15,4 +12,3 @@ def load_ydb_root_certificate(): if path is not None and os.path.exists(path): return read_bytes(path) return None - From 36d26a46fd95b22609c775266f36f522a0c8ec15 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 28 Feb 2023 17:09:11 +0300 Subject: [PATCH 063/147] requirements to setup.py from requirements.txt add six as dependency --- CHANGELOG.md | 1 + MANIFEST.in | 1 + requirements.txt | 6 ++++-- setup.py | 14 ++++++++------ 4 files changed, 14 insertions(+), 8 deletions(-) create mode 100644 MANIFEST.in diff --git a/CHANGELOG.md b/CHANGELOG.md index 61ffe36f..2244ae9d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ * Use anonymous credentials by default instead of iam metadata (use ydb.driver.credentials_from_env_variables for creds by env var) * Close grpc streams while closing readers/writers * Add control plane operations for topic api: create, drop +* Add six package to requirements ## 3.0.1b4 ## * Initial implementation of topic reader diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 00000000..f9bd1455 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +include requirements.txt diff --git a/requirements.txt b/requirements.txt index 57470a28..da37d9fa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,6 @@ +aiohttp>=3.7.4,<4.0.0 +enum-compat>=0.0.1 grpcio>=1.42.0 packaging -protobuf>=3.13.0,<5.0.0 -aiohttp>=3.7.4,<4.0.0 +protobuf>3.13.0,<5.0.0 +six<2 \ No newline at end of file diff --git a/setup.py b/setup.py index 389711ba..6c6ab5bd 100644 --- a/setup.py +++ b/setup.py @@ -4,6 +4,13 @@ with open("README.md", "r") as r: long_description = r.read() +with open("requirements.txt") as r: + requirements = [] + for line in r.readlines(): + line = line.strip() + if line != "": + requirements.append(line) + setuptools.setup( name="ydb", version="3.0.1b4", # AUTOVERSION @@ -23,12 +30,7 @@ "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.6", ], - install_requires=( - "protobuf>=3.13.0", - "grpcio>=1.5.0", - "enum-compat>=0.0.1", - "packaging" - ), + install_requires=requirements, # requirements.txt options={"bdist_wheel": {"universal": True}}, extras_require={ "yc": ["yandexcloud", ], From 37e930f2092efa7e092b069f12c291433b282f28 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 28 Feb 2023 18:46:49 +0300 Subject: [PATCH 064/147] remove six package --- .../cloud_function/requirements.txt | 1 - kikimr/public/sdk/python/client/__init__.py | 3 +-- test-requirements.txt | 1 - ydb/_sp_impl.py | 2 +- ydb/_utilities.py | 12 +++++------- ydb/aio/credentials.py | 6 +++--- ydb/aio/iam.py | 4 +--- ydb/convert.py | 19 ++++++------------- ydb/credentials.py | 10 +++------- ydb/dbapi/cursor.py | 4 +--- ydb/default_pem.py | 7 +------ ydb/driver.py | 6 +----- ydb/iam/auth.py | 7 ++----- ydb/issues.py | 2 +- ydb/pool.py | 8 +++----- ydb/scheme.py | 4 +--- ydb/table.py | 15 +++++---------- ydb/types.py | 7 +------ 18 files changed, 36 insertions(+), 82 deletions(-) diff --git a/examples/reservations-bot-demo/cloud_function/requirements.txt b/examples/reservations-bot-demo/cloud_function/requirements.txt index 9b243fbb..56665263 100644 --- a/examples/reservations-bot-demo/cloud_function/requirements.txt +++ b/examples/reservations-bot-demo/cloud_function/requirements.txt @@ -11,7 +11,6 @@ pycparser==2.20 pydantic==1.6.2 PyJWT==2.4.0 requests==2.24.0 -six==1.15.0 urllib3==1.26.5 yandexcloud==0.48.0 ydb==0.0.41 diff --git a/kikimr/public/sdk/python/client/__init__.py b/kikimr/public/sdk/python/client/__init__.py index e2050d47..157c103e 100644 --- a/kikimr/public/sdk/python/client/__init__.py +++ b/kikimr/public/sdk/python/client/__init__.py @@ -1,13 +1,12 @@ # -*- coding: utf-8 -*- from ydb import * # noqa import sys -import six import warnings warnings.warn("module kikimr.public.sdk.python.client is deprecated. please use ydb instead") -for name, module in six.iteritems(sys.modules.copy()): +for name, module in sys.modules.copy().items(): if not name.startswith("ydb"): continue diff --git a/test-requirements.txt b/test-requirements.txt index c4a58290..273f0fb6 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -32,7 +32,6 @@ pytest-docker-compose==3.2.1 python-dotenv==0.18.0 PyYAML==5.4.1 requests==2.26.0 -six==1.16.0 texttable==1.6.4 toml==0.10.2 typing-extensions==3.10.0.0 diff --git a/ydb/_sp_impl.py b/ydb/_sp_impl.py index a8529d73..5974a301 100644 --- a/ydb/_sp_impl.py +++ b/ydb/_sp_impl.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- import collections from concurrent import futures -from six.moves import queue +import queue import time import threading from . import settings, issues, _utilities, tracing diff --git a/ydb/_utilities.py b/ydb/_utilities.py index 32419b1b..544b154c 100644 --- a/ydb/_utilities.py +++ b/ydb/_utilities.py @@ -1,10 +1,10 @@ # -*- coding: utf-8 -*- -import six import codecs from concurrent import futures import functools import hashlib import collections +import urllib.parse from . import ydb_version try: @@ -55,8 +55,8 @@ def parse_connection_string(connection_string): # default is grpcs cs = _grpcs_protocol + cs - p = six.moves.urllib.parse.urlparse(connection_string) - b = six.moves.urllib.parse.parse_qs(p.query) + p = urllib.parse.urlparse(connection_string) + b = urllib.parse.parse_qs(p.query) database = b.get("database", []) assert len(database) > 0 @@ -77,11 +77,9 @@ def decorator(*args, **kwargs): def get_query_hash(yql_text): try: - return hashlib.sha256( - six.text_type(yql_text, "utf-8").encode("utf-8") - ).hexdigest() + return hashlib.sha256(str(yql_text, "utf-8").encode("utf-8")).hexdigest() except TypeError: - return hashlib.sha256(six.text_type(yql_text).encode("utf-8")).hexdigest() + return hashlib.sha256(str(yql_text).encode("utf-8")).hexdigest() class LRUCache(object): diff --git a/ydb/aio/credentials.py b/ydb/aio/credentials.py index e9840440..3a4e64a9 100644 --- a/ydb/aio/credentials.py +++ b/ydb/aio/credentials.py @@ -3,7 +3,6 @@ import abc import asyncio import logging -import six from ydb import issues, credentials logger = logging.getLogger(__name__) @@ -55,8 +54,9 @@ def submit(self, callback): asyncio.ensure_future(self._wrapped_execution(callback)) -@six.add_metaclass(abc.ABCMeta) -class AbstractExpiringTokenCredentials(credentials.AbstractExpiringTokenCredentials): +class AbstractExpiringTokenCredentials( + credentials.AbstractExpiringTokenCredentials, abc.ABC +): def __init__(self): super(AbstractExpiringTokenCredentials, self).__init__() self._tp = _AtMostOneExecution() diff --git a/ydb/aio/iam.py b/ydb/aio/iam.py index 51b650f2..7440f5c4 100644 --- a/ydb/aio/iam.py +++ b/ydb/aio/iam.py @@ -3,7 +3,6 @@ import abc import logging -import six from ydb.iam import auth from .credentials import AbstractExpiringTokenCredentials @@ -24,8 +23,7 @@ aiohttp = None -@six.add_metaclass(abc.ABCMeta) -class TokenServiceCredentials(AbstractExpiringTokenCredentials): +class TokenServiceCredentials(AbstractExpiringTokenCredentials, abc.ABC): def __init__(self, iam_endpoint=None, iam_channel_credentials=None): super(TokenServiceCredentials, self).__init__() assert ( diff --git a/ydb/convert.py b/ydb/convert.py index 02c7de0c..70bc638e 100644 --- a/ydb/convert.py +++ b/ydb/convert.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- import decimal from google.protobuf import struct_pb2 -import six from . import issues, types, _apis @@ -81,9 +80,7 @@ def _pb_to_list(type_pb, value_pb, table_client_settings): def _pb_to_tuple(type_pb, value_pb, table_client_settings): return tuple( _to_native_value(item_type, item_value, table_client_settings) - for item_type, item_value in six.moves.zip( - type_pb.tuple_type.elements, value_pb.items - ) + for item_type, item_value in zip(type_pb.tuple_type.elements, value_pb.items) ) @@ -106,7 +103,7 @@ class _Struct(_DotDict): def _pb_to_struct(type_pb, value_pb, table_client_settings): result = _Struct() - for member, item in six.moves.zip(type_pb.struct_type.members, value_pb.items): + for member, item in zip(type_pb.struct_type.members, value_pb.items): result[member.name] = _to_native_value(member.type, item, table_client_settings) return result @@ -201,9 +198,7 @@ def _list_to_pb(type_pb, value): def _tuple_to_pb(type_pb, value): value_pb = _apis.ydb_value.Value() - for element_type, element_value in six.moves.zip( - type_pb.tuple_type.elements, value - ): + for element_type, element_value in zip(type_pb.tuple_type.elements, value): value_item_proto = value_pb.items.add() value_item_proto.MergeFrom(_from_native_value(element_type, element_value)) return value_pb @@ -289,7 +284,7 @@ def parameters_to_pb(parameters_types, parameters_values): return {} param_values_pb = {} - for name, type_pb in six.iteritems(parameters_types): + for name, type_pb in parameters_types.items(): result = _apis.ydb_value.TypedValue() ttype = type_pb if isinstance(type_pb, types.AbstractTypeBuilder): @@ -330,7 +325,7 @@ def from_message(cls, message, table_client_settings=None): for row_proto in message.rows: row = _Row(message.columns) - for column, value, column_info in six.moves.zip( + for column, value, column_info in zip( message.columns, row_proto.items, column_parsers ): v_type = value.WhichOneof("value") @@ -398,9 +393,7 @@ def __init__(self, columns, proto_row, table_client_settings, parsers): super(_LazyRow, self).__init__() self._columns = columns self._table_client_settings = table_client_settings - for i, (column, row_item) in enumerate( - six.moves.zip(self._columns, proto_row.items) - ): + for i, (column, row_item) in enumerate(zip(self._columns, proto_row.items)): super(_LazyRow, self).__setitem__( column.name, _LazyRowItem(row_item, column.type, table_client_settings, parsers[i]), diff --git a/ydb/credentials.py b/ydb/credentials.py index 330eefda..abc3dfee 100644 --- a/ydb/credentials.py +++ b/ydb/credentials.py @@ -2,7 +2,6 @@ import abc import typing -import six from . import tracing, issues, connection from . import settings as settings_impl import threading @@ -23,15 +22,13 @@ logger = logging.getLogger(__name__) -@six.add_metaclass(abc.ABCMeta) -class AbstractCredentials(object): +class AbstractCredentials(abc.ABC): """ An abstract class that provides auth metadata """ -@six.add_metaclass(abc.ABCMeta) -class Credentials(object): +class Credentials(abc.ABC): def __init__(self, tracer=None): self.tracer = tracer if tracer is not None else tracing.Tracer(None) @@ -88,8 +85,7 @@ def cleanup(self): self._can_schedule = True -@six.add_metaclass(abc.ABCMeta) -class AbstractExpiringTokenCredentials(Credentials): +class AbstractExpiringTokenCredentials(Credentials, abc.ABC): def __init__(self, tracer=None): super(AbstractExpiringTokenCredentials, self).__init__(tracer) self._expires_in = 0 diff --git a/ydb/dbapi/cursor.py b/ydb/dbapi/cursor.py index 71175abf..eb26dc2b 100644 --- a/ydb/dbapi/cursor.py +++ b/ydb/dbapi/cursor.py @@ -5,8 +5,6 @@ import itertools import logging -import six - import ydb from .errors import DatabaseError @@ -42,7 +40,7 @@ def render_datetime(value): def render(value): if value is None: return "NULL" - if isinstance(value, six.string_types): + if isinstance(value, str): return render_str(value) if isinstance(value, datetime.datetime): return render_datetime(value) diff --git a/ydb/default_pem.py b/ydb/default_pem.py index 92286ba2..b8272efd 100644 --- a/ydb/default_pem.py +++ b/ydb/default_pem.py @@ -1,6 +1,3 @@ -import six - - data = """ # Issuer: CN=GlobalSign Root CA O=GlobalSign nv-sa OU=Root CA # Subject: CN=GlobalSign Root CA O=GlobalSign nv-sa OU=Root CA @@ -4686,6 +4683,4 @@ def load_default_pem(): global data - if six.PY3: - return data.encode("utf-8") - return data + return data.encode("utf-8") diff --git a/ydb/driver.py b/ydb/driver.py index 0ef723fe..e3274687 100644 --- a/ydb/driver.py +++ b/ydb/driver.py @@ -1,15 +1,11 @@ # -*- coding: utf-8 -*- from . import credentials as credentials_impl, table, scheme, pool from . import tracing -import six import os import grpc from . import _utilities -if six.PY2: - Any = None -else: - from typing import Any # noqa +from typing import Any # noqa class RPCCompression: diff --git a/ydb/iam/auth.py b/ydb/iam/auth.py index 06b07e91..c9badbbf 100644 --- a/ydb/iam/auth.py +++ b/ydb/iam/auth.py @@ -3,7 +3,6 @@ import grpc import time import abc -import six from datetime import datetime import json import os @@ -45,8 +44,7 @@ def get_jwt(account_id, access_key_id, private_key, jwt_expiration_timeout): ) -@six.add_metaclass(abc.ABCMeta) -class TokenServiceCredentials(credentials.AbstractExpiringTokenCredentials): +class TokenServiceCredentials(credentials.AbstractExpiringTokenCredentials, abc.ABC): def __init__(self, iam_endpoint=None, iam_channel_credentials=None, tracer=None): super(TokenServiceCredentials, self).__init__(tracer) assert ( @@ -84,8 +82,7 @@ def _make_token_request(self): return {"access_token": response.iam_token, "expires_in": expires_in} -@six.add_metaclass(abc.ABCMeta) -class BaseJWTCredentials(object): +class BaseJWTCredentials(abc.ABC): def __init__(self, account_id, access_key_id, private_key): self._account_id = account_id self._jwt_expiration_timeout = 60.0 * 60 diff --git a/ydb/issues.py b/ydb/issues.py index 727aff1b..5a57f4d2 100644 --- a/ydb/issues.py +++ b/ydb/issues.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- from google.protobuf import text_format import enum -from six.moves import queue +import queue from . import _apis diff --git a/ydb/pool.py b/ydb/pool.py index 73cd1681..007aa94d 100644 --- a/ydb/pool.py +++ b/ydb/pool.py @@ -1,14 +1,13 @@ # -*- coding: utf-8 -*- +import abc import threading import logging from concurrent import futures import collections import random -import six - from . import connection as connection_impl, issues, resolver, _utilities, tracing -from abc import abstractmethod, ABCMeta +from abc import abstractmethod from .connection import Connection @@ -296,8 +295,7 @@ def run(self): self.logger.info("Successfully terminated discovery process") -@six.add_metaclass(ABCMeta) -class IConnectionPool: +class IConnectionPool(abc.ABC): @abstractmethod def __init__(self, driver_config): """ diff --git a/ydb/scheme.py b/ydb/scheme.py index 88eca78c..96a6c25f 100644 --- a/ydb/scheme.py +++ b/ydb/scheme.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- import abc import enum -import six from abc import abstractmethod from . import issues, operation, settings as settings_impl, _apis @@ -347,8 +346,7 @@ def _wrap_describe_path_response(rpc_state, response): return _wrap_scheme_entry(message.self) -@six.add_metaclass(abc.ABCMeta) -class ISchemeClient: +class ISchemeClient(abc.ABC): @abstractmethod def __init__(self, driver): pass diff --git a/ydb/table.py b/ydb/table.py index 6f0a4868..d60f138a 100644 --- a/ydb/table.py +++ b/ydb/table.py @@ -7,7 +7,6 @@ import random import enum -import six from . import ( issues, convert, @@ -768,8 +767,7 @@ def with_compaction_policy(self, compaction_policy): return self -@six.add_metaclass(abc.ABCMeta) -class AbstractTransactionModeBuilder(object): +class AbstractTransactionModeBuilder(abc.ABC): @property @abc.abstractmethod def name(self): @@ -947,7 +945,7 @@ def retry_operation_impl(callee, retry_settings=None, *args, **kwargs): retry_settings = RetrySettings() if retry_settings is None else retry_settings status = None - for attempt in six.moves.range(retry_settings.max_retries + 1): + for attempt in range(retry_settings.max_retries + 1): try: result = YdbRetryOperationFinalResult(callee(*args, **kwargs)) yield result @@ -1095,8 +1093,7 @@ def _scan_query_request_factory(query, parameters=None, settings=None): ) -@six.add_metaclass(abc.ABCMeta) -class ISession: +class ISession(abc.ABC): @abstractmethod def __init__(self, driver, table_client_settings): pass @@ -1258,8 +1255,7 @@ def describe_table(self, path, settings=None): pass -@six.add_metaclass(abc.ABCMeta) -class ITableClient: +class ITableClient(abc.ABC): def __init__(self, driver, table_client_settings=None): pass @@ -2082,8 +2078,7 @@ def async_describe_table(self, path, settings=None): ) -@six.add_metaclass(abc.ABCMeta) -class ITxContext: +class ITxContext(abc.ABC): @abstractmethod def __init__(self, driver, session_state, session, tx_mode=None): """ diff --git a/ydb/types.py b/ydb/types.py index 598a9013..ce5c4673 100644 --- a/ydb/types.py +++ b/ydb/types.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- import abc import enum -import six import json from . import _utilities, _apis from datetime import date, datetime, timedelta @@ -12,12 +11,8 @@ _SECONDS_IN_DAY = 60 * 60 * 24 _EPOCH = datetime(1970, 1, 1) -if six.PY3: - _from_bytes = None -else: - def _from_bytes(x, table_client_settings): - return _utilities.from_bytes(x) +_from_bytes = None def _from_date_number(x, table_client_settings): From b03f958c92e50dfe4efe3533012bc0a1c9cef659 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 28 Feb 2023 18:48:30 +0300 Subject: [PATCH 065/147] changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2244ae9d..465891a2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,4 @@ +* Remove six package from code and dependencies (remove support python2) * Use anonymous credentials by default instead of iam metadata (use ydb.driver.credentials_from_env_variables for creds by env var) * Close grpc streams while closing readers/writers * Add control plane operations for topic api: create, drop From 67dc0dd7df03bddf5a1b26653a1f2f7334b1d3f2 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Wed, 22 Feb 2023 21:21:32 +0300 Subject: [PATCH 066/147] add sync reader --- tests/topics/test_topic_reader.py | 12 +- ydb/_topic_common/common.py | 37 +++++ ydb/_topic_reader/topic_reader.py | 127 ---------------- ydb/_topic_reader/topic_reader_sync.py | 195 +++++++++++++++++++++++++ ydb/_topic_writer/topic_writer_sync.py | 53 ++----- ydb/topic.py | 42 +++--- 6 files changed, 274 insertions(+), 192 deletions(-) create mode 100644 ydb/_topic_reader/topic_reader_sync.py diff --git a/tests/topics/test_topic_reader.py b/tests/topics/test_topic_reader.py index ac338fbd..8107ac16 100644 --- a/tests/topics/test_topic_reader.py +++ b/tests/topics/test_topic_reader.py @@ -2,7 +2,7 @@ @pytest.mark.asyncio -class TestTopicWriterAsyncIO: +class TestTopicReaderAsyncIO: async def test_read_message( self, driver, topic_path, topic_with_messages, topic_consumer ): @@ -10,3 +10,13 @@ async def test_read_message( assert await reader.receive_batch() is not None await reader.close() + + +class TestTopicReaderSync: + def test_read_message( + self, driver_sync, topic_path, topic_with_messages, topic_consumer + ): + reader = driver_sync.topic_client.topic_reader(topic_consumer, topic_path) + + assert reader.receive_batch() is not None + reader.close() diff --git a/ydb/_topic_common/common.py b/ydb/_topic_common/common.py index e325ca4b..5bb10654 100644 --- a/ydb/_topic_common/common.py +++ b/ydb/_topic_common/common.py @@ -1,4 +1,8 @@ +import asyncio +import concurrent.futures +import threading import typing +from typing import Optional from .. import operation, issues from .._grpc.grpcwrapper.common_utils import IFromProtoWithProtoType @@ -24,3 +28,36 @@ def wrapper(rpc_state, response_pb, driver=None): return result_type.from_proto(msg) return wrapper + + +_shared_event_loop_lock = threading.Lock() +_shared_event_loop = None # type: Optional[asyncio.AbstractEventLoop] + + +def _get_shared_event_loop() -> asyncio.AbstractEventLoop: + global _shared_event_loop + + if _shared_event_loop is not None: + return _shared_event_loop + + with _shared_event_loop_lock: + if _shared_event_loop is not None: + return _shared_event_loop + + event_loop_set_done = concurrent.futures.Future() + + def start_event_loop(): + event_loop = asyncio.new_event_loop() + event_loop_set_done.set_result(event_loop) + asyncio.set_event_loop(event_loop) + event_loop.run_forever() + + t = threading.Thread( + target=start_event_loop, + name="Common ydb topic writer event loop", + daemon=True, + ) + t.start() + + _shared_event_loop = event_loop_set_done.result() + return _shared_event_loop diff --git a/ydb/_topic_reader/topic_reader.py b/ydb/_topic_reader/topic_reader.py index 7bb6d934..4c9e63e1 100644 --- a/ydb/_topic_reader/topic_reader.py +++ b/ydb/_topic_reader/topic_reader.py @@ -1,4 +1,3 @@ -import concurrent.futures import enum import datetime from dataclasses import dataclass @@ -6,11 +5,9 @@ Union, Optional, List, - Iterable, ) from ..table import RetrySettings -from .datatypes import ICommittable, PublicBatch, PublicMessage from .._topic_common.common import TokenGetterFuncType from .._grpc.grpcwrapper.ydb_topic import StreamReadMessage, OffsetsRange @@ -26,130 +23,6 @@ def __init__(self, path, *, partitions: Union[None, int, List[int]] = None): self.partitions = partitions -class Reader(object): - def async_sessions_stat(self) -> concurrent.futures.Future: - """ - Receive stat from the server, return feature. - """ - raise NotImplementedError() - - async def sessions_stat(self) -> List["SessionStat"]: - """ - Receive stat from the server - - use async_sessions_stat for set explicit wait timeout - """ - raise NotImplementedError() - - def messages( - self, *, timeout: Union[float, None] = None - ) -> Iterable[PublicMessage]: - """ - todo? - - Block until receive new message - It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available. - - if no new message in timeout seconds (default - infinite): stop iterations by raise StopIteration - if timeout <= 0 - it will fast non block method, get messages from internal buffer only. - """ - raise NotImplementedError() - - def receive_message(self, *, timeout: Union[float, None] = None) -> PublicMessage: - """ - Block until receive new message - It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available. - - if no new message in timeout seconds (default - infinite): raise TimeoutError() - if timeout <= 0 - it will fast non block method, get messages from internal buffer only. - """ - raise NotImplementedError() - - def async_wait_message(self) -> concurrent.futures.Future: - """ - Return future, which will completed when the reader has least one message in queue. - If reader already has message - future will return completed. - - Possible situation when receive signal about message available, but no messages when try to receive a message. - If message expired between send event and try to retrieve message (for example connection broken). - """ - raise NotImplementedError() - - def batches( - self, - *, - max_messages: Union[int, None] = None, - max_bytes: Union[int, None] = None, - timeout: Union[float, None] = None, - ) -> Iterable[PublicBatch]: - """ - Block until receive new batch. - It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available. - - if no new message in timeout seconds (default - infinite): stop iterations by raise StopIteration - if timeout <= 0 - it will fast non block method, get messages from internal buffer only. - """ - raise NotImplementedError() - - def receive_batch( - self, - *, - max_messages: Union[int, None] = None, - max_bytes: Union[int, None], - timeout: Union[float, None] = None, - ) -> Union[PublicBatch, None]: - """ - Get one messages batch from reader - It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available. - - if no new message in timeout seconds (default - infinite): raise TimeoutError() - if timeout <= 0 - it will fast non block method, get messages from internal buffer only. - """ - raise NotImplementedError() - - def commit(self, mess: ICommittable): - """ - Put commit message to internal buffer. - - For the method no way check the commit result - (for example if lost connection - commits will not re-send and committed messages will receive again) - """ - raise NotImplementedError() - - def commit_with_ack( - self, mess: ICommittable - ) -> Union["CommitResult", List["CommitResult"]]: - """ - write commit message to a buffer and wait ack from the server. - - if receive in timeout seconds (default - infinite): raise TimeoutError() - """ - raise NotImplementedError() - - def async_commit_with_ack( - self, mess: ICommittable - ) -> Union["CommitResult", List["CommitResult"]]: - """ - write commit message to a buffer and return Future for wait result. - """ - raise NotImplementedError() - - def async_flush(self) -> concurrent.futures.Future: - """ - force send all commit messages from internal buffers to server and return Future for wait server acks. - """ - raise NotImplementedError() - - def flush(self): - """ - force send all commit messages from internal buffers to server and wait acks for all of them. - """ - raise NotImplementedError() - - def close(self): - raise NotImplementedError() - - @dataclass class PublicReaderSettings: consumer: str diff --git a/ydb/_topic_reader/topic_reader_sync.py b/ydb/_topic_reader/topic_reader_sync.py new file mode 100644 index 00000000..b30b547a --- /dev/null +++ b/ydb/_topic_reader/topic_reader_sync.py @@ -0,0 +1,195 @@ +import asyncio +import concurrent.futures +import typing +from typing import List, Union, Iterable, Optional, Coroutine + +from ydb._grpc.grpcwrapper.common_utils import SupportedDriverType +from ydb._topic_common.common import _get_shared_event_loop +from ydb._topic_reader.datatypes import PublicMessage, PublicBatch, ICommittable +from ydb._topic_reader.topic_reader import ( + PublicReaderSettings, + SessionStat, + CommitResult, +) +from ydb._topic_reader.topic_reader_asyncio import ( + PublicAsyncIOReader, + TopicReaderClosedError, +) + + +class TopicReaderSync: + _loop: asyncio.AbstractEventLoop + _async_reader: PublicAsyncIOReader + _closed: bool + + def __init__( + self, + driver: SupportedDriverType, + settings: PublicReaderSettings, + *, + eventloop: Optional[asyncio.AbstractEventLoop] = None, + ): + self._closed = False + + if eventloop: + self._loop = eventloop + else: + self._loop = _get_shared_event_loop() + + async def create_reader(): + return PublicAsyncIOReader(driver, settings) + + self._async_reader = asyncio.run_coroutine_threadsafe( + create_reader(), self._loop + ).result() + + def __del__(self): + self.close() + + def _call(self, coro): + if self._closed: + raise TopicReaderClosedError() + + return asyncio.run_coroutine_threadsafe(coro, self._loop) + + def _call_sync(self, coro: Coroutine, timeout): + f = self._call(coro) + try: + return f.result(timeout) + except TimeoutError: + f.cancel() + raise + + def async_sessions_stat(self) -> concurrent.futures.Future: + """ + Receive stat from the server, return feature. + """ + raise NotImplementedError() + + async def sessions_stat(self) -> List[SessionStat]: + """ + Receive stat from the server + + use async_sessions_stat for set explicit wait timeout + """ + raise NotImplementedError() + + def messages( + self, *, timeout: Union[float, None] = None + ) -> Iterable[PublicMessage]: + """ + todo? + + Block until receive new message + It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available. + + if no new message in timeout seconds (default - infinite): stop iterations by raise StopIteration + if timeout <= 0 - it will fast non block method, get messages from internal buffer only. + """ + raise NotImplementedError() + + def receive_message(self, *, timeout: Union[float, None] = None) -> PublicMessage: + """ + Block until receive new message + It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available. + + if no new message in timeout seconds (default - infinite): raise TimeoutError() + if timeout <= 0 - it will fast non block method, get messages from internal buffer only. + """ + raise NotImplementedError() + + def async_wait_message(self) -> concurrent.futures.Future: + """ + Return future, which will completed when the reader has least one message in queue. + If reader already has message - future will return completed. + + Possible situation when receive signal about message available, but no messages when try to receive a message. + If message expired between send event and try to retrieve message (for example connection broken). + """ + raise NotImplementedError() + + def batches( + self, + *, + max_messages: Union[int, None] = None, + max_bytes: Union[int, None] = None, + timeout: Union[float, None] = None, + ) -> Iterable[PublicBatch]: + """ + Block until receive new batch. + It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available. + + if no new message in timeout seconds (default - infinite): stop iterations by raise StopIteration + if timeout <= 0 - it will fast non block method, get messages from internal buffer only. + """ + raise NotImplementedError() + + def receive_batch( + self, + *, + max_messages: typing.Union[int, None] = None, + max_bytes: typing.Union[int, None] = None, + timeout: Union[float, None] = None, + ) -> Union[PublicBatch, None]: + """ + Get one messages batch from reader + It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available. + + if no new message in timeout seconds (default - infinite): raise TimeoutError() + if timeout <= 0 - it will fast non block method, get messages from internal buffer only. + """ + return self._call_sync( + self._async_reader.receive_batch( + max_messages=max_messages, max_bytes=max_bytes + ), + timeout, + ) + + def commit(self, mess: ICommittable): + """ + Put commit message to internal buffer. + + For the method no way check the commit result + (for example if lost connection - commits will not re-send and committed messages will receive again) + """ + self._call_sync(self._async_reader.commit(mess), None) + + def commit_with_ack( + self, mess: ICommittable + ) -> Union[CommitResult, List[CommitResult]]: + """ + write commit message to a buffer and wait ack from the server. + + if receive in timeout seconds (default - infinite): raise TimeoutError() + """ + raise NotImplementedError() + + def async_commit_with_ack( + self, mess: ICommittable + ) -> Union[CommitResult, List[CommitResult]]: + """ + write commit message to a buffer and return Future for wait result. + """ + raise NotImplementedError() + + def async_flush(self) -> concurrent.futures.Future: + """ + force send all commit messages from internal buffers to server and return Future for wait server acks. + """ + raise NotImplementedError() + + def flush(self): + """ + force send all commit messages from internal buffers to server and wait acks for all of them. + """ + raise NotImplementedError() + + def close(self): + if self._closed: + return + self._closed = True + + # for no call self._call_sync on closed object + asyncio.run_coroutine_threadsafe( + self._async_reader.close(), self._loop + ).result() diff --git a/ydb/_topic_writer/topic_writer_sync.py b/ydb/_topic_writer/topic_writer_sync.py index 2c58325e..419edcba 100644 --- a/ydb/_topic_writer/topic_writer_sync.py +++ b/ydb/_topic_writer/topic_writer_sync.py @@ -2,7 +2,6 @@ import asyncio from concurrent.futures import Future -import threading from typing import Union, List, Optional, Coroutine from .._grpc.grpcwrapper.common_utils import SupportedDriverType @@ -16,40 +15,7 @@ ) from .topic_writer_asyncio import WriterAsyncIO -from .._topic_common.common import TimeoutType - -_shared_event_loop_lock = threading.Lock() -_shared_event_loop = None # type: Optional[asyncio.AbstractEventLoop] - - -def _get_shared_event_loop() -> asyncio.AbstractEventLoop: - global _shared_event_loop - - if _shared_event_loop is not None: - return _shared_event_loop - - with _shared_event_loop_lock: - if _shared_event_loop is not None: - return _shared_event_loop - - event_loop_set_done = Future() - - def start_event_loop(): - global _shared_event_loop - _shared_event_loop = asyncio.new_event_loop() - event_loop_set_done.set_result(None) - asyncio.set_event_loop(_shared_event_loop) - _shared_event_loop.run_forever() - - t = threading.Thread( - target=start_event_loop, - name="Common ydb topic writer event loop", - daemon=True, - ) - t.start() - - event_loop_set_done.result() - return _shared_event_loop +from .._topic_common.common import _get_shared_event_loop, TimeoutType class WriterSync: @@ -62,7 +28,7 @@ def __init__( driver: SupportedDriverType, settings: PublicWriterSettings, *, - eventloop: asyncio.AbstractEventLoop = None, + eventloop: Optional[asyncio.AbstractEventLoop] = None, ): self._closed = False @@ -85,26 +51,29 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.close() - def _call(self, coro, *args, **kwargs): + def _call(self, coro): if self._closed: raise TopicWriterError("writer is closed") return asyncio.run_coroutine_threadsafe(coro, self._loop) - def _call_sync(self, coro: Coroutine, timeout, *args, **kwargs): - f = self._call(coro, *args, **kwargs) + def _call_sync(self, coro: Coroutine, timeout): + f = self._call(coro) try: - return f.result(timeout=timeout) + return f.result(timeout) except TimeoutError: f.cancel() raise - def close(self): + def close(self, flush: bool = True): if self._closed: return + self._closed = True + + # for no call self._call_sync on closed object asyncio.run_coroutine_threadsafe( - self._async_writer.close(), self._loop + self._async_writer.close(flush=flush), self._loop ).result() def async_flush(self) -> Future: diff --git a/ydb/topic.py b/ydb/topic.py index 593c0378..9378d100 100644 --- a/ydb/topic.py +++ b/ydb/topic.py @@ -7,13 +7,10 @@ from ._topic_reader.topic_reader import ( PublicReaderSettings as TopicReaderSettings, - Reader as TopicReader, - Selector as TopicSelector, - Events as TopicReaderEvents, - RetryPolicy as TopicReaderRetryPolicy, - StubEvent as TopicReaderStubEvent, ) +from ._topic_reader.topic_reader_sync import TopicReaderSync as TopicReader + from ._topic_reader.topic_reader_asyncio import ( PublicAsyncIOReader as TopicReaderAsyncIO, ) @@ -241,26 +238,27 @@ def drop_topic(self, path: str): def topic_reader( self, - topic: Union[str, TopicSelector, List[Union[str, TopicSelector]]], consumer: str, - commit_batch_time: Union[float, None] = 0.1, - commit_batch_count: Union[int, None] = 1000, + topic: str, buffer_size_bytes: int = 50 * 1024 * 1024, - sync_commit: bool = False, # reader.commit(...) will wait commit ack from server - on_commit: Callable[["TopicReaderStubEvent"], None] = None, - on_get_partition_start_offset: Callable[ - ["TopicReaderEvents.OnPartitionGetStartOffsetRequest"], - "TopicReaderEvents.OnPartitionGetStartOffsetResponse", - ] = None, - on_init_partition: Callable[["StubEvent"], None] = None, - on_shutdown_partition: Callable[["StubEvent"], None] = None, - decoder: Union[Mapping[int, Callable[[bytes], bytes]], None] = None, - deserializer: Union[Callable[[bytes], Any], None] = None, - one_attempt_connection_timeout: Union[float, None] = 1, - connection_timeout: Union[float, None] = None, - retry_policy: Union["TopicReaderRetryPolicy", None] = None, + # on_commit: Callable[["Events.OnCommit"], None] = None + # on_get_partition_start_offset: Callable[ + # ["Events.OnPartitionGetStartOffsetRequest"], + # "Events.OnPartitionGetStartOffsetResponse", + # ] = None + # on_partition_session_start: Callable[["StubEvent"], None] = None + # on_partition_session_stop: Callable[["StubEvent"], None] = None + # on_partition_session_close: Callable[["StubEvent"], None] = None # todo? + # decoder: Union[Mapping[int, Callable[[bytes], bytes]], None] = None + # deserializer: Union[Callable[[bytes], Any], None] = None + # one_attempt_connection_timeout: Union[float, None] = 1 + # connection_timeout: Union[float, None] = None + # retry_policy: Union["RetryPolicy", None] = None ) -> TopicReader: - raise NotImplementedError() + args = locals() + del args["self"] + settings = TopicReaderSettings(**args) + return TopicReader(self._driver, settings) def topic_writer( self, From 2249c6f109a287a62d5f6250a678529af6f79aca Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 28 Feb 2023 18:16:57 +0300 Subject: [PATCH 067/147] typo --- ydb/_topic_common/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ydb/_topic_common/common.py b/ydb/_topic_common/common.py index 5bb10654..f2d6ca9b 100644 --- a/ydb/_topic_common/common.py +++ b/ydb/_topic_common/common.py @@ -54,7 +54,7 @@ def start_event_loop(): t = threading.Thread( target=start_event_loop, - name="Common ydb topic writer event loop", + name="Common ydb topic event loop", daemon=True, ) t.start() From c8fc34b281b426b52b928d44f787929a265425ba Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Wed, 1 Mar 2023 12:39:35 +0300 Subject: [PATCH 068/147] small fixes --- ydb/aio/credentials.py | 4 +--- ydb/aio/iam.py | 2 +- ydb/credentials.py | 2 +- ydb/iam/auth.py | 2 +- ydb/types.py | 8 ++------ 5 files changed, 6 insertions(+), 12 deletions(-) diff --git a/ydb/aio/credentials.py b/ydb/aio/credentials.py index 3a4e64a9..93868b27 100644 --- a/ydb/aio/credentials.py +++ b/ydb/aio/credentials.py @@ -54,9 +54,7 @@ def submit(self, callback): asyncio.ensure_future(self._wrapped_execution(callback)) -class AbstractExpiringTokenCredentials( - credentials.AbstractExpiringTokenCredentials, abc.ABC -): +class AbstractExpiringTokenCredentials(credentials.AbstractExpiringTokenCredentials): def __init__(self): super(AbstractExpiringTokenCredentials, self).__init__() self._tp = _AtMostOneExecution() diff --git a/ydb/aio/iam.py b/ydb/aio/iam.py index 7440f5c4..b56c0660 100644 --- a/ydb/aio/iam.py +++ b/ydb/aio/iam.py @@ -23,7 +23,7 @@ aiohttp = None -class TokenServiceCredentials(AbstractExpiringTokenCredentials, abc.ABC): +class TokenServiceCredentials(AbstractExpiringTokenCredentials): def __init__(self, iam_endpoint=None, iam_channel_credentials=None): super(TokenServiceCredentials, self).__init__() assert ( diff --git a/ydb/credentials.py b/ydb/credentials.py index abc3dfee..13b45b20 100644 --- a/ydb/credentials.py +++ b/ydb/credentials.py @@ -85,7 +85,7 @@ def cleanup(self): self._can_schedule = True -class AbstractExpiringTokenCredentials(Credentials, abc.ABC): +class AbstractExpiringTokenCredentials(Credentials): def __init__(self, tracer=None): super(AbstractExpiringTokenCredentials, self).__init__(tracer) self._expires_in = 0 diff --git a/ydb/iam/auth.py b/ydb/iam/auth.py index c9badbbf..50d98b4b 100644 --- a/ydb/iam/auth.py +++ b/ydb/iam/auth.py @@ -44,7 +44,7 @@ def get_jwt(account_id, access_key_id, private_key, jwt_expiration_timeout): ) -class TokenServiceCredentials(credentials.AbstractExpiringTokenCredentials, abc.ABC): +class TokenServiceCredentials(credentials.AbstractExpiringTokenCredentials): def __init__(self, iam_endpoint=None, iam_channel_credentials=None, tracer=None): super(TokenServiceCredentials, self).__init__(tracer) assert ( diff --git a/ydb/types.py b/ydb/types.py index ce5c4673..49192fce 100644 --- a/ydb/types.py +++ b/ydb/types.py @@ -12,8 +12,6 @@ _SECONDS_IN_DAY = 60 * 60 * 24 _EPOCH = datetime(1970, 1, 1) -_from_bytes = None - def _from_date_number(x, table_client_settings): if ( @@ -39,8 +37,6 @@ def _from_json(x, table_client_settings): and table_client_settings._native_json_in_result_sets ): return json.loads(x) - if _from_bytes is not None: - return _from_bytes(x, table_client_settings) return x @@ -109,7 +105,7 @@ class PrimitiveType(enum.Enum): Float = _apis.primitive_types.FLOAT, "float_value" String = _apis.primitive_types.STRING, "bytes_value" - Utf8 = _apis.primitive_types.UTF8, "text_value", _from_bytes + Utf8 = _apis.primitive_types.UTF8, "text_value" Yson = _apis.primitive_types.YSON, "bytes_value" Json = _apis.primitive_types.JSON, "text_value", _from_json @@ -138,7 +134,7 @@ class PrimitiveType(enum.Enum): _to_interval, ) - DyNumber = _apis.primitive_types.DYNUMBER, "text_value", _from_bytes + DyNumber = _apis.primitive_types.DYNUMBER, "text_value" def __init__(self, idn, proto_field, to_obj=None, from_obj=None): self._idn_ = idn From 99a4fae285256c91232f42f86ddf0f3d8d4c1064 Mon Sep 17 00:00:00 2001 From: robot Date: Wed, 1 Mar 2023 21:11:43 +0000 Subject: [PATCH 069/147] Release: 3.0.1b5 --- CHANGELOG.md | 1 + setup.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 465891a2..339be6a4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,4 @@ +## 3.0.1b5 ## * Remove six package from code and dependencies (remove support python2) * Use anonymous credentials by default instead of iam metadata (use ydb.driver.credentials_from_env_variables for creds by env var) * Close grpc streams while closing readers/writers diff --git a/setup.py b/setup.py index 6c6ab5bd..ae7df0fd 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ setuptools.setup( name="ydb", - version="3.0.1b4", # AUTOVERSION + version="3.0.1b5", # AUTOVERSION description="YDB Python SDK", author="Yandex LLC", author_email="ydb@yandex-team.ru", From 28211869549f3f3f98169503eabe8f5f86ac2715 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Thu, 2 Mar 2023 10:20:30 +0300 Subject: [PATCH 070/147] Removed tornado support --- ydb/tornado/__init__.py | 4 -- ydb/tornado/tornado_helpers.py | 67 ---------------------------------- 2 files changed, 71 deletions(-) delete mode 100644 ydb/tornado/__init__.py delete mode 100644 ydb/tornado/tornado_helpers.py diff --git a/ydb/tornado/__init__.py b/ydb/tornado/__init__.py deleted file mode 100644 index eea7bdc4..00000000 --- a/ydb/tornado/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -try: - from .tornado_helpers import * # noqa -except ImportError: - pass diff --git a/ydb/tornado/tornado_helpers.py b/ydb/tornado/tornado_helpers.py deleted file mode 100644 index d63d4aed..00000000 --- a/ydb/tornado/tornado_helpers.py +++ /dev/null @@ -1,67 +0,0 @@ -# -*- coding: utf-8 -*- -try: - import tornado.concurrent - import tornado.ioloop - import tornado.gen - from tornado.concurrent import TracebackFuture -except ImportError: - tornado = None - -from ydb.table import retry_operation_impl, YdbRetryOperationSleepOpt - - -def as_tornado_future(foreign_future, timeout=None): - """ - Return tornado.concurrent.Future wrapped python concurrent.future (foreign_future). - Cancel execution original future after given timeout - """ - result_future = tornado.concurrent.Future() - timeout_timer = set() - if timeout: - - def on_timeout(): - timeout_timer.clear() - foreign_future.cancel() - - timeout_timer.add( - tornado.ioloop.IOLoop.current().call_later(timeout, on_timeout) - ) - - def copy_to_result_future(foreign_future): - try: - to_remove = timeout_timer.pop() - tornado.ioloop.IOLoop.current().remove_timeout(to_remove) - except KeyError: - pass - - if result_future.done(): - return - - if ( - isinstance(foreign_future, TracebackFuture) - and isinstance(result_future, TracebackFuture) - and result_future.exc_info() is not None - ): - result_future.set_exc_info(foreign_future.exc_info()) - elif foreign_future.cancelled(): - result_future.set_exception(tornado.gen.TimeoutError()) - elif foreign_future.exception() is not None: - result_future.set_exception(foreign_future.exception()) - else: - result_future.set_result(foreign_future.result()) - - tornado.ioloop.IOLoop.current().add_future(foreign_future, copy_to_result_future) - return result_future - - -async def retry_operation(callee, retry_settings=None, *args, **kwargs): - opt_generator = retry_operation_impl(callee, retry_settings, *args, **kwargs) - - for next_opt in opt_generator: - if isinstance(next_opt, YdbRetryOperationSleepOpt): - await tornado.gen.sleep(next_opt.timeout) - else: - try: - return await next_opt.result - except Exception as e: - next_opt.set_exception(e) From fc7ea8600cae28819ecc8e1f80d614db5da09e42 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Thu, 2 Mar 2023 10:46:59 +0300 Subject: [PATCH 071/147] public method renames --- CHANGELOG.md | 2 ++ examples/topic/reader_async_example.py | 14 ++++++-------- examples/topic/reader_example.py | 12 ++++++------ examples/topic/writer_async_example.py | 8 ++++---- examples/topic/writer_example.py | 10 +++++----- tests/conftest.py | 2 +- tests/topics/test_control_plane.py | 5 ++--- tests/topics/test_topic_reader.py | 4 ++-- tests/topics/test_topic_writer.py | 20 ++++++++++---------- ydb/topic.py | 14 ++++++++------ 10 files changed, 46 insertions(+), 45 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 339be6a4..760269a3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,5 @@ +* BROKEN CHANGES: change names of public method in topic client + ## 3.0.1b5 ## * Remove six package from code and dependencies (remove support python2) * Use anonymous credentials by default instead of iam metadata (use ydb.driver.credentials_from_env_variables for creds by env var) diff --git a/examples/topic/reader_async_example.py b/examples/topic/reader_async_example.py index e702903f..96142921 100644 --- a/examples/topic/reader_async_example.py +++ b/examples/topic/reader_async_example.py @@ -10,14 +10,12 @@ async def connect(): connection_string="grpc://localhost:2135?database=/local", credentials=ydb.credentials.AnonymousCredentials(), ) - reader = ydb.TopicClientAsyncIO(db).topic_reader( - "/local/topic", consumer="consumer" - ) + reader = ydb.TopicClientAsyncIO(db).reader("/local/topic", consumer="consumer") return reader async def create_reader_and_close_with_context_manager(db: ydb.aio.Driver): - with ydb.TopicClientAsyncIO(db).topic_reader( + with ydb.TopicClientAsyncIO(db).reader( "/database/topic/path", consumer="consumer" ) as reader: async for message in reader.messages(): @@ -91,7 +89,7 @@ async def get_one_batch_from_external_loop_async(reader: ydb.TopicReaderAsyncIO) async def auto_deserialize_message(db: ydb.aio.Driver): # async, batch work similar to this - async with ydb.TopicClientAsyncIO(db).topic_reader( + async with ydb.TopicClientAsyncIO(db).reader( "/database/topic/path", consumer="asd", deserializer=json.loads ) as reader: async for message in reader.messages(): @@ -133,7 +131,7 @@ def process_batch(batch): async def connect_and_read_few_topics(db: ydb.aio.Driver): - with ydb.TopicClientAsyncIO(db).topic_reader( + with ydb.TopicClientAsyncIO(db).reader( [ "/database/topic/path", ydb.TopicSelector("/database/second-topic", partitions=3), @@ -156,7 +154,7 @@ def on_commit(event: ydb.TopicReaderEvents.OnCommit) -> None: print(event.topic) print(event.offset) - async with ydb.TopicClientAsyncIO(db).topic_reader( + async with ydb.TopicClientAsyncIO(db).reader( "/local", consumer="consumer", commit_batch_time=4, on_commit=on_commit ) as reader: async for message in reader.messages(): @@ -173,7 +171,7 @@ async def on_get_partition_start_offset( resp.start_offset = 123 return resp - async with ydb.TopicClient(db).topic_reader( + async with ydb.TopicClient(db).reader( "/local/test", consumer="consumer", on_get_partition_start_offset=on_get_partition_start_offset, diff --git a/examples/topic/reader_example.py b/examples/topic/reader_example.py index 7cea2a35..183c51d6 100644 --- a/examples/topic/reader_example.py +++ b/examples/topic/reader_example.py @@ -9,12 +9,12 @@ def connect(): connection_string="grpc://localhost:2135?database=/local", credentials=ydb.credentials.AnonymousCredentials(), ) - reader = ydb.TopicClient(db).topic_reader("/local/topic", consumer="consumer") + reader = ydb.TopicClient(db).reader("/local/topic", consumer="consumer") return reader def create_reader_and_close_with_context_manager(db: ydb.Driver): - with ydb.TopicClient(db).topic_reader( + with ydb.TopicClient(db).reader( "/database/topic/path", consumer="consumer", buffer_size_bytes=123 ) as reader: for message in reader: @@ -81,7 +81,7 @@ def get_one_batch_from_external_loop(reader: ydb.TopicReader): def auto_deserialize_message(db: ydb.Driver): # async, batch work similar to this - reader = ydb.TopicClient(db).topic_reader( + reader = ydb.TopicClient(db).reader( "/database/topic/path", consumer="asd", deserializer=json.loads ) for message in reader.messages(): @@ -123,7 +123,7 @@ def process_batch(batch): def connect_and_read_few_topics(db: ydb.Driver): - with ydb.TopicClient(db).topic_reader( + with ydb.TopicClient(db).reader( [ "/database/topic/path", ydb.TopicSelector("/database/second-topic", partitions=3), @@ -146,7 +146,7 @@ def on_commit(event: ydb.TopicReaderEvents.OnCommit) -> None: print(event.topic) print(event.offset) - with ydb.TopicClient(db).topic_reader( + with ydb.TopicClient(db).reader( "/local", consumer="consumer", commit_batch_time=4, on_commit=on_commit ) as reader: for message in reader: @@ -164,7 +164,7 @@ def on_get_partition_start_offset( resp.start_offset = 123 return resp - with ydb.TopicClient(db).topic_reader( + with ydb.TopicClient(db).reader( "/local/test", consumer="consumer", on_get_partition_start_offset=on_get_partition_start_offset, diff --git a/examples/topic/writer_async_example.py b/examples/topic/writer_async_example.py index 29c79b08..30fbecab 100644 --- a/examples/topic/writer_async_example.py +++ b/examples/topic/writer_async_example.py @@ -8,7 +8,7 @@ async def create_writer(db: ydb.aio.Driver): - async with ydb.TopicClientAsyncIO(db).topic_writer( + async with ydb.TopicClientAsyncIO(db).writer( "/database/topic/path", producer_and_message_group_id="producer-id", ) as writer: @@ -16,7 +16,7 @@ async def create_writer(db: ydb.aio.Driver): async def connect_and_wait(db: ydb.aio.Driver): - async with ydb.TopicClientAsyncIO(db).topic_writer( + async with ydb.TopicClientAsyncIO(db).writer( "/database/topic/path", producer_and_message_group_id="producer-id", ) as writer: @@ -24,7 +24,7 @@ async def connect_and_wait(db: ydb.aio.Driver): async def connect_without_context_manager(db: ydb.aio.Driver): - writer = ydb.TopicClientAsyncIO(db).topic_writer( + writer = ydb.TopicClientAsyncIO(db).writer( "/database/topic/path", producer_and_message_group_id="producer-id", ) @@ -81,7 +81,7 @@ async def send_messages_with_wait_ack(writer: ydb.TopicWriterAsyncIO): async def send_json_message(db: ydb.aio.Driver): - async with ydb.TopicClientAsyncIO(db).topic_writer( + async with ydb.TopicClientAsyncIO(db).writer( "/database/path/topic", serializer=json.dumps ) as writer: writer.write({"a": 123}) diff --git a/examples/topic/writer_example.py b/examples/topic/writer_example.py index 27387e11..10f7db21 100644 --- a/examples/topic/writer_example.py +++ b/examples/topic/writer_example.py @@ -13,7 +13,7 @@ async def connect(): connection_string="grpc://localhost:2135?database=/local", credentials=ydb.credentials.AnonymousCredentials(), ) - writer = ydb.TopicClientAsyncIO(db).topic_writer( + writer = ydb.TopicClientAsyncIO(db).writer( "/local/topic", producer_and_message_group_id="producer-id", ) @@ -21,7 +21,7 @@ async def connect(): def create_writer(db: ydb.Driver): - with ydb.TopicClient(db).topic_writer( + with ydb.TopicClient(db).writer( "/database/topic/path", producer_and_message_group_id="producer-id", ) as writer: @@ -29,7 +29,7 @@ def create_writer(db: ydb.Driver): def connect_and_wait(db: ydb.Driver): - with ydb.TopicClient(db).topic_writer( + with ydb.TopicClient(db).writer( "/database/topic/path", producer_and_message_group_id="producer-id", ) as writer: @@ -37,7 +37,7 @@ def connect_and_wait(db: ydb.Driver): def connect_without_context_manager(db: ydb.Driver): - writer = ydb.TopicClient(db).topic_writer( + writer = ydb.TopicClient(db).writer( "/database/topic/path", producer_and_message_group_id="producer-id", ) @@ -98,7 +98,7 @@ def send_messages_with_wait_ack(writer: ydb.TopicWriter): def send_json_message(db: ydb.Driver): - with ydb.TopicClient(db).topic_writer( + with ydb.TopicClient(db).writer( "/database/path/topic", serializer=json.dumps ) as writer: writer.write({"a": 123}) diff --git a/tests/conftest.py b/tests/conftest.py index 09c02977..e94a83dc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -131,7 +131,7 @@ async def topic_path(driver, topic_consumer, database) -> str: @pytest.fixture() @pytest.mark.asyncio() async def topic_with_messages(driver, topic_path): - writer = driver.topic_client.topic_writer( + writer = driver.topic_client.writer( topic_path, producer_and_message_group_id="fixture-producer-id" ) await writer.write_with_ack( diff --git a/tests/topics/test_control_plane.py b/tests/topics/test_control_plane.py index 8e1d6f23..2446ddcf 100644 --- a/tests/topics/test_control_plane.py +++ b/tests/topics/test_control_plane.py @@ -27,7 +27,7 @@ async def test_drop_topic(self, driver, topic_path): await client.drop_topic(topic_path) async def test_describe_topic(self, driver, topic_path: str, topic_consumer): - res = await driver.topic_client.describe(topic_path) + res = await driver.topic_client.describe_topic(topic_path) assert res.self.name == os.path.basename(topic_path) @@ -61,8 +61,7 @@ def test_drop_topic(self, driver_sync, topic_path): client.drop_topic(topic_path) def test_describe_topic(self, driver_sync, topic_path: str, topic_consumer): - res = driver_sync.topic_client.describe(topic_path) - res.partition_count_limit + res = driver_sync.topic_client.describe_topic(topic_path) assert res.self.name == os.path.basename(topic_path) diff --git a/tests/topics/test_topic_reader.py b/tests/topics/test_topic_reader.py index 8107ac16..21675eb2 100644 --- a/tests/topics/test_topic_reader.py +++ b/tests/topics/test_topic_reader.py @@ -6,7 +6,7 @@ class TestTopicReaderAsyncIO: async def test_read_message( self, driver, topic_path, topic_with_messages, topic_consumer ): - reader = driver.topic_client.topic_reader(topic_consumer, topic_path) + reader = driver.topic_client.reader(topic_consumer, topic_path) assert await reader.receive_batch() is not None await reader.close() @@ -16,7 +16,7 @@ class TestTopicReaderSync: def test_read_message( self, driver_sync, topic_path, topic_with_messages, topic_consumer ): - reader = driver_sync.topic_client.topic_reader(topic_consumer, topic_path) + reader = driver_sync.topic_client.reader(topic_consumer, topic_path) assert reader.receive_batch() is not None reader.close() diff --git a/tests/topics/test_topic_writer.py b/tests/topics/test_topic_writer.py index 799c4d13..5ae976b8 100644 --- a/tests/topics/test_topic_writer.py +++ b/tests/topics/test_topic_writer.py @@ -6,14 +6,14 @@ @pytest.mark.asyncio class TestTopicWriterAsyncIO: async def test_send_message(self, driver: ydb.aio.Driver, topic_path): - writer = driver.topic_client.topic_writer( + writer = driver.topic_client.writer( topic_path, producer_and_message_group_id="test" ) await writer.write(ydb.TopicWriterMessage(data="123".encode())) await writer.close() async def test_wait_last_seqno(self, driver: ydb.aio.Driver, topic_path): - async with driver.topic_client.topic_writer( + async with driver.topic_client.writer( topic_path, producer_and_message_group_id="test", auto_seqno=False, @@ -22,7 +22,7 @@ async def test_wait_last_seqno(self, driver: ydb.aio.Driver, topic_path): ydb.TopicWriterMessage(data="123".encode(), seqno=5) ) - async with driver.topic_client.topic_writer( + async with driver.topic_client.writer( topic_path, producer_and_message_group_id="test", get_last_seqno=True, @@ -31,7 +31,7 @@ async def test_wait_last_seqno(self, driver: ydb.aio.Driver, topic_path): assert init_info.last_seqno == 5 async def test_auto_flush_on_close(self, driver: ydb.aio.Driver, topic_path): - async with driver.topic_client.topic_writer( + async with driver.topic_client.writer( topic_path, producer_and_message_group_id="test", auto_seqno=False, @@ -43,7 +43,7 @@ async def test_auto_flush_on_close(self, driver: ydb.aio.Driver, topic_path): ydb.TopicWriterMessage(data=f"msg-{i}", seqno=last_seqno) ) - async with driver.topic_client.topic_writer( + async with driver.topic_client.writer( topic_path, producer_and_message_group_id="test", get_last_seqno=True, @@ -54,21 +54,21 @@ async def test_auto_flush_on_close(self, driver: ydb.aio.Driver, topic_path): class TestTopicWriterSync: def test_send_message(self, driver_sync: ydb.Driver, topic_path): - writer = driver_sync.topic_client.topic_writer( + writer = driver_sync.topic_client.writer( topic_path, producer_and_message_group_id="test" ) writer.write(ydb.TopicWriterMessage(data="123".encode())) writer.close() def test_wait_last_seqno(self, driver_sync: ydb.Driver, topic_path): - with driver_sync.topic_client.topic_writer( + with driver_sync.topic_client.writer( topic_path, producer_and_message_group_id="test", auto_seqno=False, ) as writer: writer.write_with_ack(ydb.TopicWriterMessage(data="123".encode(), seqno=5)) - with driver_sync.topic_client.topic_writer( + with driver_sync.topic_client.writer( topic_path, producer_and_message_group_id="test", get_last_seqno=True, @@ -77,7 +77,7 @@ def test_wait_last_seqno(self, driver_sync: ydb.Driver, topic_path): assert init_info.last_seqno == 5 def test_auto_flush_on_close(self, driver_sync: ydb.Driver, topic_path): - with driver_sync.topic_client.topic_writer( + with driver_sync.topic_client.writer( topic_path, producer_and_message_group_id="test", auto_seqno=False, @@ -87,7 +87,7 @@ def test_auto_flush_on_close(self, driver_sync: ydb.Driver, topic_path): last_seqno = i + 1 writer.write(ydb.TopicWriterMessage(data=f"msg-{i}", seqno=last_seqno)) - with driver_sync.topic_client.topic_writer( + with driver_sync.topic_client.writer( topic_path, producer_and_message_group_id="test", get_last_seqno=True, diff --git a/ydb/topic.py b/ydb/topic.py index 9378d100..3c70b061 100644 --- a/ydb/topic.py +++ b/ydb/topic.py @@ -92,7 +92,7 @@ async def create_topic( _wrap_operation, ) - async def describe( + async def describe_topic( self, path: str, include_stats: bool = False ) -> TopicDescription: args = locals().copy() @@ -115,7 +115,7 @@ async def drop_topic(self, path: str): _wrap_operation, ) - def topic_reader( + def reader( self, consumer: str, topic: str, @@ -139,7 +139,7 @@ def topic_reader( settings = TopicReaderSettings(**args) return TopicReaderAsyncIO(self._driver, settings) - def topic_writer( + def writer( self, topic, *, @@ -215,7 +215,9 @@ def create_topic( _wrap_operation, ) - def describe(self, path: str, include_stats: bool = False) -> TopicDescription: + def describe_topic( + self, path: str, include_stats: bool = False + ) -> TopicDescription: args = locals().copy() del args["self"] req = _ydb_topic_public_types.DescribeTopicRequestParams(**args) @@ -236,7 +238,7 @@ def drop_topic(self, path: str): _wrap_operation, ) - def topic_reader( + def reader( self, consumer: str, topic: str, @@ -260,7 +262,7 @@ def topic_reader( settings = TopicReaderSettings(**args) return TopicReader(self._driver, settings) - def topic_writer( + def writer( self, topic, producer_and_message_group_id: str, From cbfb86abf53b4c6a889049d036d198f06d436fab Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Thu, 2 Mar 2023 12:53:49 +0300 Subject: [PATCH 072/147] rename producer_and_message_group_id to producer_id and make it optional --- CHANGELOG.md | 2 + examples/topic/writer_async_example.py | 6 +- examples/topic/writer_example.py | 8 +-- tests/conftest.py | 19 ++++++- tests/topics/test_topic_writer.py | 57 +++++++++++++------ ydb/_topic_writer/topic_writer.py | 41 +++++++------ .../topic_writer_asyncio_test.py | 4 +- ydb/topic.py | 23 ++------ 8 files changed, 95 insertions(+), 65 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 760269a3..3043520c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,6 @@ * BROKEN CHANGES: change names of public method in topic client +* BROKEN CHANGES: rename parameter producer_and_message_group_id to producer_id +* producer_id is optional now ## 3.0.1b5 ## * Remove six package from code and dependencies (remove support python2) diff --git a/examples/topic/writer_async_example.py b/examples/topic/writer_async_example.py index 30fbecab..548ef1aa 100644 --- a/examples/topic/writer_async_example.py +++ b/examples/topic/writer_async_example.py @@ -10,7 +10,7 @@ async def create_writer(db: ydb.aio.Driver): async with ydb.TopicClientAsyncIO(db).writer( "/database/topic/path", - producer_and_message_group_id="producer-id", + producer_id="producer-id", ) as writer: await writer.write(TopicWriterMessage("asd")) @@ -18,7 +18,7 @@ async def create_writer(db: ydb.aio.Driver): async def connect_and_wait(db: ydb.aio.Driver): async with ydb.TopicClientAsyncIO(db).writer( "/database/topic/path", - producer_and_message_group_id="producer-id", + producer_id="producer-id", ) as writer: writer.wait_init() @@ -26,7 +26,7 @@ async def connect_and_wait(db: ydb.aio.Driver): async def connect_without_context_manager(db: ydb.aio.Driver): writer = ydb.TopicClientAsyncIO(db).writer( "/database/topic/path", - producer_and_message_group_id="producer-id", + producer_id="producer-id", ) try: pass # some code diff --git a/examples/topic/writer_example.py b/examples/topic/writer_example.py index 10f7db21..e95107d1 100644 --- a/examples/topic/writer_example.py +++ b/examples/topic/writer_example.py @@ -15,7 +15,7 @@ async def connect(): ) writer = ydb.TopicClientAsyncIO(db).writer( "/local/topic", - producer_and_message_group_id="producer-id", + producer_id="producer-id", ) await writer.write(TopicWriterMessage("asd")) @@ -23,7 +23,7 @@ async def connect(): def create_writer(db: ydb.Driver): with ydb.TopicClient(db).writer( "/database/topic/path", - producer_and_message_group_id="producer-id", + producer_id="producer-id", ) as writer: writer.write(TopicWriterMessage("asd")) @@ -31,7 +31,7 @@ def create_writer(db: ydb.Driver): def connect_and_wait(db: ydb.Driver): with ydb.TopicClient(db).writer( "/database/topic/path", - producer_and_message_group_id="producer-id", + producer_id="producer-id", ) as writer: writer.wait() @@ -39,7 +39,7 @@ def connect_and_wait(db: ydb.Driver): def connect_without_context_manager(db: ydb.Driver): writer = ydb.TopicClient(db).writer( "/database/topic/path", - producer_and_message_group_id="producer-id", + producer_id="producer-id", ) try: pass # some code diff --git a/tests/conftest.py b/tests/conftest.py index e94a83dc..6fa1f174 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -131,11 +131,24 @@ async def topic_path(driver, topic_consumer, database) -> str: @pytest.fixture() @pytest.mark.asyncio() async def topic_with_messages(driver, topic_path): - writer = driver.topic_client.writer( - topic_path, producer_and_message_group_id="fixture-producer-id" - ) + writer = driver.topic_client.writer(topic_path, producer_id="fixture-producer-id") await writer.write_with_ack( ydb.TopicWriterMessage(data="123".encode()), ydb.TopicWriterMessage(data="456".encode()), ) await writer.close() + + +@pytest.fixture() +@pytest.mark.asyncio() +async def topic_reader(driver, topic_consumer, topic_path) -> ydb.TopicReaderAsyncIO: + reader = driver.topic_client.reader(topic=topic_path, consumer=topic_consumer) + yield reader + await reader.close() + + +@pytest.fixture() +def topic_reader_sync(driver_sync, topic_consumer, topic_path) -> ydb.TopicReader: + reader = driver_sync.topic_client.reader(topic=topic_path, consumer=topic_consumer) + yield reader + reader.close() diff --git a/tests/topics/test_topic_writer.py b/tests/topics/test_topic_writer.py index 5ae976b8..e7db0e23 100644 --- a/tests/topics/test_topic_writer.py +++ b/tests/topics/test_topic_writer.py @@ -6,16 +6,14 @@ @pytest.mark.asyncio class TestTopicWriterAsyncIO: async def test_send_message(self, driver: ydb.aio.Driver, topic_path): - writer = driver.topic_client.writer( - topic_path, producer_and_message_group_id="test" - ) + writer = driver.topic_client.writer(topic_path, producer_id="test") await writer.write(ydb.TopicWriterMessage(data="123".encode())) await writer.close() async def test_wait_last_seqno(self, driver: ydb.aio.Driver, topic_path): async with driver.topic_client.writer( topic_path, - producer_and_message_group_id="test", + producer_id="test", auto_seqno=False, ) as writer: await writer.write_with_ack( @@ -24,16 +22,28 @@ async def test_wait_last_seqno(self, driver: ydb.aio.Driver, topic_path): async with driver.topic_client.writer( topic_path, - producer_and_message_group_id="test", - get_last_seqno=True, + producer_id="test", ) as writer2: init_info = await writer2.wait_init() assert init_info.last_seqno == 5 + async def test_random_producer_id( + self, driver: ydb.aio.Driver, topic_path, topic_reader: ydb.TopicReaderAsyncIO + ): + async with driver.topic_client.writer(topic_path) as writer: + await writer.write(ydb.TopicWriterMessage(data="123".encode())) + async with driver.topic_client.writer(topic_path) as writer: + await writer.write(ydb.TopicWriterMessage(data="123".encode())) + + batch1 = await topic_reader.receive_batch() + batch2 = await topic_reader.receive_batch() + + assert batch1.messages[0].producer_id != batch2.messages[0].producer_id + async def test_auto_flush_on_close(self, driver: ydb.aio.Driver, topic_path): async with driver.topic_client.writer( topic_path, - producer_and_message_group_id="test", + producer_id="test", auto_seqno=False, ) as writer: last_seqno = 0 @@ -45,8 +55,7 @@ async def test_auto_flush_on_close(self, driver: ydb.aio.Driver, topic_path): async with driver.topic_client.writer( topic_path, - producer_and_message_group_id="test", - get_last_seqno=True, + producer_id="test", ) as writer: init_info = await writer.wait_init() assert init_info.last_seqno == last_seqno @@ -54,24 +63,21 @@ async def test_auto_flush_on_close(self, driver: ydb.aio.Driver, topic_path): class TestTopicWriterSync: def test_send_message(self, driver_sync: ydb.Driver, topic_path): - writer = driver_sync.topic_client.writer( - topic_path, producer_and_message_group_id="test" - ) + writer = driver_sync.topic_client.writer(topic_path, producer_id="test") writer.write(ydb.TopicWriterMessage(data="123".encode())) writer.close() def test_wait_last_seqno(self, driver_sync: ydb.Driver, topic_path): with driver_sync.topic_client.writer( topic_path, - producer_and_message_group_id="test", + producer_id="test", auto_seqno=False, ) as writer: writer.write_with_ack(ydb.TopicWriterMessage(data="123".encode(), seqno=5)) with driver_sync.topic_client.writer( topic_path, - producer_and_message_group_id="test", - get_last_seqno=True, + producer_id="test", ) as writer2: init_info = writer2.wait_init() assert init_info.last_seqno == 5 @@ -79,7 +85,7 @@ def test_wait_last_seqno(self, driver_sync: ydb.Driver, topic_path): def test_auto_flush_on_close(self, driver_sync: ydb.Driver, topic_path): with driver_sync.topic_client.writer( topic_path, - producer_and_message_group_id="test", + producer_id="test", auto_seqno=False, ) as writer: last_seqno = 0 @@ -89,8 +95,23 @@ def test_auto_flush_on_close(self, driver_sync: ydb.Driver, topic_path): with driver_sync.topic_client.writer( topic_path, - producer_and_message_group_id="test", - get_last_seqno=True, + producer_id="test", ) as writer: init_info = writer.wait_init() assert init_info.last_seqno == last_seqno + + def test_random_producer_id( + self, + driver_sync: ydb.aio.Driver, + topic_path, + topic_reader_sync: ydb.TopicReader, + ): + with driver_sync.topic_client.writer(topic_path) as writer: + writer.write(ydb.TopicWriterMessage(data="123".encode())) + with driver_sync.topic_client.writer(topic_path) as writer: + writer.write(ydb.TopicWriterMessage(data="123".encode())) + + batch1 = topic_reader_sync.receive_batch() + batch2 = topic_reader_sync.receive_batch() + + assert batch1.messages[0].producer_id != batch2.messages[0].producer_id diff --git a/ydb/_topic_writer/topic_writer.py b/ydb/_topic_writer/topic_writer.py index 75858324..aa147558 100644 --- a/ydb/_topic_writer/topic_writer.py +++ b/ydb/_topic_writer/topic_writer.py @@ -1,8 +1,9 @@ import datetime import enum +import uuid from dataclasses import dataclass from enum import Enum -from typing import List, Union, TextIO, BinaryIO, Optional, Callable, Mapping, Any, Dict +from typing import List, Union, TextIO, BinaryIO, Optional, Any, Dict import typing @@ -16,21 +17,31 @@ @dataclass class PublicWriterSettings: + """ + Settings for topic writer. + + order of fields IS NOT stable, use keywords only + """ + topic: str - producer_and_message_group_id: str + producer_id: Optional[str] = None session_metadata: Optional[Dict[str, str]] = None - encoders: Union[Mapping[int, Callable[[bytes], bytes]], None] = None - serializer: Union[Callable[[Any], bytes], None] = None - send_buffer_count: Optional[int] = 10000 - send_buffer_bytes: Optional[int] = 100 * 1024 * 1024 partition_id: Optional[int] = None - codec: Optional[int] = None - codec_autoselect: bool = True auto_seqno: bool = True auto_created_at: bool = True - get_last_seqno: bool = False - retry_policy: Optional["RetryPolicy"] = None - update_token_interval: Union[int, float] = 3600 + # get_last_seqno: bool = False + # encoders: Union[Mapping[int, Callable[[bytes], bytes]], None] = None + # serializer: Union[Callable[[Any], bytes], None] = None + # send_buffer_count: Optional[int] = 10000 + # send_buffer_bytes: Optional[int] = 100 * 1024 * 1024 + # codec: Optional[int] = None + # codec_autoselect: bool = True + # retry_policy: Optional["RetryPolicy"] = None + # update_token_interval: Union[int, float] = 3600 + + def __post_init__(self): + if self.producer_id is None: + self.producer_id = uuid.uuid4().hex @dataclass @@ -55,18 +66,16 @@ def __init__(self, settings: PublicWriterSettings): def create_init_request(self) -> StreamWriteMessage.InitRequest: return StreamWriteMessage.InitRequest( path=self.topic, - producer_id=self.producer_and_message_group_id, + producer_id=self.producer_id, write_session_meta=self.session_metadata, partitioning=self.get_partitioning(), - get_last_seq_no=self.get_last_seqno, + get_last_seq_no=True, ) def get_partitioning(self) -> StreamWriteMessage.PartitioningType: if self.partition_id is not None: return StreamWriteMessage.PartitioningPartitionID(self.partition_id) - return StreamWriteMessage.PartitioningMessageGroupID( - self.producer_and_message_group_id - ) + return StreamWriteMessage.PartitioningMessageGroupID(self.producer_id) class SendMode(Enum): diff --git a/ydb/_topic_writer/topic_writer_asyncio_test.py b/ydb/_topic_writer/topic_writer_asyncio_test.py index 7f19a4dd..32bb3de5 100644 --- a/ydb/_topic_writer/topic_writer_asyncio_test.py +++ b/ydb/_topic_writer/topic_writer_asyncio_test.py @@ -238,7 +238,7 @@ def default_settings(self) -> WriterSettings: return WriterSettings( PublicWriterSettings( topic="/local/topic", - producer_and_message_group_id="test-producer", + producer_id="test-producer", auto_seqno=False, auto_created_at=False, ) @@ -487,7 +487,7 @@ async def close(self): def default_settings(self) -> PublicWriterSettings: return PublicWriterSettings( topic="/local/topic", - producer_and_message_group_id="producer-id", + producer_id="producer-id", ) @pytest.fixture(autouse=True) diff --git a/ydb/topic.py b/ydb/topic.py index 3c70b061..45b8b073 100644 --- a/ydb/topic.py +++ b/ydb/topic.py @@ -1,5 +1,5 @@ import datetime -from typing import List, Callable, Union, Mapping, Any, Optional, Dict +from typing import List, Union, Mapping, Optional, Dict from . import aio, Credentials, _apis @@ -143,19 +143,11 @@ def writer( self, topic, *, - producer_and_message_group_id: str, + producer_id: Optional[str] = None, # default - random session_metadata: Mapping[str, str] = None, - encoders: Union[Mapping[int, Callable[[bytes], bytes]], None] = None, - serializer: Union[Callable[[Any], bytes], None] = None, - send_buffer_count: Union[int, None] = 10000, - send_buffer_bytes: Union[int, None] = 100 * 1024 * 1024, partition_id: Union[int, None] = None, - codec: Union[int, None] = None, - codec_autoselect: bool = True, auto_seqno: bool = True, auto_created_at: bool = True, - get_last_seqno: bool = False, - retry_policy: Union["TopicWriterRetryPolicy", None] = None, ) -> TopicWriterAsyncIO: args = locals() del args["self"] @@ -265,19 +257,12 @@ def reader( def writer( self, topic, - producer_and_message_group_id: str, + *, + producer_id: Optional[str] = None, # default - random session_metadata: Mapping[str, str] = None, - encoders: Union[Mapping[int, Callable[[bytes], bytes]], None] = None, - serializer: Union[Callable[[Any], bytes], None] = None, - send_buffer_count: Union[int, None] = 10000, - send_buffer_bytes: Union[int, None] = 100 * 1024 * 1024, partition_id: Union[int, None] = None, - codec: Union[int, None] = None, - codec_autoselect: bool = True, auto_seqno: bool = True, auto_created_at: bool = True, - get_last_seqno: bool = False, - retry_policy: Union["TopicWriterRetryPolicy", None] = None, ) -> TopicWriter: args = locals() del args["self"] From d41b6f417cb2529f6785dc696387b47e792c3551 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Wed, 15 Feb 2023 18:55:24 +0300 Subject: [PATCH 073/147] add commit for reader --- .github/workflows/style.yaml | 1 + tests/conftest.py | 4 + tests/topics/test_topic_reader.py | 23 + ydb/_grpc/grpcwrapper/ydb_topic.py | 112 ++++- ydb/_grpc/grpcwrapper/ydb_topic_test.py | 27 ++ ydb/_topic_common/test_helpers.py | 27 +- ydb/_topic_reader/datatypes.py | 183 +++++++- ydb/_topic_reader/datatypes_test.py | 315 ++++++++++++++ ydb/_topic_reader/topic_reader_asyncio.py | 180 ++++++-- .../topic_reader_asyncio_test.py | 408 ++++++++++++++---- ydb/_topic_reader/topic_reader_sync.py | 16 +- ydb/_topic_writer/topic_writer_asyncio.py | 9 +- ydb/_utilities.py | 15 + 13 files changed, 1153 insertions(+), 167 deletions(-) create mode 100644 ydb/_grpc/grpcwrapper/ydb_topic_test.py create mode 100644 ydb/_topic_reader/datatypes_test.py diff --git a/.github/workflows/style.yaml b/.github/workflows/style.yaml index 8723d8f2..c280042b 100644 --- a/.github/workflows/style.yaml +++ b/.github/workflows/style.yaml @@ -2,6 +2,7 @@ name: Style checks on: push: + - main pull_request: jobs: diff --git a/tests/conftest.py b/tests/conftest.py index 6fa1f174..62f486cb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -136,6 +136,10 @@ async def topic_with_messages(driver, topic_path): ydb.TopicWriterMessage(data="123".encode()), ydb.TopicWriterMessage(data="456".encode()), ) + await writer.write_with_ack( + ydb.TopicWriterMessage(data="789".encode()), + ydb.TopicWriterMessage(data="0".encode()), + ) await writer.close() diff --git a/tests/topics/test_topic_reader.py b/tests/topics/test_topic_reader.py index 21675eb2..734b64c7 100644 --- a/tests/topics/test_topic_reader.py +++ b/tests/topics/test_topic_reader.py @@ -11,6 +11,18 @@ async def test_read_message( assert await reader.receive_batch() is not None await reader.close() + async def test_read_and_commit_message( + self, driver, topic_path, topic_with_messages, topic_consumer + ): + + reader = driver.topic_client.topic_reader(topic_consumer, topic_path) + batch = await reader.receive_batch() + await reader.commit_with_ack(batch) + + reader = driver.topic_client.topic_reader(topic_consumer, topic_path) + batch2 = await reader.receive_batch() + assert batch.messages[0] != batch2.messages[0] + class TestTopicReaderSync: def test_read_message( @@ -20,3 +32,14 @@ def test_read_message( assert reader.receive_batch() is not None reader.close() + + def test_read_and_commit_message( + self, driver_sync, topic_path, topic_with_messages, topic_consumer + ): + reader = driver_sync.topic_client.topic_reader(topic_consumer, topic_path) + batch = reader.receive_batch() + reader.commit_with_ack(batch) + + reader = driver_sync.topic_client.topic_reader(topic_consumer, topic_path) + batch2 = reader.receive_batch() + assert batch.messages[0] != batch2.messages[0] diff --git a/ydb/_grpc/grpcwrapper/ydb_topic.py b/ydb/_grpc/grpcwrapper/ydb_topic.py index e6a5a8e3..ad8a8e72 100644 --- a/ydb/_grpc/grpcwrapper/ydb_topic.py +++ b/ydb/_grpc/grpcwrapper/ydb_topic.py @@ -67,10 +67,23 @@ def to_public(self) -> List[ydb_topic_public_types.PublicCodec]: return list(map(Codec.to_public, self.codecs)) -@dataclass -class OffsetsRange(IFromProto): - start: int - end: int +@dataclass(order=True) +class OffsetsRange(IFromProto, IToProto): + """ + half-opened interval, include [start, end) offsets + """ + + __slots__ = ("start", "end") + + start: int # first offset + end: int # offset after last, included to range + + def __post_init__(self): + if self.end < self.start: + raise ValueError( + "offset end must be not less then start. Got start=%s end=%s" + % (self.start, self.end) + ) @staticmethod def from_proto(msg: ydb_topic_pb2.OffsetsRange) -> "OffsetsRange": @@ -79,6 +92,20 @@ def from_proto(msg: ydb_topic_pb2.OffsetsRange) -> "OffsetsRange": end=msg.end, ) + def to_proto(self) -> ydb_topic_pb2.OffsetsRange: + return ydb_topic_pb2.OffsetsRange( + start=self.start, + end=self.end, + ) + + def is_intersected_with(self, other: "OffsetsRange") -> bool: + return ( + self.start <= other.start < self.end + or self.start < other.end <= self.end + or other.start <= self.start < other.end + or other.start < self.end <= other.end + ) + @dataclass class UpdateTokenRequest(IToProto): @@ -527,23 +554,67 @@ def from_proto( ) @dataclass - class CommitOffsetRequest: + class CommitOffsetRequest(IToProto): commit_offsets: List["PartitionCommitOffset"] + def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.CommitOffsetRequest: + res = ydb_topic_pb2.StreamReadMessage.CommitOffsetRequest( + commit_offsets=list( + map( + StreamReadMessage.CommitOffsetRequest.PartitionCommitOffset.to_proto, + self.commit_offsets, + ) + ), + ) + return res + @dataclass - class PartitionCommitOffset: + class PartitionCommitOffset(IToProto): partition_session_id: int offsets: List["OffsetsRange"] + def to_proto( + self, + ) -> ydb_topic_pb2.StreamReadMessage.CommitOffsetRequest.PartitionCommitOffset: + res = ydb_topic_pb2.StreamReadMessage.CommitOffsetRequest.PartitionCommitOffset( + partition_session_id=self.partition_session_id, + offsets=list(map(OffsetsRange.to_proto, self.offsets)), + ) + return res + @dataclass - class CommitOffsetResponse: - partitions_committed_offsets: List["PartitionCommittedOffset"] + class CommitOffsetResponse(IFromProto): + partitions_committed_offsets: List[ + "StreamReadMessage.CommitOffsetResponse.PartitionCommittedOffset" + ] + + @staticmethod + def from_proto( + msg: ydb_topic_pb2.StreamReadMessage.CommitOffsetResponse, + ) -> "StreamReadMessage.CommitOffsetResponse": + return StreamReadMessage.CommitOffsetResponse( + partitions_committed_offsets=list( + map( + StreamReadMessage.CommitOffsetResponse.PartitionCommittedOffset.from_proto, + msg.partitions_committed_offsets, + ) + ) + ) @dataclass - class PartitionCommittedOffset: + class PartitionCommittedOffset(IFromProto): partition_session_id: int committed_offset: int + @staticmethod + def from_proto( + msg: ydb_topic_pb2.StreamReadMessage.CommitOffsetResponse.PartitionCommittedOffset, + ) -> "StreamReadMessage.CommitOffsetResponse.PartitionCommittedOffset": + return StreamReadMessage.CommitOffsetResponse.PartitionCommittedOffset( + partition_session_id=msg.partition_session_id, + committed_offset=msg.committed_offset, + ) + @dataclass class PartitionSessionStatusRequest: partition_session_id: int @@ -576,16 +647,18 @@ def from_proto( @dataclass class StartPartitionSessionResponse(IToProto): partition_session_id: int - read_offset: int - commit_offset: int + read_offset: Optional[int] + commit_offset: Optional[int] def to_proto( self, ) -> ydb_topic_pb2.StreamReadMessage.StartPartitionSessionResponse: res = ydb_topic_pb2.StreamReadMessage.StartPartitionSessionResponse() res.partition_session_id = self.partition_session_id - res.read_offset = self.read_offset - res.commit_offset = self.commit_offset + if self.read_offset is not None: + res.read_offset = self.read_offset + if self.commit_offset is not None: + res.commit_offset = self.commit_offset return res @dataclass @@ -609,6 +682,8 @@ def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.FromClient: res = ydb_topic_pb2.StreamReadMessage.FromClient() if isinstance(self.client_message, StreamReadMessage.ReadRequest): res.read_request.CopyFrom(self.client_message.to_proto()) + elif isinstance(self.client_message, StreamReadMessage.CommitOffsetRequest): + res.commit_offset_request.CopyFrom(self.client_message.to_proto()) elif isinstance(self.client_message, StreamReadMessage.InitRequest): res.init_request.CopyFrom(self.client_message.to_proto()) elif isinstance( @@ -618,7 +693,9 @@ def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.FromClient: self.client_message.to_proto() ) else: - raise NotImplementedError() + raise NotImplementedError( + "Unknown message type: %s" % type(self.client_message) + ) return res @dataclass @@ -639,6 +716,13 @@ def from_proto( msg.read_response ), ) + elif mess_type == "commit_offset_response": + return StreamReadMessage.FromServer( + server_status=server_status, + server_message=StreamReadMessage.CommitOffsetResponse.from_proto( + msg.commit_offset_response + ), + ) elif mess_type == "init_response": return StreamReadMessage.FromServer( server_status=server_status, diff --git a/ydb/_grpc/grpcwrapper/ydb_topic_test.py b/ydb/_grpc/grpcwrapper/ydb_topic_test.py new file mode 100644 index 00000000..bff9b43d --- /dev/null +++ b/ydb/_grpc/grpcwrapper/ydb_topic_test.py @@ -0,0 +1,27 @@ +from ydb._grpc.grpcwrapper.ydb_topic import OffsetsRange + + +def test_offsets_range_intersected(): + # not intersected + for test in [(0, 1, 1, 2), (1, 2, 3, 5)]: + assert not OffsetsRange(test[0], test[1]).is_intersected_with( + OffsetsRange(test[2], test[3]) + ) + assert not OffsetsRange(test[2], test[3]).is_intersected_with( + OffsetsRange(test[0], test[1]) + ) + + # intersected + for test in [ + (1, 2, 1, 2), + (1, 10, 1, 2), + (1, 10, 2, 3), + (1, 10, 5, 15), + (10, 20, 5, 15), + ]: + assert OffsetsRange(test[0], test[1]).is_intersected_with( + OffsetsRange(test[2], test[3]) + ) + assert OffsetsRange(test[2], test[3]).is_intersected_with( + OffsetsRange(test[0], test[1]) + ) diff --git a/ydb/_topic_common/test_helpers.py b/ydb/_topic_common/test_helpers.py index 9023f759..60166d0d 100644 --- a/ydb/_topic_common/test_helpers.py +++ b/ydb/_topic_common/test_helpers.py @@ -39,7 +39,21 @@ def close(self): self.from_server.put_nowait(None) -async def wait_condition(f: typing.Callable[[], bool], timeout=1): +class WaitConditionException(Exception): + pass + + +async def wait_condition( + f: typing.Callable[[], bool], + timeout: typing.Optional[typing.Union[float, int]] = None, +): + """ + timeout default is 1 second + if timeout is 0 - only counter work. It userful if test need fast timeout for condition (without wait full timeout) + """ + if timeout is None: + timeout = 1 + start = time.monotonic() counter = 0 while (time.monotonic() - start < timeout) or counter < 1000: @@ -48,8 +62,13 @@ async def wait_condition(f: typing.Callable[[], bool], timeout=1): return await asyncio.sleep(0) - raise Exception("Bad condition in test") + raise WaitConditionException("Bad condition in test") -async def wait_for_fast(fut): - return await asyncio.wait_for(fut, 1) +async def wait_for_fast( + awaitable: typing.Awaitable, + timeout: typing.Optional[typing.Union[float, int]] = None, +): + fut = asyncio.ensure_future(awaitable) + await wait_condition(lambda: fut.done(), timeout) + return fut.result() diff --git a/ydb/_topic_reader/datatypes.py b/ydb/_topic_reader/datatypes.py index 9b2ab31a..06b8d690 100644 --- a/ydb/_topic_reader/datatypes.py +++ b/ydb/_topic_reader/datatypes.py @@ -1,20 +1,26 @@ +from __future__ import annotations + import abc +import asyncio +import bisect import enum -from dataclasses import dataclass +from collections import deque +from dataclasses import dataclass, field import datetime -from typing import Mapping, Union, Any, List, Dict +from typing import Mapping, Union, Any, List, Dict, Deque, Optional + +from ydb._grpc.grpcwrapper.ydb_topic import OffsetsRange +from ydb._topic_reader import topic_reader_asyncio class ICommittable(abc.ABC): - @property @abc.abstractmethod - def start_offset(self) -> int: - pass + def _commit_get_partition_session(self) -> PartitionSession: + ... - @property @abc.abstractmethod - def end_offset(self) -> int: - pass + def _commit_get_offsets_range(self) -> OffsetsRange: + ... class ISessionAlive(abc.ABC): @@ -36,15 +42,15 @@ class PublicMessage(ICommittable, ISessionAlive): data: Union[ bytes, Any ] # set as original decompressed bytes or deserialized object if deserializer set in reader - _partition_session: "PartitionSession" + _partition_session: PartitionSession + _commit_start_offset: int + _commit_end_offset: int - @property - def start_offset(self) -> int: - raise NotImplementedError() + def _commit_get_partition_session(self) -> PartitionSession: + return self._partition_session - @property - def end_offset(self) -> int: - raise NotImplementedError() + def _commit_get_offsets_range(self) -> OffsetsRange: + return OffsetsRange(self._commit_start_offset, self._commit_end_offset) # ISessionAlive implementation @property @@ -58,15 +64,147 @@ class PartitionSession: state: "PartitionSession.State" topic_path: str partition_id: int + committed_offset: int # last commit offset, acked from server. Processed messages up to the field-1 offset. + reader_reconnector_id: int + reader_stream_id: int + _next_message_start_commit_offset: int = field(init=False) + _send_commit_window_start: int = field(init=False) + + # todo: check if deque is optimal + _pending_commits: Deque[OffsetsRange] = field( + init=False, default_factory=lambda: deque() + ) + + # todo: check if deque is optimal + _ack_waiters: Deque["PartitionSession.CommitAckWaiter"] = field( + init=False, default_factory=lambda: deque() + ) + + _state_changed: asyncio.Event = field( + init=False, default_factory=lambda: asyncio.Event(), compare=False + ) + _loop: Optional[asyncio.AbstractEventLoop] = field( + init=False + ) # may be None in tests + + def __post_init__(self): + self._next_message_start_commit_offset = self.committed_offset + self._send_commit_window_start = self.committed_offset + + try: + self._loop = asyncio.get_running_loop() + except RuntimeError: + self._loop = None + + def add_commit( + self, new_commit: OffsetsRange + ) -> "PartitionSession.CommitAckWaiter": + self._ensure_not_closed() + + self._add_to_commits(new_commit) + return self._add_waiter(new_commit.end) + + def _add_to_commits(self, new_commit: OffsetsRange): + index = bisect.bisect_left(self._pending_commits, new_commit) + + prev_commit = self._pending_commits[index - 1] if index > 0 else None + commit = ( + self._pending_commits[index] if index < len(self._pending_commits) else None + ) + + for c in (prev_commit, commit): + if c is not None and new_commit.is_intersected_with(c): + raise ValueError( + "new commit intersected with existed. New range: %s, existed: %s" + % (new_commit, c) + ) + + if commit is not None and commit.start == new_commit.end: + commit.start = new_commit.start + elif prev_commit is not None and prev_commit.end == new_commit.start: + prev_commit.end = new_commit.end + else: + self._pending_commits.insert(index, new_commit) + + def _add_waiter(self, end_offset: int) -> "PartitionSession.CommitAckWaiter": + waiter = PartitionSession.CommitAckWaiter(end_offset, self._create_future()) + + # fast way + if len(self._ack_waiters) > 0 and self._ack_waiters[-1].end_offset < end_offset: + self._ack_waiters.append(waiter) + else: + bisect.insort(self._ack_waiters, waiter) + + return waiter + + def _create_future(self) -> asyncio.Future: + if self._loop: + return self._loop.create_future() + else: + return asyncio.Future() + + def pop_commit_range(self) -> Optional[OffsetsRange]: + self._ensure_not_closed() + + if len(self._pending_commits) == 0: + return None + + if self._pending_commits[0].start != self._send_commit_window_start: + return None + + res = self._pending_commits.popleft() + while ( + len(self._pending_commits) > 0 and self._pending_commits[0].start == res.end + ): + commit = self._pending_commits.popleft() + res.end = commit.end + + self._send_commit_window_start = res.end + + return res + + def ack_notify(self, offset: int): + self._ensure_not_closed() + + self.committed_offset = offset + + if len(self._ack_waiters) == 0: + # todo log warning + # must be never receive ack for not sended request + return + + while len(self._ack_waiters) > 0: + if self._ack_waiters[0].end_offset <= offset: + waiter = self._ack_waiters.popleft() + waiter.future.set_result(None) + else: + break + + def close(self): + try: + self._ensure_not_closed() + except topic_reader_asyncio.TopicReaderCommitToExpiredPartition: + return - def stop(self): self.state = PartitionSession.State.Stopped + exception = topic_reader_asyncio.TopicReaderCommitToExpiredPartition() + for waiter in self._ack_waiters: + waiter.future.set_exception(exception) + + def _ensure_not_closed(self): + if self.state == PartitionSession.State.Stopped: + raise topic_reader_asyncio.TopicReaderCommitToExpiredPartition() class State(enum.Enum): Active = 1 GracefulShutdown = 2 Stopped = 3 + @dataclass(order=True) + class CommitAckWaiter: + end_offset: int + future: asyncio.Future = field(compare=False) + @dataclass class PublicBatch(ICommittable, ISessionAlive): @@ -75,13 +213,14 @@ class PublicBatch(ICommittable, ISessionAlive): _partition_session: PartitionSession _bytes_size: int - @property - def start_offset(self) -> int: - raise NotImplementedError() + def _commit_get_partition_session(self) -> PartitionSession: + return self.messages[0]._commit_get_partition_session() - @property - def end_offset(self) -> int: - raise NotImplementedError() + def _commit_get_offsets_range(self) -> OffsetsRange: + return OffsetsRange( + self.messages[0]._commit_get_offsets_range().start, + self.messages[-1]._commit_get_offsets_range().end, + ) # ISessionAlive implementation @property diff --git a/ydb/_topic_reader/datatypes_test.py b/ydb/_topic_reader/datatypes_test.py new file mode 100644 index 00000000..6ead9a88 --- /dev/null +++ b/ydb/_topic_reader/datatypes_test.py @@ -0,0 +1,315 @@ +import asyncio +import bisect +import copy +import functools +from collections import deque +from typing import List, Optional, Type, Union + +import pytest + +from ydb._grpc.grpcwrapper.ydb_topic import OffsetsRange +from ydb._topic_common.test_helpers import wait_condition +from ydb._topic_reader import topic_reader_asyncio +from ydb._topic_reader.datatypes import PartitionSession + + +@pytest.mark.asyncio +class TestPartitionSession: + session_comitted_offset = 10 + + @pytest.fixture + def session(self) -> PartitionSession: + return PartitionSession( + id=1, + state=PartitionSession.State.Active, + topic_path="", + partition_id=1, + committed_offset=self.session_comitted_offset, + reader_reconnector_id=1, + reader_stream_id=1, + ) + + @pytest.mark.parametrize( + "offsets_waited,notify_offset,offsets_notified,offsets_waited_rest", + [ + ([1], 1, [1], []), + ([1], 10, [1], []), + ([1, 2, 3], 10, [1, 2, 3], []), + ([1, 2, 10, 20], 10, [1, 2, 10], [20]), + ([10, 20], 1, [], [10, 20]), + ], + ) + async def test_ack_notify( + self, + session, + offsets_waited: List[int], + notify_offset: int, + offsets_notified: List[int], + offsets_waited_rest: List[int], + ): + notified = [] + + for offset in offsets_waited: + fut = asyncio.Future() + + def add_notify(future, notified_offset): + notified.append(notified_offset) + + fut.add_done_callback(functools.partial(add_notify, notified_offset=offset)) + waiter = PartitionSession.CommitAckWaiter(offset, fut) + session._ack_waiters.append(waiter) + + session.ack_notify(notify_offset) + assert session._ack_waiters == deque( + [ + PartitionSession.CommitAckWaiter(offset, asyncio.Future()) + for offset in offsets_waited_rest + ] + ) + + await wait_condition(lambda: len(notified) == len(offsets_notified)) + + notified.sort() + assert notified == offsets_notified + assert session.committed_offset == notify_offset + + def test_add_commit(self, session): + commit = OffsetsRange( + self.session_comitted_offset, self.session_comitted_offset + 5 + ) + waiter = session.add_commit(commit) + assert waiter.end_offset == commit.end + + @pytest.mark.parametrize( + "original,add,result", + [ + ( + [], + OffsetsRange(1, 10), + [OffsetsRange(1, 10)], + ), + ( + [OffsetsRange(1, 10)], + OffsetsRange(15, 20), + [OffsetsRange(1, 10), OffsetsRange(15, 20)], + ), + ( + [OffsetsRange(15, 20)], + OffsetsRange(1, 10), + [OffsetsRange(1, 10), OffsetsRange(15, 20)], + ), + ( + [OffsetsRange(1, 10)], + OffsetsRange(10, 20), + [OffsetsRange(1, 20)], + ), + ( + [OffsetsRange(10, 20)], + OffsetsRange(1, 10), + [OffsetsRange(1, 20)], + ), + ( + [OffsetsRange(1, 2), OffsetsRange(3, 4)], + OffsetsRange(2, 3), + [OffsetsRange(1, 2), OffsetsRange(2, 4)], + ), + ( + [OffsetsRange(1, 10)], + OffsetsRange(5, 6), + ValueError, + ), + ], + ) + def test_add_to_commits( + self, + session, + original: List[OffsetsRange], + add: OffsetsRange, + result: Union[List[OffsetsRange], Type[Exception]], + ): + session._pending_commits = copy.deepcopy(original) + if isinstance(result, type) and issubclass(result, Exception): + with pytest.raises(result): + session._add_to_commits(add) + else: + session._add_to_commits(add) + assert session._pending_commits == result + + # noinspection PyTypeChecker + @pytest.mark.parametrize( + "original,add,result", + [ + ( + [], + 5, + [PartitionSession.CommitAckWaiter(5, None)], + ), + ( + [PartitionSession.CommitAckWaiter(5, None)], + 6, + [ + PartitionSession.CommitAckWaiter(5, None), + PartitionSession.CommitAckWaiter(6, None), + ], + ), + ( + [PartitionSession.CommitAckWaiter(5, None)], + 4, + [ + PartitionSession.CommitAckWaiter(4, None), + PartitionSession.CommitAckWaiter(5, None), + ], + ), + ( + [PartitionSession.CommitAckWaiter(5, None)], + 0, + [ + PartitionSession.CommitAckWaiter(0, None), + PartitionSession.CommitAckWaiter(5, None), + ], + ), + ( + [PartitionSession.CommitAckWaiter(5, None)], + 100, + [ + PartitionSession.CommitAckWaiter(5, None), + PartitionSession.CommitAckWaiter(100, None), + ], + ), + ( + [ + PartitionSession.CommitAckWaiter(5, None), + PartitionSession.CommitAckWaiter(100, None), + ], + 50, + [ + PartitionSession.CommitAckWaiter(5, None), + PartitionSession.CommitAckWaiter(50, None), + PartitionSession.CommitAckWaiter(100, None), + ], + ), + ( + [ + PartitionSession.CommitAckWaiter(5, None), + PartitionSession.CommitAckWaiter(7, None), + ], + 6, + [ + PartitionSession.CommitAckWaiter(5, None), + PartitionSession.CommitAckWaiter(6, None), + PartitionSession.CommitAckWaiter(7, None), + ], + ), + ( + [ + PartitionSession.CommitAckWaiter(5, None), + PartitionSession.CommitAckWaiter(100, None), + ], + 6, + [ + PartitionSession.CommitAckWaiter(5, None), + PartitionSession.CommitAckWaiter(6, None), + PartitionSession.CommitAckWaiter(100, None), + ], + ), + ( + [ + PartitionSession.CommitAckWaiter(5, None), + PartitionSession.CommitAckWaiter(100, None), + ], + 99, + [ + PartitionSession.CommitAckWaiter(5, None), + PartitionSession.CommitAckWaiter(99, None), + PartitionSession.CommitAckWaiter(100, None), + ], + ), + ], + ) + def test_add_waiter( + self, + session, + original: List[PartitionSession.CommitAckWaiter], + add: int, + result: List[PartitionSession.CommitAckWaiter], + ): + session._ack_waiters = copy.deepcopy(original) + res = session._add_waiter(add) + assert result == session._ack_waiters + + index = bisect.bisect_left(session._ack_waiters, res) + assert res is session._ack_waiters[index] + + def test_close_notify_waiters(self, session): + waiter = session._add_waiter(session.committed_offset + 1) + session.close() + + with pytest.raises(topic_reader_asyncio.TopicReaderCommitToExpiredPartition): + waiter.future.result() + + def test_close_twice(self, session): + session.close() + session.close() + + @pytest.mark.parametrize( + "commits,result,rest", + [ + ([], None, []), + ( + [OffsetsRange(session_comitted_offset + 1, 20)], + None, + [OffsetsRange(session_comitted_offset + 1, 20)], + ), + ( + [OffsetsRange(session_comitted_offset, session_comitted_offset + 1)], + OffsetsRange(session_comitted_offset, session_comitted_offset + 1), + [], + ), + ( + [ + OffsetsRange(session_comitted_offset, session_comitted_offset + 1), + OffsetsRange( + session_comitted_offset + 1, session_comitted_offset + 2 + ), + ], + OffsetsRange(session_comitted_offset, session_comitted_offset + 2), + [], + ), + ( + [ + OffsetsRange(session_comitted_offset, session_comitted_offset + 1), + OffsetsRange( + session_comitted_offset + 1, session_comitted_offset + 2 + ), + OffsetsRange( + session_comitted_offset + 10, session_comitted_offset + 20 + ), + ], + OffsetsRange(session_comitted_offset, session_comitted_offset + 2), + [ + OffsetsRange( + session_comitted_offset + 10, session_comitted_offset + 20 + ) + ], + ), + ], + ) + def test_get_commit_range( + self, + session, + commits: List[OffsetsRange], + result: Optional[OffsetsRange], + rest: List[OffsetsRange], + ): + send_commit_window_start = session._send_commit_window_start + + session._pending_commits = deque(commits) + res = session.pop_commit_range() + assert res == result + assert session._pending_commits == deque(rest) + + if res is None: + assert session._send_commit_window_start == send_commit_window_start + else: + assert session._send_commit_window_start != send_commit_window_start + assert session._send_commit_window_start == res.end diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index a3f792de..cc0839f7 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -6,12 +6,12 @@ from collections import deque from typing import Optional, Set, Dict - from .. import _apis, issues, RetrySettings +from .._utilities import AtomicCounter from ..aio import Driver from ..issues import Error as YdbError, _process_response -from .datatypes import PartitionSession, PublicMessage, PublicBatch, ICommittable -from .topic_reader import PublicReaderSettings, CommitResult, SessionStat +from . import datatypes +from . import topic_reader from .._topic_common.common import ( TokenGetterFuncType, ) @@ -28,6 +28,17 @@ class TopicReaderError(YdbError): pass +class TopicReaderCommitToExpiredPartition(TopicReaderError): + """ + Commit message when partition read session are dropped. + It is ok - the message/batch will not commit to server and will receive in other read session + (with this or other reader). + """ + + def __init__(self, message: str = "Topic reader partition session is closed"): + super().__init__(message) + + class TopicReaderStreamClosedError(TopicReaderError): def __init__(self): super().__init__("Topic reader stream is closed") @@ -43,7 +54,7 @@ class PublicAsyncIOReader: _closed: bool _reconnector: ReaderReconnector - def __init__(self, driver: Driver, settings: PublicReaderSettings): + def __init__(self, driver: Driver, settings: topic_reader.PublicReaderSettings): self._loop = asyncio.get_running_loop() self._closed = False self._reconnector = ReaderReconnector(driver, settings) @@ -58,7 +69,7 @@ def __del__(self): if not self._closed: self._loop.create_task(self.close(), name="close reader") - async def sessions_stat(self) -> typing.List["SessionStat"]: + async def sessions_stat(self) -> typing.List["topic_reader.SessionStat"]: """ Receive stat from the server @@ -68,7 +79,7 @@ async def sessions_stat(self) -> typing.List["SessionStat"]: def messages( self, *, timeout: typing.Union[float, None] = None - ) -> typing.AsyncIterable["PublicMessage"]: + ) -> typing.AsyncIterable[topic_reader.PublicMessage]: """ Block until receive new message @@ -76,7 +87,7 @@ def messages( """ raise NotImplementedError() - async def receive_message(self) -> typing.Union["PublicMessage", None]: + async def receive_message(self) -> typing.Union[topic_reader.PublicMessage, None]: """ Block until receive new message @@ -90,7 +101,7 @@ def batches( max_messages: typing.Union[int, None] = None, max_bytes: typing.Union[int, None] = None, timeout: typing.Union[float, None] = None, - ) -> typing.AsyncIterable["PublicBatch"]: + ) -> typing.AsyncIterable[datatypes.PublicBatch]: """ Block until receive new batch. All messages in a batch from same partition. @@ -104,7 +115,7 @@ async def receive_batch( *, max_messages: typing.Union[int, None] = None, max_bytes: typing.Union[int, None] = None, - ) -> typing.Union["PublicBatch", None]: + ) -> typing.Union[topic_reader.PublicBatch, None]: """ Get one messages batch from reader. All messages in a batch from same partition. @@ -114,7 +125,9 @@ async def receive_batch( await self._reconnector.wait_message() return self._reconnector.receive_batch_nowait() - async def commit_on_exit(self, mess: ICommittable) -> typing.AsyncContextManager: + async def commit_on_exit( + self, mess: datatypes.ICommittable + ) -> typing.AsyncContextManager: """ commit the mess match/message if exit from context manager without exceptions @@ -122,24 +135,27 @@ async def commit_on_exit(self, mess: ICommittable) -> typing.AsyncContextManager """ raise NotImplementedError() - def commit(self, mess: ICommittable): + def commit( + self, batch: typing.Union[datatypes.PublicMessage, datatypes.PublicBatch] + ): """ Write commit message to a buffer. For the method no way check the commit result (for example if lost connection - commits will not re-send and committed messages will receive again) """ - raise NotImplementedError() + self._reconnector.commit(batch) async def commit_with_ack( - self, mess: ICommittable - ) -> typing.Union[CommitResult, typing.List[CommitResult]]: + self, batch: typing.Union[datatypes.PublicMessage, datatypes.PublicBatch] + ): """ write commit message to a buffer and wait ack from the server. use asyncio.wait_for for wait with timeout. """ - raise NotImplementedError() + waiter = self._reconnector.commit(batch) + await waiter.future async def flush(self): """ @@ -158,7 +174,10 @@ async def close(self): class ReaderReconnector: - _settings: PublicReaderSettings + _static_reader_reconnector_counter = AtomicCounter() + + _id: int + _settings: topic_reader.PublicReaderSettings _driver: Driver _background_tasks: Set[Task] @@ -166,7 +185,9 @@ class ReaderReconnector: _stream_reader: Optional["ReaderStream"] _first_error: asyncio.Future[YdbError] - def __init__(self, driver: Driver, settings: PublicReaderSettings): + def __init__(self, driver: Driver, settings: topic_reader.PublicReaderSettings): + self._id = self._static_reader_reconnector_counter.inc_and_get() + self._settings = settings self._driver = driver self._background_tasks = set() @@ -182,7 +203,7 @@ async def _connection_loop(self): while True: try: self._stream_reader = await ReaderStream.create( - self._driver, self._settings + self._id, self._driver, self._settings ) attempt = 0 self._state_changed.set() @@ -216,6 +237,11 @@ async def wait_message(self): def receive_batch_nowait(self): return self._stream_reader.receive_batch_nowait() + def commit( + self, batch: datatypes.ICommittable + ) -> datatypes.PartitionSession.CommitAckWaiter: + return self._stream_reader.commit(batch) + async def close(self): await self._stream_reader.close() for task in self._background_tasks: @@ -233,20 +259,28 @@ def _set_first_error(self, err: issues.Error): class ReaderStream: + _static_id_counter = AtomicCounter() + + _id: int + _reader_reconnector_id: int _token_getter: Optional[TokenGetterFuncType] _session_id: str _stream: Optional[IGrpcWrapperAsyncIO] _started: bool _background_tasks: Set[asyncio.Task] - _partition_sessions: Dict[int, PartitionSession] + _partition_sessions: Dict[int, datatypes.PartitionSession] _buffer_size_bytes: int # use for init request, then for debug purposes only _state_changed: asyncio.Event _closed: bool - _message_batches: typing.Deque[PublicBatch] + _message_batches: typing.Deque[datatypes.PublicBatch] _first_error: asyncio.Future[YdbError] - def __init__(self, settings: PublicReaderSettings): + def __init__( + self, reader_reconnector_id: int, settings: topic_reader.PublicReaderSettings + ): + self._id = ReaderStream._static_id_counter.inc_and_get() + self._reader_reconnector_id = reader_reconnector_id self._token_getter = settings._token_getter self._session_id = "not initialized" self._stream = None @@ -262,8 +296,9 @@ def __init__(self, settings: PublicReaderSettings): @staticmethod async def create( + reader_reconnector_id: int, driver: SupportedDriverType, - settings: PublicReaderSettings, + settings: topic_reader.PublicReaderSettings, ) -> "ReaderStream": stream = GrpcWrapperAsyncIO(StreamReadMessage.FromServer.from_proto) @@ -271,7 +306,7 @@ async def create( driver, _apis.TopicService.Stub, _apis.TopicService.StreamRead ) - reader = ReaderStream(settings) + reader = ReaderStream(reader_reconnector_id, settings) await reader._start(stream, settings._init_message()) return reader @@ -321,6 +356,45 @@ def receive_batch_nowait(self): except IndexError: return None + def commit( + self, batch: datatypes.ICommittable + ) -> datatypes.PartitionSession.CommitAckWaiter: + partition_session = batch._commit_get_partition_session() + + if ( + partition_session.reader_reconnector_id + != partition_session.reader_reconnector_id + ): + raise TopicReaderError("reader can commit only self-produced messages") + + if partition_session.reader_stream_id != self._id: + raise TopicReaderCommitToExpiredPartition( + "commit messages after reconnect to server" + ) + + if partition_session.id not in self._partition_sessions: + raise TopicReaderCommitToExpiredPartition( + "commit messages after server stop the partition read session" + ) + + waiter = partition_session.add_commit(batch._commit_get_offsets_range()) + + send_range = partition_session.pop_commit_range() + if send_range: + client_message = StreamReadMessage.CommitOffsetRequest( + commit_offsets=[ + StreamReadMessage.CommitOffsetRequest.PartitionCommitOffset( + partition_session_id=partition_session.id, + offsets=[send_range], + ) + ] + ) + self._stream.write( + StreamReadMessage.FromClient(client_message=client_message) + ) + + return waiter + async def _read_messages_loop(self, stream: IGrpcWrapperAsyncIO): try: self._stream.write( @@ -335,11 +409,15 @@ async def _read_messages_loop(self, stream: IGrpcWrapperAsyncIO): _process_response(message.server_status) if isinstance(message.server_message, StreamReadMessage.ReadResponse): self._on_read_response(message.server_message) + elif isinstance( + message.server_message, StreamReadMessage.CommitOffsetResponse + ): + self._on_commit_response(message.server_message) elif isinstance( message.server_message, StreamReadMessage.StartPartitionSessionRequest, ): - self._on_start_partition_session_start(message.server_message) + self._on_start_partition_session(message.server_message) elif isinstance( message.server_message, StreamReadMessage.StopPartitionSessionRequest, @@ -356,7 +434,7 @@ async def _read_messages_loop(self, stream: IGrpcWrapperAsyncIO): self._set_first_error(e) raise e - def _on_start_partition_session_start( + def _on_start_partition_session( self, message: StreamReadMessage.StartPartitionSessionRequest ): try: @@ -371,18 +449,21 @@ def _on_start_partition_session_start( self._partition_sessions[ message.partition_session.partition_session_id - ] = PartitionSession( + ] = datatypes.PartitionSession( id=message.partition_session.partition_session_id, - state=PartitionSession.State.Active, + state=datatypes.PartitionSession.State.Active, topic_path=message.partition_session.path, partition_id=message.partition_session.partition_id, + committed_offset=message.committed_offset, + reader_reconnector_id=self._reader_reconnector_id, + reader_stream_id=self._id, ) self._stream.write( StreamReadMessage.FromClient( client_message=StreamReadMessage.StartPartitionSessionResponse( partition_session_id=message.partition_session.partition_session_id, - read_offset=0, - commit_offset=0, + read_offset=None, + commit_offset=None, ) ), ) @@ -399,7 +480,7 @@ def _on_partition_session_stop( return del self._partition_sessions[message.partition_session_id] - partition.stop() + partition.close() if message.graceful: self._stream.write( @@ -415,6 +496,16 @@ def _on_read_response(self, message: StreamReadMessage.ReadResponse): self._message_batches.extend(batches) self._buffer_consume_bytes(message.bytes_size) + def _on_commit_response(self, message: StreamReadMessage.CommitOffsetResponse): + for partition_offset in message.partitions_committed_offsets: + try: + session = self._partition_sessions[ + partition_offset.partition_session_id + ] + except KeyError: + continue + session.ack_notify(partition_offset.committed_offset) + def _buffer_consume_bytes(self, bytes_size): self._buffer_size_bytes -= bytes_size @@ -430,7 +521,7 @@ def _buffer_release_bytes(self, bytes_size): def _read_response_to_batches( self, message: StreamReadMessage.ReadResponse - ) -> typing.List[PublicBatch]: + ) -> typing.List[datatypes.PublicBatch]: batches = [] batch_count = 0 @@ -452,7 +543,7 @@ def _read_response_to_batches( for server_batch in partition_data.batches: messages = [] for message_data in server_batch.message_data: - mess = PublicMessage( + mess = datatypes.PublicMessage( seqno=message_data.seq_no, created_at=message_data.created_at, message_group_id=message_data.message_group_id, @@ -462,15 +553,23 @@ def _read_response_to_batches( producer_id=server_batch.producer_id, data=message_data.data, _partition_session=partition_session, + _commit_start_offset=partition_session._next_message_start_commit_offset, + _commit_end_offset=message_data.offset + 1, ) messages.append(mess) - batch = PublicBatch( - session_metadata=server_batch.write_session_meta, - messages=messages, - _partition_session=partition_session, - _bytes_size=bytes_per_batch, - ) - batches.append(batch) + + partition_session._next_message_start_commit_offset = ( + mess._commit_end_offset + ) + + if len(messages) > 0: + batch = datatypes.PublicBatch( + session_metadata=server_batch.write_session_meta, + messages=messages, + _partition_session=partition_session, + _bytes_size=bytes_per_batch, + ) + batches.append(batch) batches[-1]._bytes_size += additional_bytes_to_last_batch return batches @@ -498,6 +597,9 @@ async def close(self): self._state_changed.set() self._stream.close() + for session in self._partition_sessions.values(): + session.close() + for task in self._background_tasks: task.cancel() diff --git a/ydb/_topic_reader/topic_reader_asyncio_test.py b/ydb/_topic_reader/topic_reader_asyncio_test.py index f761a315..c73be69f 100644 --- a/ydb/_topic_reader/topic_reader_asyncio_test.py +++ b/ydb/_topic_reader/topic_reader_asyncio_test.py @@ -1,16 +1,21 @@ import asyncio import datetime import typing +from collections import deque +from dataclasses import dataclass +from typing import List, Optional from unittest import mock import pytest from ydb import issues +from . import datatypes, topic_reader_asyncio from .datatypes import PublicBatch, PublicMessage from .topic_reader import PublicReaderSettings -from .topic_reader_asyncio import ReaderStream, PartitionSession, ReaderReconnector +from .topic_reader_asyncio import ReaderStream, ReaderReconnector from .._grpc.grpcwrapper.common_utils import SupportedDriverType, ServerStatus from .._grpc.grpcwrapper.ydb_topic import StreamReadMessage, Codec, OffsetsRange +from .._topic_common import test_helpers from .._topic_common.test_helpers import StreamMock, wait_condition, wait_for_fast # Workaround for good IDE and universal for runtime @@ -47,38 +52,67 @@ def write(self, message: StreamReadMessage.FromClient): @pytest.mark.asyncio class TestReaderStream: default_batch_size = 1 + partition_session_id = 2 + partition_session_committed_offset = 10 + second_partition_session_id = 12 + second_partition_session_offset = 50 + default_reader_reconnector_id = 4 @pytest.fixture() def stream(self): return StreamMock() @pytest.fixture() - def partition_session(self, default_reader_settings): - return PartitionSession( + def partition_session( + self, default_reader_settings, stream_reader_started: ReaderStream + ) -> datatypes.PartitionSession: + partition_session = datatypes.PartitionSession( id=2, topic_path=default_reader_settings.topic, partition_id=4, - state=PartitionSession.State.Active, + state=datatypes.PartitionSession.State.Active, + committed_offset=self.partition_session_committed_offset, + reader_reconnector_id=self.default_reader_reconnector_id, + reader_stream_id=stream_reader_started._id, ) + assert partition_session.id not in stream_reader_started._partition_sessions + stream_reader_started._partition_sessions[ + partition_session.id + ] = partition_session + + return stream_reader_started._partition_sessions[partition_session.id] + @pytest.fixture() - def second_partition_session(self, default_reader_settings): - return PartitionSession( + def second_partition_session( + self, default_reader_settings, stream_reader_started: ReaderStream + ): + partition_session = datatypes.PartitionSession( id=12, topic_path=default_reader_settings.topic, partition_id=10, - state=PartitionSession.State.Active, + state=datatypes.PartitionSession.State.Active, + committed_offset=self.second_partition_session_offset, + reader_reconnector_id=self.default_reader_reconnector_id, + reader_stream_id=stream_reader_started._id, ) + assert partition_session.id not in stream_reader_started._partition_sessions + stream_reader_started._partition_sessions[ + partition_session.id + ] = partition_session + + return stream_reader_started._partition_sessions[partition_session.id] + @pytest.fixture() async def stream_reader_started( self, stream, default_reader_settings, - partition_session, - second_partition_session, ) -> ReaderStream: - reader = ReaderStream(default_reader_settings) + reader = ReaderStream( + self.default_reader_reconnector_id, default_reader_settings + ) init_message = object() # noinspection PyTypeChecker @@ -99,54 +133,8 @@ async def stream_reader_started( read_request = await wait_for_fast(stream.from_client.get()) assert isinstance(read_request.client_message, StreamReadMessage.ReadRequest) - stream.from_server.put_nowait( - StreamReadMessage.FromServer( - server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), - server_message=StreamReadMessage.StartPartitionSessionRequest( - partition_session=StreamReadMessage.PartitionSession( - partition_session_id=partition_session.id, - path=partition_session.topic_path, - partition_id=partition_session.partition_id, - ), - committed_offset=0, - partition_offsets=OffsetsRange( - start=0, - end=0, - ), - ), - ) - ) await start - start_partition_resp = await wait_for_fast(stream.from_client.get()) - assert isinstance( - start_partition_resp.client_message, - StreamReadMessage.StartPartitionSessionResponse, - ) - - stream.from_server.put_nowait( - StreamReadMessage.FromServer( - server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), - server_message=StreamReadMessage.StartPartitionSessionRequest( - partition_session=StreamReadMessage.PartitionSession( - partition_session_id=second_partition_session.id, - path=second_partition_session.topic_path, - partition_id=second_partition_session.partition_id, - ), - committed_offset=0, - partition_offsets=OffsetsRange( - start=0, - end=0, - ), - ), - ) - ) - start_partition_resp = await wait_for_fast(stream.from_client.get()) - assert isinstance( - start_partition_resp.client_message, - StreamReadMessage.StartPartitionSessionResponse, - ) - await asyncio.sleep(0) with pytest.raises(asyncio.QueueEmpty): stream.from_client.get_nowait() @@ -170,17 +158,26 @@ async def stream_reader_finish_with_error( await stream_reader_started.close() @staticmethod - def create_message(partition_session: PartitionSession, seqno: int): + def create_message( + partition_session: datatypes.PartitionSession, seqno: int, offset_delta: int + ): return PublicMessage( seqno=seqno, created_at=datetime.datetime(2023, 2, 3, 14, 15), message_group_id="test-message-group", session_metadata={}, - offset=seqno + 1, + offset=partition_session._next_message_start_commit_offset + + offset_delta + - 1, written_at=datetime.datetime(2023, 2, 3, 14, 16), producer_id="test-producer-id", data=bytes(), _partition_session=partition_session, + _commit_start_offset=partition_session._next_message_start_commit_offset + + offset_delta + - 1, + _commit_end_offset=partition_session._next_message_start_commit_offset + + offset_delta, ) async def send_message(self, stream_reader, message: PublicMessage): @@ -236,6 +233,231 @@ class TestError(Exception): with pytest.raises(TestError): stream_reader_finish_with_error.receive_batch_nowait() + @pytest.mark.parametrize( + "pending_ranges,commit,send_range,rest_ranges", + [ + ( + [], + OffsetsRange( + partition_session_committed_offset, + partition_session_committed_offset + 1, + ), + OffsetsRange( + partition_session_committed_offset, + partition_session_committed_offset + 1, + ), + [], + ), + ( + [], + OffsetsRange( + partition_session_committed_offset + 1, + partition_session_committed_offset + 2, + ), + None, + [ + OffsetsRange( + partition_session_committed_offset + 1, + partition_session_committed_offset + 2, + ) + ], + ), + ( + [ + OffsetsRange( + partition_session_committed_offset + 5, + partition_session_committed_offset + 10, + ) + ], + OffsetsRange( + partition_session_committed_offset + 1, + partition_session_committed_offset + 2, + ), + None, + [ + OffsetsRange( + partition_session_committed_offset + 1, + partition_session_committed_offset + 2, + ), + OffsetsRange( + partition_session_committed_offset + 5, + partition_session_committed_offset + 10, + ), + ], + ), + ( + [ + OffsetsRange( + partition_session_committed_offset + 1, + partition_session_committed_offset + 2, + ) + ], + OffsetsRange( + partition_session_committed_offset, + partition_session_committed_offset + 1, + ), + OffsetsRange( + partition_session_committed_offset, + partition_session_committed_offset + 2, + ), + [], + ), + ( + [ + OffsetsRange( + partition_session_committed_offset + 1, + partition_session_committed_offset + 2, + ), + OffsetsRange( + partition_session_committed_offset + 2, + partition_session_committed_offset + 3, + ), + ], + OffsetsRange( + partition_session_committed_offset, + partition_session_committed_offset + 1, + ), + OffsetsRange( + partition_session_committed_offset, + partition_session_committed_offset + 3, + ), + [], + ), + ( + [ + OffsetsRange( + partition_session_committed_offset + 1, + partition_session_committed_offset + 2, + ), + OffsetsRange( + partition_session_committed_offset + 2, + partition_session_committed_offset + 3, + ), + OffsetsRange( + partition_session_committed_offset + 4, + partition_session_committed_offset + 5, + ), + ], + OffsetsRange( + partition_session_committed_offset, + partition_session_committed_offset + 1, + ), + OffsetsRange( + partition_session_committed_offset, + partition_session_committed_offset + 3, + ), + [ + OffsetsRange( + partition_session_committed_offset + 4, + partition_session_committed_offset + 5, + ) + ], + ), + ], + ) + async def test_send_commit_messages( + self, + stream, + stream_reader: ReaderStream, + partition_session, + pending_ranges: List[OffsetsRange], + commit: OffsetsRange, + send_range: Optional[OffsetsRange], + rest_ranges: List[OffsetsRange], + ): + @dataclass + class Commitable(datatypes.ICommittable): + start: int + end: int + + def _commit_get_partition_session(self) -> datatypes.PartitionSession: + return partition_session + + def _commit_get_offsets_range(self) -> OffsetsRange: + return OffsetsRange(self.start, self.end) + + partition_session._pending_commits = deque(pending_ranges) + + stream_reader.commit(Commitable(commit.start, commit.end)) + + async def wait_message(): + return await wait_for_fast(stream.from_client.get(), timeout=0) + + if send_range: + msg = await wait_message() # type: StreamReadMessage.FromClient + assert msg.client_message == StreamReadMessage.CommitOffsetRequest( + commit_offsets=[ + StreamReadMessage.CommitOffsetRequest.PartitionCommitOffset( + partition_session_id=partition_session.id, + offsets=[send_range], + ) + ] + ) + else: + with pytest.raises(test_helpers.WaitConditionException): + await wait_message() + + assert partition_session._pending_commits == deque(rest_ranges) + + async def test_commit_ack_received( + self, stream_reader, stream, partition_session, second_partition_session + ): + offset1 = self.partition_session_committed_offset + 1 + waiter1 = partition_session._add_waiter(offset1) + + offset2 = self.second_partition_session_offset + 2 + waiter2 = second_partition_session._add_waiter(offset2) + + stream.from_server.put_nowait( + StreamReadMessage.FromServer( + server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), + server_message=StreamReadMessage.CommitOffsetResponse( + partitions_committed_offsets=[ + StreamReadMessage.CommitOffsetResponse.PartitionCommittedOffset( + partition_session_id=partition_session.id, + committed_offset=offset1, + ), + StreamReadMessage.CommitOffsetResponse.PartitionCommittedOffset( + partition_session_id=second_partition_session.id, + committed_offset=offset2, + ), + ] + ), + ) + ) + + await wait_for_fast(waiter1.future) + await wait_for_fast(waiter2.future) + + async def test_close_ack_waiters_when_close_stream_reader( + self, stream_reader_started: ReaderStream, partition_session + ): + waiter = partition_session._add_waiter( + self.partition_session_committed_offset + 1 + ) + await wait_for_fast(stream_reader_started.close()) + + with pytest.raises(topic_reader_asyncio.TopicReaderCommitToExpiredPartition): + waiter.future.result() + + async def test_commit_ranges_for_received_messages( + self, stream, stream_reader_started: ReaderStream, partition_session + ): + m1 = self.create_message(partition_session, 1, 1) + m2 = self.create_message(partition_session, 2, 10) + m2._commit_start_offset = m1.offset + 1 + + await self.send_message(stream_reader_started, m1) + await self.send_message(stream_reader_started, m2) + + await stream_reader_started.wait_messages() + received = stream_reader_started.receive_batch_nowait().messages + assert received == [m1] + + await stream_reader_started.wait_messages() + received = stream_reader_started.receive_batch_nowait().messages + assert received == [m2] + async def test_error_from_status_code( self, stream, stream_reader_finish_with_error ): @@ -257,7 +479,9 @@ async def test_error_from_status_code( stream_reader_finish_with_error.receive_batch_nowait() async def test_init_reader(self, stream, default_reader_settings): - reader = ReaderStream(default_reader_settings) + reader = ReaderStream( + self.default_reader_reconnector_id, default_reader_settings + ) init_message = StreamReadMessage.InitRequest( consumer="test-consumer", topics_read_settings=[ @@ -309,6 +533,7 @@ def session_count(): test_partition_id = partition_session.partition_id + 1 test_partition_session_id = partition_session.id + 1 test_topic_path = default_reader_settings.topic + "-asd" + test_partition_committed_offset = 18 stream.from_server.put_nowait( StreamReadMessage.FromServer( @@ -319,7 +544,7 @@ def session_count(): path=test_topic_path, partition_id=test_partition_id, ), - committed_offset=0, + committed_offset=test_partition_committed_offset, partition_offsets=OffsetsRange( start=0, end=0, @@ -331,19 +556,22 @@ def session_count(): assert response == StreamReadMessage.FromClient( client_message=StreamReadMessage.StartPartitionSessionResponse( partition_session_id=test_partition_session_id, - read_offset=0, - commit_offset=0, + read_offset=None, + commit_offset=None, ) ) assert len(stream_reader._partition_sessions) == initial_session_count + 1 assert stream_reader._partition_sessions[ test_partition_session_id - ] == PartitionSession( + ] == datatypes.PartitionSession( id=test_partition_session_id, - state=PartitionSession.State.Active, + state=datatypes.PartitionSession.State.Active, topic_path=test_topic_path, partition_id=test_partition_id, + committed_offset=test_partition_committed_offset, + reader_reconnector_id=self.default_reader_reconnector_id, + reader_stream_id=stream_reader._id, ) async def test_partition_stop_force(self, stream, stream_reader, partition_session): @@ -414,7 +642,11 @@ def session_count(): stream.from_client.get_nowait() async def test_receive_message_from_server( - self, stream_reader, stream, partition_session, second_partition_session + self, + stream_reader, + stream, + partition_session: datatypes.PartitionSession, + second_partition_session, ): def reader_batch_count(): return len(stream_reader._message_batches) @@ -430,6 +662,8 @@ def reader_batch_count(): session_meta = {"a": "b"} message_group_id = "test-message-group-id" + expected_message_offset = partition_session.committed_offset + stream.from_server.put_nowait( StreamReadMessage.FromServer( server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), @@ -442,7 +676,7 @@ def reader_batch_count(): StreamReadMessage.ReadResponse.Batch( message_data=[ StreamReadMessage.ReadResponse.MessageData( - offset=1, + offset=expected_message_offset, seq_no=2, created_at=created_at, data=data, @@ -475,11 +709,13 @@ def reader_batch_count(): created_at=created_at, message_group_id=message_group_id, session_metadata=session_meta, - offset=1, + offset=expected_message_offset, written_at=written_at, producer_id=producer_id, data=data, _partition_session=partition_session, + _commit_start_offset=expected_message_offset, + _commit_end_offset=expected_message_offset + 1, ) ], _partition_session=partition_session, @@ -505,6 +741,11 @@ async def test_read_batches( message_group_id = "test-message-group-id" message_group_id2 = "test-message-group-id-2" + partition1_mess1_expected_offset = partition_session.committed_offset + partition2_mess1_expected_offset = second_partition_session.committed_offset + partition2_mess2_expected_offset = second_partition_session.committed_offset + 1 + partition2_mess3_expected_offset = second_partition_session.committed_offset + 2 + batches = stream_reader._read_response_to_batches( StreamReadMessage.ReadResponse( bytes_size=3, @@ -515,7 +756,7 @@ async def test_read_batches( StreamReadMessage.ReadResponse.Batch( message_data=[ StreamReadMessage.ReadResponse.MessageData( - offset=2, + offset=partition1_mess1_expected_offset, seq_no=3, created_at=created_at, data=data, @@ -536,7 +777,7 @@ async def test_read_batches( StreamReadMessage.ReadResponse.Batch( message_data=[ StreamReadMessage.ReadResponse.MessageData( - offset=1, + offset=partition2_mess1_expected_offset, seq_no=2, created_at=created_at2, data=data, @@ -552,7 +793,7 @@ async def test_read_batches( StreamReadMessage.ReadResponse.Batch( message_data=[ StreamReadMessage.ReadResponse.MessageData( - offset=2, + offset=partition2_mess2_expected_offset, seq_no=3, created_at=created_at3, data=data2, @@ -560,7 +801,7 @@ async def test_read_batches( message_group_id=message_group_id, ), StreamReadMessage.ReadResponse.MessageData( - offset=4, + offset=partition2_mess3_expected_offset, seq_no=5, created_at=created_at4, data=data, @@ -591,11 +832,13 @@ async def test_read_batches( created_at=created_at, message_group_id=message_group_id, session_metadata=session_meta, - offset=2, + offset=partition1_mess1_expected_offset, written_at=written_at, producer_id=producer_id, data=data, _partition_session=partition_session, + _commit_start_offset=partition1_mess1_expected_offset, + _commit_end_offset=partition1_mess1_expected_offset + 1, ) ], _partition_session=partition_session, @@ -609,11 +852,13 @@ async def test_read_batches( created_at=created_at2, message_group_id=message_group_id, session_metadata=session_meta, - offset=1, + offset=partition2_mess1_expected_offset, written_at=written_at2, producer_id=producer_id, data=data, _partition_session=second_partition_session, + _commit_start_offset=partition2_mess1_expected_offset, + _commit_end_offset=partition2_mess1_expected_offset + 1, ) ], _partition_session=second_partition_session, @@ -627,22 +872,26 @@ async def test_read_batches( created_at=created_at3, message_group_id=message_group_id, session_metadata=session_meta2, - offset=2, + offset=partition2_mess2_expected_offset, written_at=written_at2, producer_id=producer_id2, data=data2, _partition_session=second_partition_session, + _commit_start_offset=partition2_mess2_expected_offset, + _commit_end_offset=partition2_mess2_expected_offset + 1, ), PublicMessage( seqno=5, created_at=created_at4, message_group_id=message_group_id2, session_metadata=session_meta2, - offset=4, + offset=partition2_mess3_expected_offset, written_at=written_at2, producer_id=producer_id, data=data, _partition_session=second_partition_session, + _commit_start_offset=partition2_mess3_expected_offset, + _commit_end_offset=partition2_mess3_expected_offset + 1, ), ], _partition_session=second_partition_session, @@ -652,17 +901,17 @@ async def test_read_batches( async def test_receive_batch_nowait(self, stream, stream_reader, partition_session): assert stream_reader.receive_batch_nowait() is None - mess1 = self.create_message(partition_session, 1) + mess1 = self.create_message(partition_session, 1, 1) await self.send_message(stream_reader, mess1) - mess2 = self.create_message(partition_session, 2) + mess2 = self.create_message(partition_session, 2, 1) await self.send_message(stream_reader, mess2) initial_buffer_size = stream_reader._buffer_size_bytes received = stream_reader.receive_batch_nowait() assert received == PublicBatch( - mess1.session_metadata, + session_metadata=mess1.session_metadata, messages=[mess1], _partition_session=mess1._partition_session, _bytes_size=self.default_batch_size, @@ -721,6 +970,7 @@ async def wait_messages(): stream_index = 0 async def stream_create( + reader_reconnector_id: int, driver: SupportedDriverType, settings: PublicReaderSettings, ): @@ -735,7 +985,7 @@ async def stream_create( with mock.patch.object(ReaderStream, "create", stream_create): reconnector = ReaderReconnector(mock.Mock(), PublicReaderSettings("", "")) - await reconnector.wait_message() + await wait_for_fast(reconnector.wait_message()) reader_stream_mock_with_error.wait_error.assert_any_await() reader_stream_mock_with_error.wait_messages.assert_any_await() diff --git a/ydb/_topic_reader/topic_reader_sync.py b/ydb/_topic_reader/topic_reader_sync.py index b30b547a..9652cb84 100644 --- a/ydb/_topic_reader/topic_reader_sync.py +++ b/ydb/_topic_reader/topic_reader_sync.py @@ -46,13 +46,19 @@ async def create_reader(): def __del__(self): self.close() - def _call(self, coro): + def _call(self, coro) -> concurrent.futures.Future: + """ + Call async function and return future fow wait result + """ if self._closed: raise TopicReaderClosedError() return asyncio.run_coroutine_threadsafe(coro, self._loop) def _call_sync(self, coro: Coroutine, timeout): + """ + Call async function, wait and return result + """ f = self._call(coro) try: return f.result(timeout) @@ -162,15 +168,13 @@ def commit_with_ack( if receive in timeout seconds (default - infinite): raise TimeoutError() """ - raise NotImplementedError() + return self._call_sync(self._async_reader.commit_with_ack(mess), None) - def async_commit_with_ack( - self, mess: ICommittable - ) -> Union[CommitResult, List[CommitResult]]: + def async_commit_with_ack(self, mess: ICommittable) -> concurrent.futures.Future: """ write commit message to a buffer and return Future for wait result. """ - raise NotImplementedError() + return self._call(self._async_reader.commit_with_ack(mess), None) def async_flush(self) -> concurrent.futures.Future: """ diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 4724ab2f..b46a13b8 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -151,6 +151,7 @@ async def wait_init(self) -> PublicWriterInitInfo: class WriterAsyncIOReconnector: _closed: bool + _loop: asyncio.AbstractEventLoop _credentials: Union[ydb.Credentials, None] _driver: ydb.aio.Driver _update_token_interval: int @@ -169,10 +170,12 @@ class WriterAsyncIOReconnector: def __init__(self, driver: SupportedDriverType, settings: WriterSettings): self._closed = False + self._loop = asyncio.get_running_loop() self._driver = driver self._credentials = driver._credentials self._init_message = settings.create_init_request() - self._init_info = asyncio.Future() + self._new_messages = asyncio.Queue() + self._init_info = self._loop.create_future() self._stream_connected = asyncio.Event() self._settings = settings @@ -180,7 +183,7 @@ def __init__(self, driver: SupportedDriverType, settings: WriterSettings): self._messages = deque() self._messages_future = deque() self._new_messages = asyncio.Queue() - self._stop_reason = asyncio.Future() + self._stop_reason = self._loop.create_future() self._background_tasks = [ asyncio.create_task(self._connection_loop(), name="connection_loop") ] @@ -233,7 +236,7 @@ async def write_with_ack_future( await self.wait_init() internal_messages = self._prepare_internal_messages(messages) - messages_future = [asyncio.Future() for _ in internal_messages] + messages_future = [self._loop.create_future() for _ in internal_messages] self._messages.extend(internal_messages) self._messages_future.extend(messages_future) diff --git a/ydb/_utilities.py b/ydb/_utilities.py index 544b154c..0b72a198 100644 --- a/ydb/_utilities.py +++ b/ydb/_utilities.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +import threading import codecs from concurrent import futures import functools @@ -157,3 +158,17 @@ def next(self): def __next__(self): return self._next() + + +class AtomicCounter: + _lock: threading.Lock + _value: int + + def __init__(self, initial_value: int = 0): + self._lock = threading.Lock() + self._value = initial_value + + def inc_and_get(self) -> int: + with self._lock: + self._value += 1 + return self._value From af833fff9a8043f4ae9928cbb4bf604d10d65e7e Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Wed, 1 Mar 2023 18:13:00 +0300 Subject: [PATCH 074/147] remove logic for reorder commits --- ydb/_topic_common/test_helpers.py | 4 +- ydb/_topic_reader/datatypes.py | 76 ++---- ydb/_topic_reader/datatypes_test.py | 254 ++++++------------ ydb/_topic_reader/topic_reader_asyncio.py | 8 +- .../topic_reader_asyncio_test.py | 146 ++-------- 5 files changed, 125 insertions(+), 363 deletions(-) diff --git a/ydb/_topic_common/test_helpers.py b/ydb/_topic_common/test_helpers.py index 60166d0d..d70cd9f1 100644 --- a/ydb/_topic_common/test_helpers.py +++ b/ydb/_topic_common/test_helpers.py @@ -39,7 +39,7 @@ def close(self): self.from_server.put_nowait(None) -class WaitConditionException(Exception): +class WaitConditionError(Exception): pass @@ -62,7 +62,7 @@ async def wait_condition( return await asyncio.sleep(0) - raise WaitConditionException("Bad condition in test") + raise WaitConditionError("Bad condition in test") async def wait_for_fast( diff --git a/ydb/_topic_reader/datatypes.py b/ydb/_topic_reader/datatypes.py index 06b8d690..6ca7681c 100644 --- a/ydb/_topic_reader/datatypes.py +++ b/ydb/_topic_reader/datatypes.py @@ -68,12 +68,6 @@ class PartitionSession: reader_reconnector_id: int reader_stream_id: int _next_message_start_commit_offset: int = field(init=False) - _send_commit_window_start: int = field(init=False) - - # todo: check if deque is optimal - _pending_commits: Deque[OffsetsRange] = field( - init=False, default_factory=lambda: deque() - ) # todo: check if deque is optimal _ack_waiters: Deque["PartitionSession.CommitAckWaiter"] = field( @@ -89,45 +83,17 @@ class PartitionSession: def __post_init__(self): self._next_message_start_commit_offset = self.committed_offset - self._send_commit_window_start = self.committed_offset try: self._loop = asyncio.get_running_loop() except RuntimeError: self._loop = None - def add_commit( - self, new_commit: OffsetsRange - ) -> "PartitionSession.CommitAckWaiter": - self._ensure_not_closed() - - self._add_to_commits(new_commit) - return self._add_waiter(new_commit.end) - - def _add_to_commits(self, new_commit: OffsetsRange): - index = bisect.bisect_left(self._pending_commits, new_commit) - - prev_commit = self._pending_commits[index - 1] if index > 0 else None - commit = ( - self._pending_commits[index] if index < len(self._pending_commits) else None - ) - - for c in (prev_commit, commit): - if c is not None and new_commit.is_intersected_with(c): - raise ValueError( - "new commit intersected with existed. New range: %s, existed: %s" - % (new_commit, c) - ) - - if commit is not None and commit.start == new_commit.end: - commit.start = new_commit.start - elif prev_commit is not None and prev_commit.end == new_commit.start: - prev_commit.end = new_commit.end - else: - self._pending_commits.insert(index, new_commit) - - def _add_waiter(self, end_offset: int) -> "PartitionSession.CommitAckWaiter": + def add_waiter(self, end_offset: int) -> "PartitionSession.CommitAckWaiter": waiter = PartitionSession.CommitAckWaiter(end_offset, self._create_future()) + if end_offset <= self.committed_offset: + waiter._finish_ok() + return waiter # fast way if len(self._ack_waiters) > 0 and self._ack_waiters[-1].end_offset < end_offset: @@ -143,26 +109,6 @@ def _create_future(self) -> asyncio.Future: else: return asyncio.Future() - def pop_commit_range(self) -> Optional[OffsetsRange]: - self._ensure_not_closed() - - if len(self._pending_commits) == 0: - return None - - if self._pending_commits[0].start != self._send_commit_window_start: - return None - - res = self._pending_commits.popleft() - while ( - len(self._pending_commits) > 0 and self._pending_commits[0].start == res.end - ): - commit = self._pending_commits.popleft() - res.end = commit.end - - self._send_commit_window_start = res.end - - return res - def ack_notify(self, offset: int): self._ensure_not_closed() @@ -176,7 +122,7 @@ def ack_notify(self, offset: int): while len(self._ack_waiters) > 0: if self._ack_waiters[0].end_offset <= offset: waiter = self._ack_waiters.popleft() - waiter.future.set_result(None) + waiter._finish_ok() else: break @@ -189,7 +135,7 @@ def close(self): self.state = PartitionSession.State.Stopped exception = topic_reader_asyncio.TopicReaderCommitToExpiredPartition() for waiter in self._ack_waiters: - waiter.future.set_exception(exception) + waiter._finish_error(exception) def _ensure_not_closed(self): if self.state == PartitionSession.State.Stopped: @@ -204,6 +150,16 @@ class State(enum.Enum): class CommitAckWaiter: end_offset: int future: asyncio.Future = field(compare=False) + _done: bool = field(default=False, init=False) + _exception: Optional[Exception] = field(default=None, init=False) + + def _finish_ok(self): + self._done = True + self.future.set_result(None) + + def _finish_error(self, error: Exception): + self._exception = error + self.future.set_exception(error) @dataclass diff --git a/ydb/_topic_reader/datatypes_test.py b/ydb/_topic_reader/datatypes_test.py index 6ead9a88..2ec1229f 100644 --- a/ydb/_topic_reader/datatypes_test.py +++ b/ydb/_topic_reader/datatypes_test.py @@ -1,13 +1,11 @@ import asyncio -import bisect import copy import functools from collections import deque -from typing import List, Optional, Type, Union +from typing import List import pytest -from ydb._grpc.grpcwrapper.ydb_topic import OffsetsRange from ydb._topic_common.test_helpers import wait_condition from ydb._topic_reader import topic_reader_asyncio from ydb._topic_reader.datatypes import PartitionSession @@ -73,155 +71,127 @@ def add_notify(future, notified_offset): assert notified == offsets_notified assert session.committed_offset == notify_offset - def test_add_commit(self, session): - commit = OffsetsRange( - self.session_comitted_offset, self.session_comitted_offset + 5 - ) - waiter = session.add_commit(commit) - assert waiter.end_offset == commit.end - + # noinspection PyTypeChecker @pytest.mark.parametrize( - "original,add,result", + "original,add,is_done,result", [ ( [], - OffsetsRange(1, 10), - [OffsetsRange(1, 10)], - ), - ( - [OffsetsRange(1, 10)], - OffsetsRange(15, 20), - [OffsetsRange(1, 10), OffsetsRange(15, 20)], - ), - ( - [OffsetsRange(15, 20)], - OffsetsRange(1, 10), - [OffsetsRange(1, 10), OffsetsRange(15, 20)], - ), - ( - [OffsetsRange(1, 10)], - OffsetsRange(10, 20), - [OffsetsRange(1, 20)], - ), - ( - [OffsetsRange(10, 20)], - OffsetsRange(1, 10), - [OffsetsRange(1, 20)], + session_comitted_offset - 5, + True, + [], ), ( - [OffsetsRange(1, 2), OffsetsRange(3, 4)], - OffsetsRange(2, 3), - [OffsetsRange(1, 2), OffsetsRange(2, 4)], + [PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None)], + session_comitted_offset + 0, + True, + [ + PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None), + ], ), - ( - [OffsetsRange(1, 10)], - OffsetsRange(5, 6), - ValueError, - ), - ], - ) - def test_add_to_commits( - self, - session, - original: List[OffsetsRange], - add: OffsetsRange, - result: Union[List[OffsetsRange], Type[Exception]], - ): - session._pending_commits = copy.deepcopy(original) - if isinstance(result, type) and issubclass(result, Exception): - with pytest.raises(result): - session._add_to_commits(add) - else: - session._add_to_commits(add) - assert session._pending_commits == result - - # noinspection PyTypeChecker - @pytest.mark.parametrize( - "original,add,result", - [ ( [], - 5, - [PartitionSession.CommitAckWaiter(5, None)], - ), - ( - [PartitionSession.CommitAckWaiter(5, None)], - 6, + session_comitted_offset + 5, + False, [ - PartitionSession.CommitAckWaiter(5, None), - PartitionSession.CommitAckWaiter(6, None), + PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None), ], ), ( - [PartitionSession.CommitAckWaiter(5, None)], - 4, + [PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None)], + session_comitted_offset + 6, + False, [ - PartitionSession.CommitAckWaiter(4, None), - PartitionSession.CommitAckWaiter(5, None), + PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None), + PartitionSession.CommitAckWaiter(session_comitted_offset + 6, None), ], ), ( - [PartitionSession.CommitAckWaiter(5, None)], - 0, + [PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None)], + session_comitted_offset + 4, + False, [ - PartitionSession.CommitAckWaiter(0, None), - PartitionSession.CommitAckWaiter(5, None), + PartitionSession.CommitAckWaiter(session_comitted_offset + 4, None), + PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None), ], ), ( - [PartitionSession.CommitAckWaiter(5, None)], - 100, + [PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None)], + session_comitted_offset + 100, + False, [ - PartitionSession.CommitAckWaiter(5, None), - PartitionSession.CommitAckWaiter(100, None), + PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None), + PartitionSession.CommitAckWaiter( + session_comitted_offset + 100, None + ), ], ), ( [ - PartitionSession.CommitAckWaiter(5, None), - PartitionSession.CommitAckWaiter(100, None), + PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None), + PartitionSession.CommitAckWaiter( + session_comitted_offset + 100, None + ), ], - 50, + session_comitted_offset + 50, + False, [ - PartitionSession.CommitAckWaiter(5, None), - PartitionSession.CommitAckWaiter(50, None), - PartitionSession.CommitAckWaiter(100, None), + PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None), + PartitionSession.CommitAckWaiter( + session_comitted_offset + 50, None + ), + PartitionSession.CommitAckWaiter( + session_comitted_offset + 100, None + ), ], ), ( [ - PartitionSession.CommitAckWaiter(5, None), - PartitionSession.CommitAckWaiter(7, None), + PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None), + PartitionSession.CommitAckWaiter(session_comitted_offset + 7, None), ], - 6, + session_comitted_offset + 6, + False, [ - PartitionSession.CommitAckWaiter(5, None), - PartitionSession.CommitAckWaiter(6, None), - PartitionSession.CommitAckWaiter(7, None), + PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None), + PartitionSession.CommitAckWaiter(session_comitted_offset + 6, None), + PartitionSession.CommitAckWaiter(session_comitted_offset + 7, None), ], ), ( [ - PartitionSession.CommitAckWaiter(5, None), - PartitionSession.CommitAckWaiter(100, None), + PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None), + PartitionSession.CommitAckWaiter( + session_comitted_offset + 100, None + ), ], - 6, + session_comitted_offset + 6, + False, [ - PartitionSession.CommitAckWaiter(5, None), - PartitionSession.CommitAckWaiter(6, None), - PartitionSession.CommitAckWaiter(100, None), + PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None), + PartitionSession.CommitAckWaiter(session_comitted_offset + 6, None), + PartitionSession.CommitAckWaiter( + session_comitted_offset + 100, None + ), ], ), ( [ - PartitionSession.CommitAckWaiter(5, None), - PartitionSession.CommitAckWaiter(100, None), + PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None), + PartitionSession.CommitAckWaiter( + session_comitted_offset + 100, None + ), ], - 99, + session_comitted_offset + 99, + False, [ - PartitionSession.CommitAckWaiter(5, None), - PartitionSession.CommitAckWaiter(99, None), - PartitionSession.CommitAckWaiter(100, None), + PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None), + PartitionSession.CommitAckWaiter( + session_comitted_offset + 99, None + ), + PartitionSession.CommitAckWaiter( + session_comitted_offset + 100, None + ), ], ), ], @@ -231,17 +201,16 @@ def test_add_waiter( session, original: List[PartitionSession.CommitAckWaiter], add: int, + is_done: bool, result: List[PartitionSession.CommitAckWaiter], ): session._ack_waiters = copy.deepcopy(original) - res = session._add_waiter(add) + res = session.add_waiter(add) assert result == session._ack_waiters - - index = bisect.bisect_left(session._ack_waiters, res) - assert res is session._ack_waiters[index] + assert res.future.done() == is_done def test_close_notify_waiters(self, session): - waiter = session._add_waiter(session.committed_offset + 1) + waiter = session.add_waiter(session.committed_offset + 1) session.close() with pytest.raises(topic_reader_asyncio.TopicReaderCommitToExpiredPartition): @@ -250,66 +219,3 @@ def test_close_notify_waiters(self, session): def test_close_twice(self, session): session.close() session.close() - - @pytest.mark.parametrize( - "commits,result,rest", - [ - ([], None, []), - ( - [OffsetsRange(session_comitted_offset + 1, 20)], - None, - [OffsetsRange(session_comitted_offset + 1, 20)], - ), - ( - [OffsetsRange(session_comitted_offset, session_comitted_offset + 1)], - OffsetsRange(session_comitted_offset, session_comitted_offset + 1), - [], - ), - ( - [ - OffsetsRange(session_comitted_offset, session_comitted_offset + 1), - OffsetsRange( - session_comitted_offset + 1, session_comitted_offset + 2 - ), - ], - OffsetsRange(session_comitted_offset, session_comitted_offset + 2), - [], - ), - ( - [ - OffsetsRange(session_comitted_offset, session_comitted_offset + 1), - OffsetsRange( - session_comitted_offset + 1, session_comitted_offset + 2 - ), - OffsetsRange( - session_comitted_offset + 10, session_comitted_offset + 20 - ), - ], - OffsetsRange(session_comitted_offset, session_comitted_offset + 2), - [ - OffsetsRange( - session_comitted_offset + 10, session_comitted_offset + 20 - ) - ], - ), - ], - ) - def test_get_commit_range( - self, - session, - commits: List[OffsetsRange], - result: Optional[OffsetsRange], - rest: List[OffsetsRange], - ): - send_commit_window_start = session._send_commit_window_start - - session._pending_commits = deque(commits) - res = session.pop_commit_range() - assert res == result - assert session._pending_commits == deque(rest) - - if res is None: - assert session._send_commit_window_start == send_commit_window_start - else: - assert session._send_commit_window_start != send_commit_window_start - assert session._send_commit_window_start == res.end diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index cc0839f7..835fc786 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -377,15 +377,15 @@ def commit( "commit messages after server stop the partition read session" ) - waiter = partition_session.add_commit(batch._commit_get_offsets_range()) + commit_range = batch._commit_get_offsets_range() + waiter = partition_session.add_waiter(commit_range.end) - send_range = partition_session.pop_commit_range() - if send_range: + if not waiter.future.done(): client_message = StreamReadMessage.CommitOffsetRequest( commit_offsets=[ StreamReadMessage.CommitOffsetRequest.PartitionCommitOffset( partition_session_id=partition_session.id, - offsets=[send_range], + offsets=[commit_range], ) ] ) diff --git a/ydb/_topic_reader/topic_reader_asyncio_test.py b/ydb/_topic_reader/topic_reader_asyncio_test.py index c73be69f..e4609ea0 100644 --- a/ydb/_topic_reader/topic_reader_asyncio_test.py +++ b/ydb/_topic_reader/topic_reader_asyncio_test.py @@ -1,9 +1,7 @@ import asyncio import datetime import typing -from collections import deque from dataclasses import dataclass -from typing import List, Optional from unittest import mock import pytest @@ -15,8 +13,12 @@ from .topic_reader_asyncio import ReaderStream, ReaderReconnector from .._grpc.grpcwrapper.common_utils import SupportedDriverType, ServerStatus from .._grpc.grpcwrapper.ydb_topic import StreamReadMessage, Codec, OffsetsRange -from .._topic_common import test_helpers -from .._topic_common.test_helpers import StreamMock, wait_condition, wait_for_fast +from .._topic_common.test_helpers import ( + StreamMock, + wait_condition, + wait_for_fast, + WaitConditionError, +) # Workaround for good IDE and universal for runtime if typing.TYPE_CHECKING: @@ -234,124 +236,21 @@ class TestError(Exception): stream_reader_finish_with_error.receive_batch_nowait() @pytest.mark.parametrize( - "pending_ranges,commit,send_range,rest_ranges", + "commit,send_range", [ ( - [], OffsetsRange( partition_session_committed_offset, partition_session_committed_offset + 1, ), - OffsetsRange( - partition_session_committed_offset, - partition_session_committed_offset + 1, - ), - [], - ), - ( - [], - OffsetsRange( - partition_session_committed_offset + 1, - partition_session_committed_offset + 2, - ), - None, - [ - OffsetsRange( - partition_session_committed_offset + 1, - partition_session_committed_offset + 2, - ) - ], - ), - ( - [ - OffsetsRange( - partition_session_committed_offset + 5, - partition_session_committed_offset + 10, - ) - ], - OffsetsRange( - partition_session_committed_offset + 1, - partition_session_committed_offset + 2, - ), - None, - [ - OffsetsRange( - partition_session_committed_offset + 1, - partition_session_committed_offset + 2, - ), - OffsetsRange( - partition_session_committed_offset + 5, - partition_session_committed_offset + 10, - ), - ], - ), - ( - [ - OffsetsRange( - partition_session_committed_offset + 1, - partition_session_committed_offset + 2, - ) - ], - OffsetsRange( - partition_session_committed_offset, - partition_session_committed_offset + 1, - ), - OffsetsRange( - partition_session_committed_offset, - partition_session_committed_offset + 2, - ), - [], + True, ), ( - [ - OffsetsRange( - partition_session_committed_offset + 1, - partition_session_committed_offset + 2, - ), - OffsetsRange( - partition_session_committed_offset + 2, - partition_session_committed_offset + 3, - ), - ], OffsetsRange( + partition_session_committed_offset - 1, partition_session_committed_offset, - partition_session_committed_offset + 1, ), - OffsetsRange( - partition_session_committed_offset, - partition_session_committed_offset + 3, - ), - [], - ), - ( - [ - OffsetsRange( - partition_session_committed_offset + 1, - partition_session_committed_offset + 2, - ), - OffsetsRange( - partition_session_committed_offset + 2, - partition_session_committed_offset + 3, - ), - OffsetsRange( - partition_session_committed_offset + 4, - partition_session_committed_offset + 5, - ), - ], - OffsetsRange( - partition_session_committed_offset, - partition_session_committed_offset + 1, - ), - OffsetsRange( - partition_session_committed_offset, - partition_session_committed_offset + 3, - ), - [ - OffsetsRange( - partition_session_committed_offset + 4, - partition_session_committed_offset + 5, - ) - ], + False, ), ], ) @@ -360,10 +259,8 @@ async def test_send_commit_messages( stream, stream_reader: ReaderStream, partition_session, - pending_ranges: List[OffsetsRange], commit: OffsetsRange, - send_range: Optional[OffsetsRange], - rest_ranges: List[OffsetsRange], + send_range: bool, ): @dataclass class Commitable(datatypes.ICommittable): @@ -376,9 +273,9 @@ def _commit_get_partition_session(self) -> datatypes.PartitionSession: def _commit_get_offsets_range(self) -> OffsetsRange: return OffsetsRange(self.start, self.end) - partition_session._pending_commits = deque(pending_ranges) + start_ack_waiters = partition_session._ack_waiters.copy() - stream_reader.commit(Commitable(commit.start, commit.end)) + waiter = stream_reader.commit(Commitable(commit.start, commit.end)) async def wait_message(): return await wait_for_fast(stream.from_client.get(), timeout=0) @@ -389,24 +286,27 @@ async def wait_message(): commit_offsets=[ StreamReadMessage.CommitOffsetRequest.PartitionCommitOffset( partition_session_id=partition_session.id, - offsets=[send_range], + offsets=[commit], ) ] ) + assert partition_session._ack_waiters[-1].end_offset == commit.end else: - with pytest.raises(test_helpers.WaitConditionException): - await wait_message() + assert waiter.future.done() - assert partition_session._pending_commits == deque(rest_ranges) + with pytest.raises(WaitConditionError): + msg = await wait_message() + pass + assert start_ack_waiters == partition_session._ack_waiters async def test_commit_ack_received( self, stream_reader, stream, partition_session, second_partition_session ): offset1 = self.partition_session_committed_offset + 1 - waiter1 = partition_session._add_waiter(offset1) + waiter1 = partition_session.add_waiter(offset1) offset2 = self.second_partition_session_offset + 2 - waiter2 = second_partition_session._add_waiter(offset2) + waiter2 = second_partition_session.add_waiter(offset2) stream.from_server.put_nowait( StreamReadMessage.FromServer( @@ -432,7 +332,7 @@ async def test_commit_ack_received( async def test_close_ack_waiters_when_close_stream_reader( self, stream_reader_started: ReaderStream, partition_session ): - waiter = partition_session._add_waiter( + waiter = partition_session.add_waiter( self.partition_session_committed_offset + 1 ) await wait_for_fast(stream_reader_started.close()) From d13d85d318099bc781302e06624d0de4073d969c Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Wed, 1 Mar 2023 18:51:16 +0300 Subject: [PATCH 075/147] fix style --- .github/workflows/style.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/style.yaml b/.github/workflows/style.yaml index c280042b..8723d8f2 100644 --- a/.github/workflows/style.yaml +++ b/.github/workflows/style.yaml @@ -2,7 +2,6 @@ name: Style checks on: push: - - main pull_request: jobs: From 9e0e44d8d59114f91a36c12e7617bad747fea034 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Thu, 2 Mar 2023 17:35:01 +0300 Subject: [PATCH 076/147] fix typos --- tests/topics/test_topic_reader.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/topics/test_topic_reader.py b/tests/topics/test_topic_reader.py index 734b64c7..a874c743 100644 --- a/tests/topics/test_topic_reader.py +++ b/tests/topics/test_topic_reader.py @@ -15,11 +15,11 @@ async def test_read_and_commit_message( self, driver, topic_path, topic_with_messages, topic_consumer ): - reader = driver.topic_client.topic_reader(topic_consumer, topic_path) + reader = driver.topic_client.reader(topic_consumer, topic_path) batch = await reader.receive_batch() await reader.commit_with_ack(batch) - reader = driver.topic_client.topic_reader(topic_consumer, topic_path) + reader = driver.topic_client.reader(topic_consumer, topic_path) batch2 = await reader.receive_batch() assert batch.messages[0] != batch2.messages[0] @@ -36,10 +36,10 @@ def test_read_message( def test_read_and_commit_message( self, driver_sync, topic_path, topic_with_messages, topic_consumer ): - reader = driver_sync.topic_client.topic_reader(topic_consumer, topic_path) + reader = driver_sync.topic_client.reader(topic_consumer, topic_path) batch = reader.receive_batch() reader.commit_with_ack(batch) - reader = driver_sync.topic_client.topic_reader(topic_consumer, topic_path) + reader = driver_sync.topic_client.reader(topic_consumer, topic_path) batch2 = reader.receive_batch() assert batch.messages[0] != batch2.messages[0] From 9b746fc4b1a08ea51e51103383216d5ed1ea1504 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Thu, 2 Mar 2023 18:21:34 +0300 Subject: [PATCH 077/147] typo while check reconnector_id and style --- ydb/_topic_common/test_helpers.py | 4 +++- ydb/_topic_reader/topic_reader_asyncio.py | 12 +++--------- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/ydb/_topic_common/test_helpers.py b/ydb/_topic_common/test_helpers.py index d70cd9f1..96a812ab 100644 --- a/ydb/_topic_common/test_helpers.py +++ b/ydb/_topic_common/test_helpers.py @@ -54,9 +54,11 @@ async def wait_condition( if timeout is None: timeout = 1 + minimal_loop_count_for_wait = 1000 + start = time.monotonic() counter = 0 - while (time.monotonic() - start < timeout) or counter < 1000: + while (time.monotonic() - start < timeout) or counter < minimal_loop_count_for_wait: counter += 1 if f(): return diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index 835fc786..303f4c91 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -361,10 +361,7 @@ def commit( ) -> datatypes.PartitionSession.CommitAckWaiter: partition_session = batch._commit_get_partition_session() - if ( - partition_session.reader_reconnector_id - != partition_session.reader_reconnector_id - ): + if partition_session.reader_reconnector_id != self._reader_reconnector_id: raise TopicReaderError("reader can commit only self-produced messages") if partition_session.reader_stream_id != self._id: @@ -498,11 +495,8 @@ def _on_read_response(self, message: StreamReadMessage.ReadResponse): def _on_commit_response(self, message: StreamReadMessage.CommitOffsetResponse): for partition_offset in message.partitions_committed_offsets: - try: - session = self._partition_sessions[ - partition_offset.partition_session_id - ] - except KeyError: + session = self._partition_sessions.get(partition_offset.partition_session_id) + if session is None: continue session.ack_notify(partition_offset.committed_offset) From 37e117c2acec5e5b2cb60eb696f24c297042897e Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Thu, 2 Mar 2023 19:10:47 +0300 Subject: [PATCH 078/147] style --- ydb/_topic_reader/topic_reader_asyncio.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index 303f4c91..44125e54 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -495,7 +495,9 @@ def _on_read_response(self, message: StreamReadMessage.ReadResponse): def _on_commit_response(self, message: StreamReadMessage.CommitOffsetResponse): for partition_offset in message.partitions_committed_offsets: - session = self._partition_sessions.get(partition_offset.partition_session_id) + session = self._partition_sessions.get( + partition_offset.partition_session_id + ) if session is None: continue session.ack_notify(partition_offset.committed_offset) From 41b33569991232dfe8f29459fa8c2d32575f94c4 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Fri, 3 Mar 2023 16:00:27 +0300 Subject: [PATCH 079/147] style --- ydb/_topic_reader/topic_reader_asyncio.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index 44125e54..fa940136 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -470,8 +470,9 @@ def _on_start_partition_session( def _on_partition_session_stop( self, message: StreamReadMessage.StopPartitionSessionRequest ): - partition = self._partition_sessions.get(message.partition_session_id) - if partition is None: + try: + partition = self._partition_sessions.get(message.partition_session_id) + except KeyError: # may if receive stop partition with graceful=false after response on stop partition # with graceful=true and remove partition from internal dictionary return From 2334f3591095a407a37307213e2d4905e3cc530e Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Fri, 3 Mar 2023 16:17:08 +0300 Subject: [PATCH 080/147] typo --- ydb/_topic_reader/topic_reader_asyncio.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index fa940136..ab0981f6 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -471,7 +471,7 @@ def _on_partition_session_stop( self, message: StreamReadMessage.StopPartitionSessionRequest ): try: - partition = self._partition_sessions.get(message.partition_session_id) + partition = self._partition_sessions[message.partition_session_id] except KeyError: # may if receive stop partition with graceful=false after response on stop partition # with graceful=true and remove partition from internal dictionary From 6c2ff6a762d4b316dcfde3c190ed69a43d5eb80d Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Sun, 5 Mar 2023 21:38:41 +0100 Subject: [PATCH 081/147] change gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 55c4ea54..fd366ee5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ __pycache__ ydb.egg-info/ /.idea +/.vscode /tox /venv /ydb_certs From 19a2f4450aa99d5c8fee9075ec92694ed20ae8b7 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Sun, 5 Mar 2023 22:06:29 +0100 Subject: [PATCH 082/147] install ydb sources as editable requirements for tests fix proto requirements for ydb --- requirements.txt | 2 +- test-requirements.txt | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index da37d9fa..039605b8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,5 +2,5 @@ aiohttp>=3.7.4,<4.0.0 enum-compat>=0.0.1 grpcio>=1.42.0 packaging -protobuf>3.13.0,<5.0.0 +protobuf>=3.13.0,<5.0.0 six<2 \ No newline at end of file diff --git a/test-requirements.txt b/test-requirements.txt index 273f0fb6..af6c3a53 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -48,3 +48,4 @@ cython freezegun==1.2.2 grpcio-tools pytest-cov +-e . From e1ea1f457cc10578bd467c676d83943e56ac6a37 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Mon, 6 Mar 2023 00:40:51 +0100 Subject: [PATCH 083/147] remove difficult variants of send message fix examples add wait container timeout - for tests on m1 --- examples/topic/writer_async_example.py | 14 +++---- examples/topic/writer_example.py | 16 ++++---- tests/conftest.py | 14 ++++--- tests/topics/test_topic_writer.py | 48 +++++++++++++++++++++++ ydb/_topic_reader/topic_reader_asyncio.py | 5 ++- ydb/_topic_writer/topic_writer.py | 14 +++++-- ydb/_topic_writer/topic_writer_asyncio.py | 24 ++++++------ 7 files changed, 98 insertions(+), 37 deletions(-) diff --git a/examples/topic/writer_async_example.py b/examples/topic/writer_async_example.py index 548ef1aa..2337847b 100644 --- a/examples/topic/writer_async_example.py +++ b/examples/topic/writer_async_example.py @@ -37,15 +37,15 @@ async def connect_without_context_manager(db: ydb.aio.Driver): async def send_messages(writer: ydb.TopicWriterAsyncIO): # simple str/bytes without additional metadata await writer.write("mess") # send text - await writer.write(bytes([1, 2, 3])) # send bytes - await writer.write("mess-1", "mess-2") # send two messages + await writer.write(bytes([1, 2, 3])) # send single message with bytes 1,2,3 + await writer.write(["mess-1", "mess-2"]) # send two messages # full forms await writer.write(ydb.TopicWriterMessage("mess")) # send text await writer.write(ydb.TopicWriterMessage(bytes([1, 2, 3]))) # send bytes - await writer.write( + await writer.write([ ydb.TopicWriterMessage("mess-1"), ydb.TopicWriterMessage("mess-2") - ) # send few messages by one call + ]) # send few messages by one call # with meta await writer.write( @@ -71,12 +71,12 @@ async def send_messages_with_manual_seqno(writer: ydb.TopicWriter): async def send_messages_with_wait_ack(writer: ydb.TopicWriterAsyncIO): # future wait - await writer.write_with_result( + await writer.write_with_result([ ydb.TopicWriterMessage("mess", seqno=1), ydb.TopicWriterMessage("mess", seqno=2) - ) + ]) # send with flush - await writer.write("1", "2", "3") + await writer.write(["1", "2", "3"]) await writer.flush() diff --git a/examples/topic/writer_example.py b/examples/topic/writer_example.py index e95107d1..3ae9a4ee 100644 --- a/examples/topic/writer_example.py +++ b/examples/topic/writer_example.py @@ -44,21 +44,21 @@ def connect_without_context_manager(db: ydb.Driver): try: pass # some code finally: - await writer.close() + writer.close() def send_messages(writer: ydb.TopicWriter): # simple str/bytes without additional metadata writer.write("mess") # send text - writer.write(bytes([1, 2, 3])) # send bytes - writer.write("mess-1", "mess-2") # send two messages + writer.write(bytes([1, 2, 3])) # send single message with bytes 1,2,3 + writer.write(["mess-1", "mess-2"]) # send two messages # full forms writer.write(ydb.TopicWriterMessage("mess")) # send text writer.write(ydb.TopicWriterMessage(bytes([1, 2, 3]))) # send bytes - writer.write( + writer.write([ ydb.TopicWriterMessage("mess-1"), ydb.TopicWriterMessage("mess-2") - ) # send few messages by one call + ]) # send few messages by one call # with meta writer.write(ydb.TopicWriterMessage("asd", seqno=123, created_at_ns=time.time_ns())) @@ -87,13 +87,13 @@ def send_messages_with_wait_ack(writer: ydb.TopicWriter): ).result() # implicit, by sync call - writer.write_with_ack( + writer.write_with_ack([ ydb.TopicWriterMessage("mess", seqno=1), ydb.TopicWriterMessage("mess", seqno=2) - ) + ]) # write_with_ack # send with flush - writer.write("1", "2", "3") + writer.write(["1", "2", "3"]) writer.flush() diff --git a/tests/conftest.py b/tests/conftest.py index 62f486cb..99123660 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,7 +12,7 @@ def docker_compose_file(pytestconfig): def wait_container_ready(driver): - driver.wait(timeout=10) + driver.wait(timeout=30) with ydb.SessionPool(driver) as pool: @@ -133,12 +133,16 @@ async def topic_path(driver, topic_consumer, database) -> str: async def topic_with_messages(driver, topic_path): writer = driver.topic_client.writer(topic_path, producer_id="fixture-producer-id") await writer.write_with_ack( - ydb.TopicWriterMessage(data="123".encode()), - ydb.TopicWriterMessage(data="456".encode()), + [ + ydb.TopicWriterMessage(data="123".encode()), + ydb.TopicWriterMessage(data="456".encode()), + ] ) await writer.write_with_ack( - ydb.TopicWriterMessage(data="789".encode()), - ydb.TopicWriterMessage(data="0".encode()), + [ + ydb.TopicWriterMessage(data="789".encode()), + ydb.TopicWriterMessage(data="0".encode()), + ] ) await writer.close() diff --git a/tests/topics/test_topic_writer.py b/tests/topics/test_topic_writer.py index e7db0e23..9e8b0dfe 100644 --- a/tests/topics/test_topic_writer.py +++ b/tests/topics/test_topic_writer.py @@ -60,6 +60,30 @@ async def test_auto_flush_on_close(self, driver: ydb.aio.Driver, topic_path): init_info = await writer.wait_init() assert init_info.last_seqno == last_seqno + async def test_write_multi_message_with_ack( + self, driver: ydb.aio.Driver, topic_path, topic_reader: ydb.TopicReaderAsyncIO + ): + async with driver.topic_client.writer(topic_path) as writer: + await writer.write_with_ack( + [ + ydb.TopicWriterMessage(data="123".encode()), + ydb.TopicWriterMessage(data="456".encode()), + ] + ) + + batch = await topic_reader.receive_batch() + + assert batch.messages[0].offset == 0 + assert batch.messages[0].seqno == 1 + assert batch.messages[0].data == "123".encode() + + # remove second recieve batch when implement batching + # https://github.com/ydb-platform/ydb-python-sdk/issues/142 + batch = await topic_reader.receive_batch() + assert batch.messages[0].offset == 1 + assert batch.messages[0].seqno == 2 + assert batch.messages[0].data == "456".encode() + class TestTopicWriterSync: def test_send_message(self, driver_sync: ydb.Driver, topic_path): @@ -115,3 +139,27 @@ def test_random_producer_id( batch2 = topic_reader_sync.receive_batch() assert batch1.messages[0].producer_id != batch2.messages[0].producer_id + + def test_write_multi_message_with_ack( + self, driver_sync: ydb.Driver, topic_path, topic_reader_sync: ydb.TopicReader + ): + with driver_sync.topic_client.writer(topic_path) as writer: + writer.write_with_ack( + [ + ydb.TopicWriterMessage(data="123".encode()), + ydb.TopicWriterMessage(data="456".encode()), + ] + ) + + batch = topic_reader_sync.receive_batch() + + assert batch.messages[0].offset == 0 + assert batch.messages[0].seqno == 1 + assert batch.messages[0].data == "123".encode() + + # remove second recieve batch when implement batching + # https://github.com/ydb-platform/ydb-python-sdk/issues/142 + batch = topic_reader_sync.receive_batch() + assert batch.messages[0].offset == 1 + assert batch.messages[0].seqno == 2 + assert batch.messages[0].data == "456".encode() diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index ab0981f6..3e0e362e 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -115,7 +115,7 @@ async def receive_batch( *, max_messages: typing.Union[int, None] = None, max_bytes: typing.Union[int, None] = None, - ) -> typing.Union[topic_reader.PublicBatch, None]: + ) -> typing.Union[datatypes.PublicBatch, None]: """ Get one messages batch from reader. All messages in a batch from same partition. @@ -243,7 +243,8 @@ def commit( return self._stream_reader.commit(batch) async def close(self): - await self._stream_reader.close() + if self._stream_reader: + await self._stream_reader.close() for task in self._background_tasks: task.cancel() diff --git a/ydb/_topic_writer/topic_writer.py b/ydb/_topic_writer/topic_writer.py index aa147558..a2c3d0d7 100644 --- a/ydb/_topic_writer/topic_writer.py +++ b/ydb/_topic_writer/topic_writer.py @@ -3,7 +3,7 @@ import uuid from dataclasses import dataclass from enum import Enum -from typing import List, Union, TextIO, BinaryIO, Optional, Any, Dict +from typing import List, Union, Optional, Any, Dict import typing @@ -92,9 +92,9 @@ class PublicWriterInitInfo: class PublicMessage: seqno: Optional[int] created_at: Optional[datetime.datetime] - data: Union[str, bytes, TextIO, BinaryIO] + data: "PublicMessage.SimpleMessageSourceType" - SimpleMessageSourceType = Union[str, bytes, TextIO, BinaryIO] + SimpleMessageSourceType = Union[str, bytes] # Will be extend def __init__( self, @@ -107,6 +107,14 @@ def __init__( self.created_at = created_at self.data = data + @staticmethod + def _create_message( + data: Union["PublicMessage", "PublicMessage.SimpleMessageSourceType"] + ) -> "PublicMessage": + if isinstance(data, PublicMessage): + return data + return PublicMessage(data=data) + class InternalMessage(StreamWriteMessage.WriteRequest.MessageData, IToProto): def __init__(self, mess: PublicMessage): diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index b46a13b8..67b1be69 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -1,7 +1,7 @@ import asyncio import datetime from collections import deque -from typing import Deque, AsyncIterator, Union, List, Optional +from typing import Deque, AsyncIterator, Union, List import ydb from .topic_writer import ( @@ -76,7 +76,6 @@ async def close(self, *, flush: bool = True): async def write_with_ack( self, messages: Union[MessageType, List[MessageType]], - *args: Optional[MessageType], ) -> Union[PublicWriteResultTypes, List[PublicWriteResultTypes]]: """ IT IS SLOWLY WAY. IT IS BAD CHOISE IN MOST CASES. @@ -86,7 +85,7 @@ async def write_with_ack( For wait with timeout use asyncio.wait_for. """ - futures = await self.write_with_ack_future(messages, *args) + futures = await self.write_with_ack_future(messages) if not isinstance(futures, list): futures = [futures] @@ -98,7 +97,6 @@ async def write_with_ack( async def write_with_ack_future( self, messages: Union[MessageType, List[MessageType]], - *args: Optional[MessageType], ) -> Union[asyncio.Future, List[asyncio.Future]]: """ send one or number of messages to server. @@ -108,20 +106,22 @@ async def write_with_ack_future( For wait with timeout use asyncio.wait_for. """ + input_single_message = not isinstance(messages, list) if isinstance(messages, PublicMessage): - futures = await self._reconnector.write_with_ack_future([messages]) - return futures[0] + messages = [PublicMessage._create_message(messages)] if isinstance(messages, list): - for m in messages: - if not isinstance(m, PublicMessage): - raise NotImplementedError() - return await self._reconnector.write_with_ack_future(messages) - raise NotImplementedError() + for index, m in enumerate(messages): + messages[index] = PublicMessage._create_message(m) + + futures = await self._reconnector.write_with_ack_future(messages) + if input_single_message: + return futures[0] + else: + return futures async def write( self, messages: Union[MessageType, List[MessageType]], - *args: Optional[MessageType], ): """ send one or number of messages to server. From be267cee01f602747fbdaa1e4ff37ac015972cef Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Mon, 6 Mar 2023 00:43:45 +0100 Subject: [PATCH 084/147] changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3043520c..33c89bd1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,4 @@ +* BROKEN CHANGES: remove writer.write(mess1, mess2) variant, use list instead: writer.write([mess1, mess2]) * BROKEN CHANGES: change names of public method in topic client * BROKEN CHANGES: rename parameter producer_and_message_group_id to producer_id * producer_id is optional now From 48d1b366b754ab699ec6ec12e9fd6005a2e40f49 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Mon, 6 Mar 2023 00:44:27 +0100 Subject: [PATCH 085/147] fix linter --- examples/topic/writer_async_example.py | 15 +++++++++------ examples/topic/writer_example.py | 15 +++++++++------ 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/examples/topic/writer_async_example.py b/examples/topic/writer_async_example.py index 2337847b..c5144685 100644 --- a/examples/topic/writer_async_example.py +++ b/examples/topic/writer_async_example.py @@ -43,9 +43,9 @@ async def send_messages(writer: ydb.TopicWriterAsyncIO): # full forms await writer.write(ydb.TopicWriterMessage("mess")) # send text await writer.write(ydb.TopicWriterMessage(bytes([1, 2, 3]))) # send bytes - await writer.write([ - ydb.TopicWriterMessage("mess-1"), ydb.TopicWriterMessage("mess-2") - ]) # send few messages by one call + await writer.write( + [ydb.TopicWriterMessage("mess-1"), ydb.TopicWriterMessage("mess-2")] + ) # send few messages by one call # with meta await writer.write( @@ -71,9 +71,12 @@ async def send_messages_with_manual_seqno(writer: ydb.TopicWriter): async def send_messages_with_wait_ack(writer: ydb.TopicWriterAsyncIO): # future wait - await writer.write_with_result([ - ydb.TopicWriterMessage("mess", seqno=1), ydb.TopicWriterMessage("mess", seqno=2) - ]) + await writer.write_with_result( + [ + ydb.TopicWriterMessage("mess", seqno=1), + ydb.TopicWriterMessage("mess", seqno=2), + ] + ) # send with flush await writer.write(["1", "2", "3"]) diff --git a/examples/topic/writer_example.py b/examples/topic/writer_example.py index 3ae9a4ee..1465dba5 100644 --- a/examples/topic/writer_example.py +++ b/examples/topic/writer_example.py @@ -56,9 +56,9 @@ def send_messages(writer: ydb.TopicWriter): # full forms writer.write(ydb.TopicWriterMessage("mess")) # send text writer.write(ydb.TopicWriterMessage(bytes([1, 2, 3]))) # send bytes - writer.write([ - ydb.TopicWriterMessage("mess-1"), ydb.TopicWriterMessage("mess-2") - ]) # send few messages by one call + writer.write( + [ydb.TopicWriterMessage("mess-1"), ydb.TopicWriterMessage("mess-2")] + ) # send few messages by one call # with meta writer.write(ydb.TopicWriterMessage("asd", seqno=123, created_at_ns=time.time_ns())) @@ -87,9 +87,12 @@ def send_messages_with_wait_ack(writer: ydb.TopicWriter): ).result() # implicit, by sync call - writer.write_with_ack([ - ydb.TopicWriterMessage("mess", seqno=1), ydb.TopicWriterMessage("mess", seqno=2) - ]) + writer.write_with_ack( + [ + ydb.TopicWriterMessage("mess", seqno=1), + ydb.TopicWriterMessage("mess", seqno=2), + ] + ) # write_with_ack # send with flush From d7bffd83872baabba5771d81959492b77521e7dd Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Mon, 6 Mar 2023 00:54:46 +0100 Subject: [PATCH 086/147] remove unsupported args in sync writer --- ydb/_topic_writer/topic_writer_sync.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/ydb/_topic_writer/topic_writer_sync.py b/ydb/_topic_writer/topic_writer_sync.py index 419edcba..dc7b7fbd 100644 --- a/ydb/_topic_writer/topic_writer_sync.py +++ b/ydb/_topic_writer/topic_writer_sync.py @@ -92,25 +92,22 @@ def wait_init(self, timeout: Optional[TimeoutType] = None) -> PublicWriterInitIn def write( self, - message: Union[PublicMessage, List[PublicMessage]], - *args: Optional[PublicMessage], + messages: Union[PublicMessage, List[PublicMessage]], timeout: Union[float, None] = None, ): - self._call_sync(self._async_writer.write(message, *args), timeout=timeout) + self._call_sync(self._async_writer.write(messages), timeout=timeout) def async_write_with_ack( self, messages: Union[MessageType, List[MessageType]], - *args: Optional[MessageType], ) -> Future[Union[PublicWriteResult, List[PublicWriteResult]]]: - return self._call(self._async_writer.write_with_ack(messages, *args)) + return self._call(self._async_writer.write_with_ack(messages)) def write_with_ack( self, messages: Union[MessageType, List[MessageType]], - *args: Optional[MessageType], timeout: Union[float, None] = None, ) -> Union[PublicWriteResult, List[PublicWriteResult]]: return self._call_sync( - self._async_writer.write_with_ack(messages, *args), timeout=timeout + self._async_writer.write_with_ack(messages), timeout=timeout ) From b9a9a4bb2a816309a96d401d49e1b0a5f14a41ac Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Mon, 6 Mar 2023 00:55:45 +0100 Subject: [PATCH 087/147] fix examples --- examples/topic/reader_async_example.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/topic/reader_async_example.py b/examples/topic/reader_async_example.py index 96142921..fb7f4c26 100644 --- a/examples/topic/reader_async_example.py +++ b/examples/topic/reader_async_example.py @@ -123,8 +123,8 @@ def process_batch(batch): # no reason work with expired batch # go read next - good batch return - await _process(message) - await reader.commit(batch) + _process(message) + reader.commit(batch) async for batch in reader.batches(): process_batch(batch) From c12557652b321761f95b1daf0032119de6073589 Mon Sep 17 00:00:00 2001 From: robot Date: Mon, 6 Mar 2023 00:04:06 +0000 Subject: [PATCH 088/147] Release: 3.0.1b6 --- CHANGELOG.md | 1 + setup.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 33c89bd1..decbb671 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,4 @@ +## 3.0.1b6 ## * BROKEN CHANGES: remove writer.write(mess1, mess2) variant, use list instead: writer.write([mess1, mess2]) * BROKEN CHANGES: change names of public method in topic client * BROKEN CHANGES: rename parameter producer_and_message_group_id to producer_id diff --git a/setup.py b/setup.py index ae7df0fd..e585796e 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ setuptools.setup( name="ydb", - version="3.0.1b5", # AUTOVERSION + version="3.0.1b6", # AUTOVERSION description="YDB Python SDK", author="Yandex LLC", author_email="ydb@yandex-team.ru", From 390274802b50f7ef62aa48ae7fae51be88689eed Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Mon, 6 Mar 2023 23:11:30 +0100 Subject: [PATCH 089/147] implement writer codecs --- tests/conftest.py | 4 +- tests/topics/test_topic_writer.py | 28 +++ .../grpcwrapper/ydb_topic_public_types.py | 11 +- ydb/_topic_writer/topic_writer.py | 16 +- ydb/_topic_writer/topic_writer_asyncio.py | 181 +++++++++++++++++- .../topic_writer_asyncio_test.py | 127 +++++++++++- ydb/_topic_writer/topic_writer_sync.py | 3 +- ydb/topic.py | 99 +++++++++- 8 files changed, 437 insertions(+), 32 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 99123660..e7a21970 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -131,7 +131,9 @@ async def topic_path(driver, topic_consumer, database) -> str: @pytest.fixture() @pytest.mark.asyncio() async def topic_with_messages(driver, topic_path): - writer = driver.topic_client.writer(topic_path, producer_id="fixture-producer-id") + writer = driver.topic_client.writer( + topic_path, producer_id="fixture-producer-id", codec=ydb.TopicCodec.RAW + ) await writer.write_with_ack( [ ydb.TopicWriterMessage(data="123".encode()), diff --git a/tests/topics/test_topic_writer.py b/tests/topics/test_topic_writer.py index 9e8b0dfe..c53ce0db 100644 --- a/tests/topics/test_topic_writer.py +++ b/tests/topics/test_topic_writer.py @@ -84,6 +84,20 @@ async def test_write_multi_message_with_ack( assert batch.messages[0].seqno == 2 assert batch.messages[0].data == "456".encode() + @pytest.mark.parametrize( + "codec", + [ + ydb.TopicCodec.RAW, + ydb.TopicCodec.GZIP, + None, + ], + ) + async def test_write_encoded(self, driver: ydb.Driver, topic_path: str, codec): + async with driver.topic_client.writer(topic_path, codec=codec) as writer: + writer.write("a" * 1000) + writer.write("b" * 1000) + writer.write("c" * 1000) + class TestTopicWriterSync: def test_send_message(self, driver_sync: ydb.Driver, topic_path): @@ -163,3 +177,17 @@ def test_write_multi_message_with_ack( assert batch.messages[0].offset == 1 assert batch.messages[0].seqno == 2 assert batch.messages[0].data == "456".encode() + + @pytest.mark.parametrize( + "codec", + [ + ydb.TopicCodec.RAW, + ydb.TopicCodec.GZIP, + None, + ], + ) + def test_write_encoded(self, driver_sync: ydb.Driver, topic_path: str, codec): + with driver_sync.topic_client.writer(topic_path, codec=codec) as writer: + writer.write("a" * 1000) + writer.write("b" * 1000) + writer.write("c" * 1000) diff --git a/ydb/_grpc/grpcwrapper/ydb_topic_public_types.py b/ydb/_grpc/grpcwrapper/ydb_topic_public_types.py index 6d922137..4582f19a 100644 --- a/ydb/_grpc/grpcwrapper/ydb_topic_public_types.py +++ b/ydb/_grpc/grpcwrapper/ydb_topic_public_types.py @@ -31,11 +31,18 @@ class CreateTopicRequestParams: class PublicCodec(int): + """ + Codec value may contain any int number. + + Values below is only well-known predefined values, + but protocol support custom codecs. + """ + UNSPECIFIED = 0 RAW = 1 GZIP = 2 - LZOP = 3 - ZSTD = 4 + LZOP = 3 # Has not supported codec in standard library + ZSTD = 4 # Has not supported codec in standard library class PublicMeteringMode(IntEnum): diff --git a/ydb/_topic_writer/topic_writer.py b/ydb/_topic_writer/topic_writer.py index a2c3d0d7..2c9807d9 100644 --- a/ydb/_topic_writer/topic_writer.py +++ b/ydb/_topic_writer/topic_writer.py @@ -1,3 +1,4 @@ +import concurrent.futures import datetime import enum import uuid @@ -10,7 +11,7 @@ import ydb.aio from .._grpc.grpcwrapper.ydb_topic import Codec, StreamWriteMessage from .._grpc.grpcwrapper.common_utils import IToProto - +from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec MessageType = typing.Union["PublicMessage", "PublicMessage.SimpleMessageSourceType"] @@ -29,6 +30,10 @@ class PublicWriterSettings: partition_id: Optional[int] = None auto_seqno: bool = True auto_created_at: bool = True + codec: Optional[PublicCodec] = None # default mean auto-select + encoder_executor: Optional[ + concurrent.futures.Executor + ] = None # default shared client executor pool # get_last_seqno: bool = False # encoders: Union[Mapping[int, Callable[[bytes], bytes]], None] = None # serializer: Union[Callable[[Any], bytes], None] = None @@ -85,8 +90,9 @@ class SendMode(Enum): @dataclass class PublicWriterInitInfo: - __slots__ = "last_seqno" + __slots__ = ("last_seqno", "supported_codecs") last_seqno: Optional[int] + supported_codecs: List[PublicCodec] class PublicMessage: @@ -117,15 +123,17 @@ def _create_message( class InternalMessage(StreamWriteMessage.WriteRequest.MessageData, IToProto): + codec: PublicCodec + def __init__(self, mess: PublicMessage): - StreamWriteMessage.WriteRequest.MessageData.__init__( - self, + super().__init__( seq_no=mess.seqno, created_at=mess.created_at, data=mess.data, uncompressed_size=len(mess.data), partitioning=None, ) + self.codec = PublicCodec.RAW def get_bytes(self) -> bytes: if self.data is None: diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 67b1be69..f658efe3 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -1,7 +1,10 @@ import asyncio +import concurrent.futures import datetime +import gzip +import typing from collections import deque -from typing import Deque, AsyncIterator, Union, List +from typing import Deque, AsyncIterator, Union, List, Optional, Dict, Callable import ydb from .topic_writer import ( @@ -22,6 +25,7 @@ check_retriable_error, RetrySettings, ) +from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec from .._topic_common.common import ( TokenGetterFuncType, ) @@ -41,6 +45,7 @@ class WriterAsyncIO: _loop: asyncio.AbstractEventLoop _reconnector: "WriterAsyncIOReconnector" _closed: bool + _compressor_thread_pool: concurrent.futures.Executor @property def last_seqno(self) -> int: @@ -107,11 +112,11 @@ async def write_with_ack_future( For wait with timeout use asyncio.wait_for. """ input_single_message = not isinstance(messages, list) - if isinstance(messages, PublicMessage): - messages = [PublicMessage._create_message(messages)] if isinstance(messages, list): for index, m in enumerate(messages): messages[index] = PublicMessage._create_message(m) + else: + messages = [PublicMessage._create_message(messages)] futures = await self._reconnector.write_with_ack_future(messages) if input_single_message: @@ -152,7 +157,7 @@ async def wait_init(self) -> PublicWriterInitInfo: class WriterAsyncIOReconnector: _closed: bool _loop: asyncio.AbstractEventLoop - _credentials: Union[ydb.Credentials, None] + _credentials: Union[ydb.credentials.Credentials, None] _driver: ydb.aio.Driver _update_token_interval: int _token_get_function: TokenGetterFuncType @@ -160,8 +165,18 @@ class WriterAsyncIOReconnector: _init_info: asyncio.Future _stream_connected: asyncio.Event _settings: WriterSettings + _codec: PublicCodec + _codec_functions: Dict[PublicCodec, Callable[[bytes], bytes]] + _encode_executor: Optional[concurrent.futures.Executor] + _codec_selector_batch_num: int + _codec_selector_last_codec: Optional[PublicCodec] + _codec_selector_check_batches_interval: int _last_known_seq_no: int + if typing.TYPE_CHECKING: + _messages_for_encode: asyncio.Queue[List[InternalMessage]] + else: + _messages_for_encode: asyncio.Queue _messages: Deque[InternalMessage] _messages_future: Deque[asyncio.Future] _new_messages: asyncio.Queue @@ -179,13 +194,34 @@ def __init__(self, driver: SupportedDriverType, settings: WriterSettings): self._stream_connected = asyncio.Event() self._settings = settings + self._codec_functions = { + PublicCodec.RAW: lambda data: data, + PublicCodec.GZIP: gzip.compress, + } + self._encode_executor = settings.encoder_executor + + self._codec_selector_batch_num = 0 + self._codec_selector_last_codec = None + self._codec_selector_check_batches_interval = 10000 + + self._codec = self._settings.codec + if self._codec and self._codec not in self._codec_functions: + known_codecs = [key for key in self._codec_functions] + known_codecs.sort() + raise ValueError( + "Unknown codec for writer: %s, supported codecs: %s" + % (self._codec, known_codecs) + ) + self._last_known_seq_no = 0 + self._messages_for_encode = asyncio.Queue() self._messages = deque() self._messages_future = deque() self._new_messages = asyncio.Queue() self._stop_reason = self._loop.create_future() self._background_tasks = [ - asyncio.create_task(self._connection_loop(), name="connection_loop") + asyncio.create_task(self._connection_loop(), name="connection_loop"), + asyncio.create_task(self._encode_loop(), name="encode_loop"), ] async def close(self, flush: bool): @@ -238,15 +274,23 @@ async def write_with_ack_future( internal_messages = self._prepare_internal_messages(messages) messages_future = [self._loop.create_future() for _ in internal_messages] - self._messages.extend(internal_messages) self._messages_future.extend(messages_future) - for m in internal_messages: - self._new_messages.put_nowait(m) + if self._codec == PublicCodec.RAW: + self._add_messages_to_send_queue(internal_messages) + else: + self._messages_for_encode.put_nowait(internal_messages) return messages_future - def _prepare_internal_messages(self, messages: List[PublicMessage]): + def _add_messages_to_send_queue(self, internal_messages: List[InternalMessage]): + self._messages.extend(internal_messages) + for m in internal_messages: + self._new_messages.put_nowait(m) + + def _prepare_internal_messages( + self, messages: List[PublicMessage] + ) -> List[InternalMessage]: if self._settings.auto_created_at: now = datetime.datetime.now() else: @@ -307,7 +351,10 @@ async def _connection_loop(self): try: self._last_known_seq_no = stream_writer.last_seqno self._init_info.set_result( - PublicWriterInitInfo(last_seqno=stream_writer.last_seqno) + PublicWriterInitInfo( + last_seqno=stream_writer.last_seqno, + supported_codecs=stream_writer.supported_codecs, + ) ) except asyncio.InvalidStateError: pass @@ -350,6 +397,118 @@ async def _connection_loop(self): task.cancel() await asyncio.wait(pending) + async def _encode_loop(self): + while True: + messages = await self._messages_for_encode.get() + while not self._messages_for_encode.empty(): + messages.extend(self._messages_for_encode.get_nowait()) + + batch_codec = await self._codec_selector(messages) + await self._encode_data_inplace(batch_codec, messages) + self._add_messages_to_send_queue(messages) + + async def _encode_data_inplace( + self, codec: PublicCodec, messages: List[InternalMessage] + ): + if codec == PublicCodec.RAW: + return + + eventloop = asyncio.get_running_loop() + encode_waiters = [] + encoder_function = self._codec_functions[codec] + + for message in messages: + encoded_data_futures = eventloop.run_in_executor( + self._encode_executor, encoder_function, message.get_bytes() + ) + encode_waiters.append(encoded_data_futures) + + encoded_datas = await asyncio.gather(*encode_waiters) + + for index, data in enumerate(encoded_datas): + message = messages[index] + message.codec = codec + message.data = data + + async def _codec_selector(self, messages: List[InternalMessage]) -> PublicCodec: + if self._codec is not None: + return self._codec + + if self._codec_selector_last_codec is None: + available_codecs = await self._get_available_codecs() + + if self._codec_selector_batch_num < len(available_codecs): + codec_index = self._codec_selector_batch_num % len(available_codecs) + codec = available_codecs[codec_index] + else: + codec = await self._codec_selector_by_check_compress(messages) + self._codec_selector_last_codec = codec + else: + if ( + self._codec_selector_batch_num + % self._codec_selector_check_batches_interval + == 0 + ): + self._codec_selector_last_codec = ( + await self._codec_selector_by_check_compress(messages) + ) + codec = self._codec_selector_last_codec + self._codec_selector_batch_num += 1 + return codec + + async def _get_available_codecs(self) -> List[PublicCodec]: + info = await self.wait_init() + topic_supported_codecs = info.supported_codecs + if not topic_supported_codecs: + topic_supported_codecs = [PublicCodec.RAW, PublicCodec.GZIP] + + res = [] + for codec in topic_supported_codecs: + if codec in self._codec_functions: + res.append(codec) + + if not res: + raise TopicWriterError("Writer does not support topic's codecs") + + res.sort() + + return res + + async def _codec_selector_by_check_compress( + self, messages: List[InternalMessage] + ) -> PublicCodec: + test_messages = messages + if len(test_messages) > 10: + test_messages = test_messages[:10] + + available_codecs = await self._get_available_codecs() + if len(available_codecs) == 1: + return available_codecs[0] + + def get_compressed_size(codec) -> int: + s = 0 + f = self._codec_functions[codec] + + for m in test_messages: + encoded = f(m.get_bytes()) + s += len(encoded) + + return s + + def select_codec() -> PublicCodec: + min_codec = available_codecs[0] + min_size = get_compressed_size(min_codec) + for codec in available_codecs[1:]: + size = get_compressed_size(codec) + if size < min_size: + min_codec = codec + min_size = size + return min_codec + + loop = asyncio.get_running_loop() + codec = await loop.run_in_executor(self._encode_executor, select_codec) + return codec + async def _read_loop(self, writer: "WriterAsyncIOStream"): while True: resp = await writer.receive() @@ -412,6 +571,7 @@ class WriterAsyncIOStream: # todo slots last_seqno: int + supported_codecs: Optional[List[PublicCodec]] _stream: IGrpcWrapperAsyncIO _token_getter: TokenGetterFuncType @@ -466,6 +626,7 @@ async def _start( raise TopicWriterError("Unexpected answer for init request: %s" % resp) self.last_seqno = resp.last_seq_no + self.supported_codecs = [PublicCodec(codec) for codec in resp.supported_codecs] self._stream = stream diff --git a/ydb/_topic_writer/topic_writer_asyncio_test.py b/ydb/_topic_writer/topic_writer_asyncio_test.py index 32bb3de5..abb1f427 100644 --- a/ydb/_topic_writer/topic_writer_asyncio_test.py +++ b/ydb/_topic_writer/topic_writer_asyncio_test.py @@ -4,14 +4,15 @@ import copy import dataclasses import datetime +import gzip import typing from queue import Queue, Empty +from typing import List from unittest import mock import freezegun import pytest - from .. import aio from .. import StatusCode, issues from .._grpc.grpcwrapper.ydb_topic import Codec, StreamWriteMessage @@ -25,6 +26,7 @@ PublicWriteResult, TopicWriterError, ) +from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec from .._topic_common.test_helpers import StreamMock from .topic_writer_asyncio import ( @@ -151,9 +153,11 @@ async def test_write_a_message(self, writer_and_stream: WriterWithMockedStream): @pytest.mark.asyncio class TestWriterAsyncIOReconnector: init_last_seqno = 0 + time_for_mocks = 1678046714.639387 class StreamWriterMock: last_seqno: int + supported_codecs: List[PublicCodec] from_client: asyncio.Queue from_server: asyncio.Queue @@ -165,6 +169,7 @@ def __init__(self): self.from_server = asyncio.Queue() self.from_client = asyncio.Queue() self._closed = False + self.supported_codecs = [] def write(self, messages: typing.List[InternalMessage]): if self._closed: @@ -184,10 +189,7 @@ async def receive(self) -> StreamWriteMessage.WriteResponse: def close(self): if self._closed: return - - self.from_server.put_nowait( - Exception("waited message while StreamWriterMock closed") - ) + self._closed = True @pytest.fixture(autouse=True) async def stream_writer_double_queue(self, monkeypatch): @@ -241,6 +243,7 @@ def default_settings(self) -> WriterSettings: producer_id="test-producer", auto_seqno=False, auto_created_at=False, + codec=PublicCodec.RAW, ) ) @@ -347,7 +350,9 @@ async def wait_stop(): async def test_wait_init(self, default_driver, default_settings, get_stream_writer): init_seqno = 100 - expected_init_info = PublicWriterInitInfo(init_seqno) + expected_init_info = PublicWriterInitInfo( + last_seqno=init_seqno, supported_codecs=[] + ) with mock.patch.object( TestWriterAsyncIOReconnector, "init_last_seqno", init_seqno ): @@ -457,6 +462,116 @@ async def test_auto_created_at( ] == sent await reconnector.close(flush=False) + @pytest.mark.parametrize( + "codec,write_datas,expected_codecs,expected_datas", + [ + ( + PublicCodec.RAW, + [b"123"], + [PublicCodec.RAW], + [b"123"], + ), + ( + PublicCodec.GZIP, + [b"123"], + [PublicCodec.GZIP], + [gzip.compress(b"123", mtime=time_for_mocks)], + ), + ( + None, + [b"123", b"456", b"789", b"0" * 1000], + [PublicCodec.RAW, PublicCodec.GZIP, PublicCodec.RAW, PublicCodec.RAW], + [ + b"123", + gzip.compress(b"456", mtime=time_for_mocks), + b"789", + b"0" * 1000, + ], + ), + ( + None, + [b"123", b"456", b"789" * 1000, b"0"], + [PublicCodec.RAW, PublicCodec.GZIP, PublicCodec.GZIP, PublicCodec.GZIP], + [ + b"123", + gzip.compress(b"456", mtime=time_for_mocks), + gzip.compress(b"789" * 1000, mtime=time_for_mocks), + gzip.compress(b"0", mtime=time_for_mocks), + ], + ), + ], + ) + async def test_select_codecs( + self, + default_driver: aio.Driver, + default_settings: WriterSettings, + monkeypatch, + write_datas: List[typing.Optional[bytes]], + codec: typing.Optional[PublicCodec], + expected_codecs: List[PublicCodec], + expected_datas: List[bytes], + ): + assert len(write_datas) == len(expected_datas) + assert len(expected_codecs) == len(expected_datas) + + settings = copy.copy(default_settings) + settings.codec = codec + settings.auto_seqno = True + reconnector = WriterAsyncIOReconnector(default_driver, settings) + + added_messages = asyncio.Queue() # type: asyncio.Queue[List[InternalMessage]] + + def add_messages(_self, messages: typing.List[InternalMessage]): + added_messages.put_nowait(messages) + + monkeypatch.setattr( + WriterAsyncIOReconnector, "_add_messages_to_send_queue", add_messages + ) + monkeypatch.setattr( + "time.time", lambda: TestWriterAsyncIOReconnector.time_for_mocks + ) + + for i in range(len(expected_datas)): + await reconnector.write_with_ack_future( + [PublicMessage(data=write_datas[i])] + ) + mess = await asyncio.wait_for(added_messages.get(), timeout=600) + mess = mess[0] + + assert mess.codec == expected_codecs[i] + assert mess.get_bytes() == expected_datas[i] + + await reconnector.close(flush=False) + + @pytest.mark.parametrize( + "codec,datas", + [ + ( + PublicCodec.RAW, + [b"123", b"456", b"789", b"0"], + ), + ( + PublicCodec.GZIP, + [b"123", b"456", b"789", b"0"], + ), + ], + ) + async def test_encode_data_inplace( + self, + reconnector: WriterAsyncIOReconnector, + codec: PublicCodec, + datas: List[bytes], + ): + f = reconnector._codec_functions[codec] + expected_datas = [f(data) for data in datas] + + messages = [InternalMessage(PublicMessage(data)) for data in datas] + await reconnector._encode_data_inplace(codec, messages) + + for index, mess in enumerate(messages): + assert mess.codec == codec + assert mess.get_bytes() == expected_datas[index] + @pytest.mark.asyncio class TestWriterAsyncIO: diff --git a/ydb/_topic_writer/topic_writer_sync.py b/ydb/_topic_writer/topic_writer_sync.py index dc7b7fbd..48649875 100644 --- a/ydb/_topic_writer/topic_writer_sync.py +++ b/ydb/_topic_writer/topic_writer_sync.py @@ -9,7 +9,6 @@ PublicWriterSettings, TopicWriterError, PublicWriterInitInfo, - PublicMessage, PublicWriteResult, MessageType, ) @@ -92,7 +91,7 @@ def wait_init(self, timeout: Optional[TimeoutType] = None) -> PublicWriterInitIn def write( self, - messages: Union[PublicMessage, List[PublicMessage]], + messages: Union[MessageType, List[MessageType]], timeout: Union[float, None] = None, ): self._call_sync(self._async_writer.write(messages), timeout=timeout) diff --git a/ydb/topic.py b/ydb/topic.py index 45b8b073..1ed60686 100644 --- a/ydb/topic.py +++ b/ydb/topic.py @@ -1,7 +1,11 @@ +from __future__ import annotations + +import concurrent.futures import datetime +from dataclasses import dataclass from typing import List, Union, Mapping, Optional, Dict -from . import aio, Credentials, _apis +from . import aio, Credentials, _apis, issues from . import driver @@ -43,11 +47,25 @@ class TopicClientAsyncIO: + _closed: bool _driver: aio.Driver _credentials: Union[Credentials, None] + _settings: TopicClientSettings + _executor: concurrent.futures.Executor - def __init__(self, driver: aio.Driver, settings: "TopicClientSettings" = None): + def __init__(self, driver: aio.Driver, settings: Optional[TopicClientSettings]): + if not settings: + settings = TopicClientSettings() + self._closed = False self._driver = driver + self._settings = settings + self._executor = concurrent.futures.ThreadPoolExecutor( + max_workers=settings.encode_decode_threads_count, + thread_name_prefix="topic_asyncio_executor", + ) + + def __del__(self): + self.close() async def create_topic( self, @@ -148,21 +166,56 @@ def writer( partition_id: Union[int, None] = None, auto_seqno: bool = True, auto_created_at: bool = True, + codec: Optional[TopicCodec] = None, # default mean auto-select + encoder_executor: Optional[ + concurrent.futures.Executor + ] = None, # default shared client executor pool ) -> TopicWriterAsyncIO: args = locals() del args["self"] + settings = TopicWriterSettings(**args) + + if not settings.encoder_executor: + settings.encoder_executor = self._executor + return TopicWriterAsyncIO(self._driver, settings) + def close(self): + if self._closed: + return + + self._closed = True + self._executor.shutdown(wait=False, cancel_futures=True) + + def _check_closed(self): + if not self._closed: + return + + raise RuntimeError("Topic client closed") + class TopicClient: + _closed: bool _driver: driver.Driver _credentials: Union[Credentials, None] + _settings: TopicClientSettings + _executor: concurrent.futures.Executor - def __init__( - self, driver: driver.Driver, topic_client_settings: "TopicClientSettings" = None - ): + def __init__(self, driver: driver.Driver, settings: Optional[TopicClientSettings]): + if not settings: + settings = TopicClientSettings() + + self._closed = False self._driver = driver + self._settings = settings + self._executor = concurrent.futures.ThreadPoolExecutor( + max_workers=settings.encode_decode_threads_count, + thread_name_prefix="topic_asyncio_executor", + ) + + def __del__(self): + self.close() def create_topic( self, @@ -198,6 +251,8 @@ def create_topic( """ args = locals().copy() del args["self"] + self._check_closed() + req = _ydb_topic_public_types.CreateTopicRequestParams(**args) req = _ydb_topic.CreateTopicRequest.from_public(req) self._driver( @@ -212,6 +267,8 @@ def describe_topic( ) -> TopicDescription: args = locals().copy() del args["self"] + self._check_closed() + req = _ydb_topic_public_types.DescribeTopicRequestParams(**args) res = self._driver( req.to_proto(), @@ -222,6 +279,8 @@ def describe_topic( return res.to_public() def drop_topic(self, path: str): + self._check_closed() + req = _ydb_topic_public_types.DropTopicRequestParams(path=path) self._driver( req.to_proto(), @@ -251,6 +310,8 @@ def reader( ) -> TopicReader: args = locals() del args["self"] + self._check_closed() + settings = TopicReaderSettings(**args) return TopicReader(self._driver, settings) @@ -263,16 +324,40 @@ def writer( partition_id: Union[int, None] = None, auto_seqno: bool = True, auto_created_at: bool = True, + codec: Optional[TopicCodec] = None, # default mean auto-select + encoder_executor: Optional[ + concurrent.futures.Executor + ] = None, # default shared client executor pool ) -> TopicWriter: args = locals() del args["self"] + self._check_closed() + settings = TopicWriterSettings(**args) + + if not settings.encoder_executor: + settings.encoder_executor = self._executor + return TopicWriter(self._driver, settings) + def close(self): + if self._closed: + return + + self._closed = True + self._executor.shutdown(wait=False, cancel_futures=True) + + def _check_closed(self): + if not self._closed: + return + raise RuntimeError("Topic client closed") + + +@dataclass class TopicClientSettings: - pass + encode_decode_threads_count: int = 4 -class StubEvent: +class TopicError(issues.Error): pass From 1104d9d45b329e6bf5f6a45519dea396f55497bc Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 7 Mar 2023 09:59:14 +0100 Subject: [PATCH 090/147] add custom encoders --- ydb/_topic_writer/topic_writer.py | 4 ++- ydb/_topic_writer/topic_writer_asyncio.py | 5 +++ .../topic_writer_asyncio_test.py | 31 ++++++++++++++++++- ydb/topic.py | 8 ++++- 4 files changed, 45 insertions(+), 3 deletions(-) diff --git a/ydb/_topic_writer/topic_writer.py b/ydb/_topic_writer/topic_writer.py index 2c9807d9..7bfb3f30 100644 --- a/ydb/_topic_writer/topic_writer.py +++ b/ydb/_topic_writer/topic_writer.py @@ -34,8 +34,10 @@ class PublicWriterSettings: encoder_executor: Optional[ concurrent.futures.Executor ] = None # default shared client executor pool + encoders: Optional[ + typing.Mapping[PublicCodec, typing.Callable[[bytes], bytes]] + ] = None # get_last_seqno: bool = False - # encoders: Union[Mapping[int, Callable[[bytes], bytes]], None] = None # serializer: Union[Callable[[Any], bytes], None] = None # send_buffer_count: Optional[int] = 10000 # send_buffer_bytes: Optional[int] = 100 * 1024 * 1024 diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index f658efe3..59daa48d 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -198,6 +198,11 @@ def __init__(self, driver: SupportedDriverType, settings: WriterSettings): PublicCodec.RAW: lambda data: data, PublicCodec.GZIP: gzip.compress, } + + if settings.encoders: + for codec, encoder in settings.encoders.items(): + self._codec_functions[codec] = encoder + self._encode_executor = settings.encoder_executor self._codec_selector_batch_num = 0 diff --git a/ydb/_topic_writer/topic_writer_asyncio_test.py b/ydb/_topic_writer/topic_writer_asyncio_test.py index abb1f427..921c6aa4 100644 --- a/ydb/_topic_writer/topic_writer_asyncio_test.py +++ b/ydb/_topic_writer/topic_writer_asyncio_test.py @@ -27,7 +27,7 @@ TopicWriterError, ) from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec -from .._topic_common.test_helpers import StreamMock +from .._topic_common.test_helpers import StreamMock, wait_for_fast from .topic_writer_asyncio import ( WriterAsyncIOStream, @@ -572,6 +572,35 @@ async def test_encode_data_inplace( assert mess.codec == codec assert mess.get_bytes() == expected_datas[index] + async def test_custom_encoder( + self, default_driver, default_settings, get_stream_writer + ): + codec = 10001 + + settings = copy.copy(default_settings) + settings.encoders = {codec: lambda x: bytes(reversed(x))} + settings.codec = codec + reconnector = WriterAsyncIOReconnector(default_driver, settings) + + now = datetime.datetime.now() + seqno = self.init_last_seqno + 1 + + await reconnector.write_with_ack_future( + [PublicMessage(data=b"123", seqno=seqno, created_at=now)] + ) + + stream_writer = get_stream_writer() + sent_messages = await wait_for_fast(stream_writer.from_client.get()) + + expected_mess = InternalMessage( + PublicMessage(data=b"321", seqno=seqno, created_at=now) + ) + expected_mess.codec = codec + + assert sent_messages == [expected_mess] + + await reconnector.close(flush=False) + @pytest.mark.asyncio class TestWriterAsyncIO: diff --git a/ydb/topic.py b/ydb/topic.py index 1ed60686..a4162208 100644 --- a/ydb/topic.py +++ b/ydb/topic.py @@ -3,7 +3,7 @@ import concurrent.futures import datetime from dataclasses import dataclass -from typing import List, Union, Mapping, Optional, Dict +from typing import List, Union, Mapping, Optional, Dict, Callable from . import aio, Credentials, _apis, issues @@ -167,6 +167,9 @@ def writer( auto_seqno: bool = True, auto_created_at: bool = True, codec: Optional[TopicCodec] = None, # default mean auto-select + encoders: Optional[ + Mapping[_ydb_topic_public_types.PublicCodec, Callable[[bytes], bytes]] + ] = None, encoder_executor: Optional[ concurrent.futures.Executor ] = None, # default shared client executor pool @@ -325,6 +328,9 @@ def writer( auto_seqno: bool = True, auto_created_at: bool = True, codec: Optional[TopicCodec] = None, # default mean auto-select + encoders: Optional[ + Mapping[_ydb_topic_public_types.PublicCodec, Callable[[bytes], bytes]] + ] = None, encoder_executor: Optional[ concurrent.futures.Executor ] = None, # default shared client executor pool From a537c65760a45d78f2d4371dbccc54c6cb315458 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 7 Mar 2023 10:21:48 +0100 Subject: [PATCH 091/147] add comment --- ydb/_topic_writer/topic_writer_asyncio.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 59daa48d..5c337448 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -442,6 +442,8 @@ async def _codec_selector(self, messages: List[InternalMessage]) -> PublicCodec: if self._codec_selector_last_codec is None: available_codecs = await self._get_available_codecs() + # use every of available encoders at start for prevent problems + # with rare used encoders (on writer or reader side) if self._codec_selector_batch_num < len(available_codecs): codec_index = self._codec_selector_batch_num % len(available_codecs) codec = available_codecs[codec_index] @@ -482,6 +484,10 @@ async def _get_available_codecs(self) -> List[PublicCodec]: async def _codec_selector_by_check_compress( self, messages: List[InternalMessage] ) -> PublicCodec: + """ + Try to compress messages and choose codec with the smallest result size. + """ + test_messages = messages if len(test_messages) > 10: test_messages = test_messages[:10] From 51c9bd34aa4f06c3ab47a0edb4b29524bf30ae88 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 7 Mar 2023 10:59:11 +0100 Subject: [PATCH 092/147] add sync tx tests --- tests/table/test_tx.py | 91 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 tests/table/test_tx.py diff --git a/tests/table/test_tx.py b/tests/table/test_tx.py new file mode 100644 index 00000000..32d9b763 --- /dev/null +++ b/tests/table/test_tx.py @@ -0,0 +1,91 @@ +import pytest +from contextlib import suppress + +import ydb.iam + + +def test_tx_commit(driver_sync, database): + session = driver_sync.table_client.session().create() + prepared = session.prepare( + "DECLARE $param as Int32;\n SELECT $param as value", + ) + + tx = session.transaction() + tx.execute(prepared, {"$param": 2}) + tx.commit() + tx.commit() + + +def test_tx_rollback(driver_sync, database): + session = driver_sync.table_client.session().create() + prepared = session.prepare( + "DECLARE $param as Int32;\n SELECT $param as value", + ) + + tx = session.transaction() + tx.execute(prepared, {"$param": 2}) + tx.rollback() + tx.rollback() + + +def test_tx_begin(driver_sync, database): + session = driver_sync.table_client.session().create() + session.create() + + tx = session.transaction() + tx.begin() + tx.begin() + tx.rollback() + + +def test_credentials(): + credentials = ydb.iam.MetadataUrlCredentials() + raised = False + try: + credentials.auth_metadata() + except Exception: + raised = True + + assert raised + + +def test_tx_snapshot_ro(driver_sync, database): + session = driver_sync.table_client.session().create() + description = ( + ydb.TableDescription() + .with_primary_keys("key") + .with_columns( + ydb.Column("key", ydb.OptionalType(ydb.PrimitiveType.Uint64)), + ydb.Column("value", ydb.OptionalType(ydb.PrimitiveType.Uint64)), + ) + ) + tb_name = f"{database}/test" + with suppress(ydb.issues.SchemeError): + session.drop_table(tb_name) + session.create_table(tb_name, description) + session.transaction(ydb.SerializableReadWrite()).execute( + """INSERT INTO `test` (`key`, `value`) VALUES (1, 1), (2, 2)""", + commit_tx=True, + ) + + ro_tx = session.transaction(tx_mode=ydb.SnapshotReadOnly()) + data1 = ro_tx.execute("SELECT value FROM `test` WHERE key = 1") + + session.transaction(ydb.SerializableReadWrite()).execute( + "UPDATE `test` SET value = value + 1", commit_tx=True + ) + + data2 = ro_tx.execute("SELECT value FROM `test` WHERE key = 1") + assert data1[0].rows == data2[0].rows == [{"value": 1}] + + ro_tx.commit() + + with pytest.raises(ydb.issues.GenericError) as exc_info: + ro_tx.execute("UPDATE `test` SET value = value + 1") + assert "read only transaction" in exc_info.value.message + + data = session.transaction(tx_mode=ydb.SnapshotReadOnly()).execute( + "SELECT value FROM `test` WHERE key = 1", + commit_tx=True, + ) + assert data[0].rows == [{"value": 2}] From b0a17e12c5561797964cbf056b0ba848074c33ea Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 7 Mar 2023 12:21:46 +0100 Subject: [PATCH 093/147] Deny split transactions by default --- CHANGELOG.md | 2 ++ tests/aio/test_tx.py | 62 +++++++++++++++++++++++++++++++++ tests/conftest.py | 27 +++++++++++++++ tests/table/test_tx.py | 59 +++++++++++++++++++++++++++++++ ydb/aio/table.py | 13 +++++-- ydb/table.py | 79 +++++++++++++++++++++++++++++++++++++++--- 6 files changed, 235 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index decbb671..4cfa110d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,5 @@ +* BROKEN CHANGE: deny any action in transaction after commit/rollback + ## 3.0.1b6 ## * BROKEN CHANGES: remove writer.write(mess1, mess2) variant, use list instead: writer.write([mess1, mess2]) * BROKEN CHANGES: change names of public method in topic client diff --git a/tests/aio/test_tx.py b/tests/aio/test_tx.py index 2161ddeb..fae09909 100644 --- a/tests/aio/test_tx.py +++ b/tests/aio/test_tx.py @@ -85,6 +85,7 @@ async def test_tx_snapshot_ro(driver, database): await ro_tx.commit() + ro_tx = session.transaction(tx_mode=ydb.SnapshotReadOnly()) with pytest.raises(ydb.issues.GenericError) as exc_info: await ro_tx.execute("UPDATE `test` SET value = value + 1") assert "read only transaction" in exc_info.value.message @@ -94,3 +95,64 @@ async def test_tx_snapshot_ro(driver, database): commit_tx=True, ) assert data[0].rows == [{"value": 2}] + + +@pytest.mark.asyncio +async def test_split_transactions_deny_split(driver, table_name): + async with ydb.aio.SessionPool(driver, 1) as pool: + + async def check_transaction(s: ydb.aio.table.Session): + async with s.transaction(deny_split_transactions=True) as tx: + await tx.execute("INSERT INTO %s (id) VALUES (1)" % table_name) + await tx.commit() + + with pytest.raises(RuntimeError): + await tx.execute("INSERT INTO %s (id) VALUES (2)" % table_name) + + await tx.commit() + + async with s.transaction() as tx: + rs = await tx.execute("SELECT COUNT(*) as cnt FROM %s" % table_name) + assert rs[0].rows[0].cnt == 1 + + await pool.retry_operation(check_transaction) + + +@pytest.mark.asyncio +async def test_split_transactions_allow_split(driver, table_name): + async with ydb.aio.SessionPool(driver, 1) as pool: + + async def check_transaction(s: ydb.aio.table.Session): + async with s.transaction(deny_split_transactions=False) as tx: + await tx.execute("INSERT INTO %s (id) VALUES (1)" % table_name) + await tx.commit() + + await tx.execute("INSERT INTO %s (id) VALUES (2)" % table_name) + await tx.commit() + + async with s.transaction() as tx: + rs = await tx.execute("SELECT COUNT(*) as cnt FROM %s" % table_name) + assert rs[0].rows[0].cnt == 2 + + await pool.retry_operation(check_transaction) + + +@pytest.mark.asyncio +async def test_split_transactions_default(driver, table_name): + async with ydb.aio.SessionPool(driver, 1) as pool: + + async def check_transaction(s: ydb.aio.table.Session): + async with s.transaction() as tx: + await tx.execute("INSERT INTO %s (id) VALUES (1)" % table_name) + await tx.commit() + + with pytest.raises(RuntimeError): + await tx.execute("INSERT INTO %s (id) VALUES (2)" % table_name) + + tx.commit() + + async with s.transaction() as tx: + rs = await tx.execute("SELECT COUNT(*) as cnt FROM %s" % table_name) + assert rs[0].rows[0].cnt == 1 + + await pool.retry_operation(check_transaction) diff --git a/tests/conftest.py b/tests/conftest.py index 99123660..e7809847 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -105,6 +105,33 @@ async def driver_sync(endpoint, database, event_loop): driver.stop(timeout=10) +@pytest.fixture() +def table_name(driver_sync, database): + table_name = "table" + + with ydb.SessionPool(driver_sync) as pool: + + def create_table(s): + try: + s.drop_table(database + "/" + table_name) + except ydb.SchemeError: + pass + + s.execute_scheme( + """ +CREATE TABLE %s ( +id Int64 NOT NULL, +i64Val Int64, +PRIMARY KEY(id) +) +""" + % table_name + ) + + pool.retry_operation_sync(create_table) + return table_name + + @pytest.fixture() def topic_consumer(): return "fixture-consumer" diff --git a/tests/table/test_tx.py b/tests/table/test_tx.py index 32d9b763..095fb72f 100644 --- a/tests/table/test_tx.py +++ b/tests/table/test_tx.py @@ -80,6 +80,7 @@ def test_tx_snapshot_ro(driver_sync, database): ro_tx.commit() + ro_tx = session.transaction(tx_mode=ydb.SnapshotReadOnly()) with pytest.raises(ydb.issues.GenericError) as exc_info: ro_tx.execute("UPDATE `test` SET value = value + 1") assert "read only transaction" in exc_info.value.message @@ -89,3 +90,61 @@ def test_tx_snapshot_ro(driver_sync, database): commit_tx=True, ) assert data[0].rows == [{"value": 2}] + + +def test_split_transactions_deny_split(driver_sync, table_name): + with ydb.SessionPool(driver_sync, 1) as pool: + + def check_transaction(s: ydb.table.Session): + with s.transaction(deny_split_transactions=True) as tx: + tx.execute("INSERT INTO %s (id) VALUES (1)" % table_name) + tx.commit() + + with pytest.raises(RuntimeError): + tx.execute("INSERT INTO %s (id) VALUES (2)" % table_name) + + tx.commit() + + with s.transaction() as tx: + rs = tx.execute("SELECT COUNT(*) as cnt FROM %s" % table_name) + assert rs[0].rows[0].cnt == 1 + + pool.retry_operation_sync(check_transaction) + + +def test_split_transactions_allow_split(driver_sync, table_name): + with ydb.SessionPool(driver_sync, 1) as pool: + + def check_transaction(s: ydb.table.Session): + with s.transaction(deny_split_transactions=False) as tx: + tx.execute("INSERT INTO %s (id) VALUES (1)" % table_name) + tx.commit() + + tx.execute("INSERT INTO %s (id) VALUES (2)" % table_name) + tx.commit() + + with s.transaction() as tx: + rs = tx.execute("SELECT COUNT(*) as cnt FROM %s" % table_name) + assert rs[0].rows[0].cnt == 2 + + pool.retry_operation_sync(check_transaction) + + +def test_split_transactions_default(driver_sync, table_name): + with ydb.SessionPool(driver_sync, 1) as pool: + + def check_transaction(s: ydb.table.Session): + with s.transaction() as tx: + tx.execute("INSERT INTO %s (id) VALUES (1)" % table_name) + tx.commit() + + with pytest.raises(RuntimeError): + tx.execute("INSERT INTO %s (id) VALUES (2)" % table_name) + + tx.commit() + + with s.transaction() as tx: + rs = tx.execute("SELECT COUNT(*) as cnt FROM %s" % table_name) + assert rs[0].rows[0].cnt == 1 + + pool.retry_operation_sync(check_transaction) diff --git a/ydb/aio/table.py b/ydb/aio/table.py index 9df797ea..95e2723d 100644 --- a/ydb/aio/table.py +++ b/ydb/aio/table.py @@ -120,8 +120,14 @@ async def alter_table( set_read_replicas_settings, ) - def transaction(self, tx_mode=None): - return TxContext(self._driver, self._state, self, tx_mode) + def transaction(self, tx_mode=None, *, deny_split_transactions=True): + return TxContext( + self._driver, + self._state, + self, + tx_mode, + deny_split_transactions=deny_split_transactions, + ) async def describe_table(self, path, settings=None): # pylint: disable=W0236 return await super().describe_table(path, settings) @@ -184,6 +190,9 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): async def execute( self, query, parameters=None, commit_tx=False, settings=None ): # pylint: disable=W0236 + + self._check_split() + return await super().execute(query, parameters, commit_tx, settings) async def commit(self, settings=None): # pylint: disable=W0236 diff --git a/ydb/table.py b/ydb/table.py index d60f138a..e2d9ad86 100644 --- a/ydb/table.py +++ b/ydb/table.py @@ -1173,7 +1173,7 @@ def execute_scheme(self, yql_text, settings=None): pass @abstractmethod - def transaction(self, tx_mode=None): + def transaction(self, tx_mode=None, deny_split_transactions=True): pass @abstractmethod @@ -1677,8 +1677,14 @@ def execute_scheme(self, yql_text, settings=None): self._state.endpoint, ) - def transaction(self, tx_mode=None): - return TxContext(self._driver, self._state, self, tx_mode) + def transaction(self, tx_mode=None, deny_split_transactions=True): + return TxContext( + self._driver, + self._state, + self, + tx_mode, + deny_split_transactions=deny_split_transactions, + ) def has_prepared(self, query): return query in self._state @@ -2189,9 +2195,27 @@ def begin(self, settings=None): class BaseTxContext(ITxContext): - __slots__ = ("_tx_state", "_session_state", "_driver", "session") + __slots__ = ( + "_tx_state", + "_session_state", + "_driver", + "session", + "_finished", + "_deny_split_transactions", + ) - def __init__(self, driver, session_state, session, tx_mode=None): + _COMMIT = "commit" + _ROLLBACK = "rollback" + + def __init__( + self, + driver, + session_state, + session, + tx_mode=None, + *, + deny_split_transactions=True + ): """ An object that provides a simple transaction context manager that allows statements execution in a transaction. You don't have to open transaction explicitly, because context manager encapsulates @@ -2214,6 +2238,8 @@ def __init__(self, driver, session_state, session, tx_mode=None): self._tx_state = _tx_ctx_impl.TxState(tx_mode) self._session_state = session_state self.session = session + self._finished = "" + self._deny_split_transactions = deny_split_transactions def __enter__(self): """ @@ -2271,6 +2297,9 @@ def execute(self, query, parameters=None, commit_tx=False, settings=None): :return: A result sets or exception in case of execution errors """ + + self._check_split() + return self._driver( _tx_ctx_impl.execute_request_factory( self._session_state, @@ -2297,8 +2326,12 @@ def commit(self, settings=None): :return: A committed transaction or exception if commit is failed """ + + self._set_finish(self._COMMIT) + if self._tx_state.tx_id is None and not self._tx_state.dead: return self + return self._driver( _tx_ctx_impl.commit_request_factory(self._session_state, self._tx_state), _apis.TableService.Stub, @@ -2318,8 +2351,12 @@ def rollback(self, settings=None): :return: A rolled back transaction or exception if rollback is failed """ + + self._set_finish(self._ROLLBACK) + if self._tx_state.tx_id is None and not self._tx_state.dead: return self + return self._driver( _tx_ctx_impl.rollback_request_factory(self._session_state, self._tx_state), _apis.TableService.Stub, @@ -2340,6 +2377,9 @@ def begin(self, settings=None): """ if self._tx_state.tx_id is not None: return self + + self._check_split() + return self._driver( _tx_ctx_impl.begin_request_factory(self._session_state, self._tx_state), _apis.TableService.Stub, @@ -2350,6 +2390,21 @@ def begin(self, settings=None): self._session_state.endpoint, ) + def _set_finish(self, val): + self._check_split(val) + self._finished = val + + def _check_split(self, allow=""): + """ + Deny all operaions with transaction after commit/rollback. + Exception: double commit and double rollbacks, because it is safe + """ + if not self._deny_split_transactions: + return + + if self._finished != "" and self._finished != allow: + raise RuntimeError("Any operation with finished transaction is denied") + class TxContext(BaseTxContext): @_utilities.wrap_async_call_exceptions @@ -2365,6 +2420,9 @@ def async_execute(self, query, parameters=None, commit_tx=False, settings=None): :return: A future of query execution """ + + self._check_split() + return self._driver.future( _tx_ctx_impl.execute_request_factory( self._session_state, @@ -2396,8 +2454,12 @@ def async_commit(self, settings=None): :return: A future of commit call """ + self._check_split() + self._finished = True + if self._tx_state.tx_id is None and not self._tx_state.dead: return _utilities.wrap_result_in_future(self) + return self._driver.future( _tx_ctx_impl.commit_request_factory(self._session_state, self._tx_state), _apis.TableService.Stub, @@ -2418,8 +2480,12 @@ def async_rollback(self, settings=None): :return: A future of rollback call """ + self._check_split() + self._finished = True + if self._tx_state.tx_id is None and not self._tx_state.dead: return _utilities.wrap_result_in_future(self) + return self._driver.future( _tx_ctx_impl.rollback_request_factory(self._session_state, self._tx_state), _apis.TableService.Stub, @@ -2441,6 +2507,9 @@ def async_begin(self, settings=None): """ if self._tx_state.tx_id is not None: return _utilities.wrap_result_in_future(self) + + self._check_split() + return self._driver.future( _tx_ctx_impl.begin_request_factory(self._session_state, self._tx_state), _apis.TableService.Stub, From 7bcee518deef5c452f5dd113130f747c8618b4a4 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 7 Mar 2023 13:06:18 +0100 Subject: [PATCH 094/147] fix commit/rollback markers fix typos --- tests/aio/test_tx.py | 2 +- ydb/table.py | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/aio/test_tx.py b/tests/aio/test_tx.py index fae09909..be5c6806 100644 --- a/tests/aio/test_tx.py +++ b/tests/aio/test_tx.py @@ -149,7 +149,7 @@ async def check_transaction(s: ydb.aio.table.Session): with pytest.raises(RuntimeError): await tx.execute("INSERT INTO %s (id) VALUES (2)" % table_name) - tx.commit() + await tx.commit() async with s.transaction() as tx: rs = await tx.execute("SELECT COUNT(*) as cnt FROM %s" % table_name) diff --git a/ydb/table.py b/ydb/table.py index e2d9ad86..eaee78ec 100644 --- a/ydb/table.py +++ b/ydb/table.py @@ -2454,8 +2454,7 @@ def async_commit(self, settings=None): :return: A future of commit call """ - self._check_split() - self._finished = True + self._set_finish(self._COMMIT) if self._tx_state.tx_id is None and not self._tx_state.dead: return _utilities.wrap_result_in_future(self) @@ -2480,8 +2479,7 @@ def async_rollback(self, settings=None): :return: A future of rollback call """ - self._check_split() - self._finished = True + self._set_finish(self._ROLLBACK) if self._tx_state.tx_id is None and not self._tx_state.dead: return _utilities.wrap_result_in_future(self) From a123228ae3dcedccba042a17bf7a11fdbcb87d83 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 7 Mar 2023 15:22:17 +0100 Subject: [PATCH 095/147] truncated response raise exception by default --- tests/aio/test_tx.py | 43 ++++++++++++++++++++++++++++++++++++++++++ tests/conftest.py | 5 +++++ tests/table/test_tx.py | 43 ++++++++++++++++++++++++++++++++++++++++++ ydb/convert.py | 10 +++++++++- ydb/issues.py | 4 ++++ ydb/table.py | 6 ++++++ 6 files changed, 110 insertions(+), 1 deletion(-) diff --git a/tests/aio/test_tx.py b/tests/aio/test_tx.py index be5c6806..da66769e 100644 --- a/tests/aio/test_tx.py +++ b/tests/aio/test_tx.py @@ -156,3 +156,46 @@ async def check_transaction(s: ydb.aio.table.Session): assert rs[0].rows[0].cnt == 1 await pool.retry_operation(check_transaction) + + +@pytest.mark.asyncio +async def test_truncated_response(driver, table_name, table_path): + column_types = ydb.BulkUpsertColumns().add_column("id", ydb.PrimitiveType.Int64) + + rows = [] + + rows_count = 1100 + for i in range(rows_count): + rows.append({"id": i}) + + await driver.table_client.bulk_upsert(table_path, rows, column_types) + + table_client = driver.table_client # default table client with driver's settings + s = table_client.session() + await s.create() + t = s.transaction() + with pytest.raises(ydb.TruncatedResponseError): + await t.execute("SELECT * FROM %s" % table_name) + + +@pytest.mark.asyncio +async def test_truncated_response_allow(driver, table_name, table_path): + column_types = ydb.BulkUpsertColumns().add_column("id", ydb.PrimitiveType.Int64) + + rows = [] + + rows_count = 1100 + for i in range(rows_count): + rows.append({"id": i}) + + await driver.table_client.bulk_upsert(table_path, rows, column_types) + + table_client = ydb.TableClient( + driver, ydb.TableClientSettings().with_allow_truncated_result(True) + ) + s = table_client.session() + await s.create() + t = s.transaction() + result = await t.execute("SELECT * FROM %s" % table_name) + assert result[0].truncated + assert len(result[0].rows) == 1000 diff --git a/tests/conftest.py b/tests/conftest.py index e7809847..675ef7b6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -132,6 +132,11 @@ def create_table(s): return table_name +@pytest.fixture() +def table_path(database, table_name) -> str: + return database + "/" + table_name + + @pytest.fixture() def topic_consumer(): return "fixture-consumer" diff --git a/tests/table/test_tx.py b/tests/table/test_tx.py index 095fb72f..bd703fa8 100644 --- a/tests/table/test_tx.py +++ b/tests/table/test_tx.py @@ -148,3 +148,46 @@ def check_transaction(s: ydb.table.Session): assert rs[0].rows[0].cnt == 1 pool.retry_operation_sync(check_transaction) + + +def test_truncated_response(driver_sync, table_name, table_path): + column_types = ydb.BulkUpsertColumns().add_column("id", ydb.PrimitiveType.Int64) + + rows = [] + + rows_count = 1100 + for i in range(rows_count): + rows.append({"id": i}) + + driver_sync.table_client.bulk_upsert(table_path, rows, column_types) + + table_client = ( + driver_sync.table_client + ) # default table client with driver's settings + s = table_client.session() + s.create() + t = s.transaction() + with pytest.raises(ydb.TruncatedResponseError): + t.execute("SELECT * FROM %s" % table_name) + + +def test_truncated_response_allow(driver_sync, table_name, table_path): + column_types = ydb.BulkUpsertColumns().add_column("id", ydb.PrimitiveType.Int64) + + rows = [] + + rows_count = 1100 + for i in range(rows_count): + rows.append({"id": i}) + + driver_sync.table_client.bulk_upsert(table_path, rows, column_types) + + table_client = ydb.TableClient( + driver_sync, ydb.TableClientSettings().with_allow_truncated_result(True) + ) + s = table_client.session() + s.create() + t = s.transaction() + result = t.execute("SELECT * FROM %s" % table_name) + assert result[0].truncated + assert len(result[0].rows) == 1000 diff --git a/ydb/convert.py b/ydb/convert.py index 70bc638e..567900a1 100644 --- a/ydb/convert.py +++ b/ydb/convert.py @@ -489,5 +489,13 @@ def __init__(self, result_sets_pb, table_client_settings=None): _ResultSet.from_message if not make_lazy else _ResultSet.lazy_from_message ) for result_set in result_sets_pb: - result_sets.append(initializer(result_set, table_client_settings)) + result_set = initializer(result_set, table_client_settings) + if ( + result_set.truncated + and not table_client_settings._allow_truncated_result + ): + raise issues.TruncatedResponseError( + "Response for the request was truncated by server" + ) + result_sets.append(result_set) super(ResultSets, self).__init__(result_sets) diff --git a/ydb/issues.py b/ydb/issues.py index 5a57f4d2..55c14cea 100644 --- a/ydb/issues.py +++ b/ydb/issues.py @@ -52,6 +52,10 @@ def __init__(self, message, issues=None): self.message = message +class TruncatedResponseError(Error): + status = None + + class ConnectionError(Error): status = None diff --git a/ydb/table.py b/ydb/table.py index eaee78ec..40431c62 100644 --- a/ydb/table.py +++ b/ydb/table.py @@ -1002,6 +1002,7 @@ def __init__(self): self._native_json_in_result_sets = False self._native_interval_in_result_sets = False self._native_timestamp_in_result_sets = False + self._allow_truncated_result = False def with_native_timestamp_in_result_sets(self, enabled): # type:(bool) -> ydb.TableClientSettings @@ -1038,6 +1039,11 @@ def with_lazy_result_sets(self, enabled): self._make_result_sets_lazy = enabled return self + def with_allow_truncated_result(self, enabled): + # type:(bool) -> ydb.TableClientSettings + self._allow_truncated_result = enabled + return self + class ScanQueryResult(object): def __init__(self, result, table_client_settings): From a66f0d0b9a236d69877692223d80aa2f84485221 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 7 Mar 2023 16:56:44 +0100 Subject: [PATCH 096/147] fix style --- ydb/_topic_writer/topic_writer_asyncio.py | 23 +++++++++-------------- ydb/topic.py | 4 +++- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 5c337448..1b3b3d25 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -45,7 +45,6 @@ class WriterAsyncIO: _loop: asyncio.AbstractEventLoop _reconnector: "WriterAsyncIOReconnector" _closed: bool - _compressor_thread_pool: concurrent.futures.Executor @property def last_seqno(self) -> int: @@ -112,13 +111,14 @@ async def write_with_ack_future( For wait with timeout use asyncio.wait_for. """ input_single_message = not isinstance(messages, list) + converted_messages = [] if isinstance(messages, list): - for index, m in enumerate(messages): - messages[index] = PublicMessage._create_message(m) + for m in messages: + converted_messages.append(PublicMessage._create_message(m)) else: - messages = [PublicMessage._create_message(messages)] + converted_messages = [PublicMessage._create_message(messages)] - futures = await self._reconnector.write_with_ack_future(messages) + futures = await self._reconnector.write_with_ack_future(converted_messages) if input_single_message: return futures[0] else: @@ -200,8 +200,7 @@ def __init__(self, driver: SupportedDriverType, settings: WriterSettings): } if settings.encoders: - for codec, encoder in settings.encoders.items(): - self._codec_functions[codec] = encoder + self._codec_functions.update(settings.encoders) self._encode_executor = settings.encoder_executor @@ -211,8 +210,7 @@ def __init__(self, driver: SupportedDriverType, settings: WriterSettings): self._codec = self._settings.codec if self._codec and self._codec not in self._codec_functions: - known_codecs = [key for key in self._codec_functions] - known_codecs.sort() + known_codecs = sorted(self._codec_functions.keys()) raise ValueError( "Unknown codec for writer: %s, supported codecs: %s" % (self._codec, known_codecs) @@ -445,8 +443,7 @@ async def _codec_selector(self, messages: List[InternalMessage]) -> PublicCodec: # use every of available encoders at start for prevent problems # with rare used encoders (on writer or reader side) if self._codec_selector_batch_num < len(available_codecs): - codec_index = self._codec_selector_batch_num % len(available_codecs) - codec = available_codecs[codec_index] + codec = available_codecs[self._codec_selector_batch_num] else: codec = await self._codec_selector_by_check_compress(messages) self._codec_selector_last_codec = codec @@ -488,9 +485,7 @@ async def _codec_selector_by_check_compress( Try to compress messages and choose codec with the smallest result size. """ - test_messages = messages - if len(test_messages) > 10: - test_messages = test_messages[:10] + test_messages = messages[:10] available_codecs = await self._get_available_codecs() if len(available_codecs) == 1: diff --git a/ydb/topic.py b/ydb/topic.py index a4162208..3ccdda08 100644 --- a/ydb/topic.py +++ b/ydb/topic.py @@ -53,7 +53,9 @@ class TopicClientAsyncIO: _settings: TopicClientSettings _executor: concurrent.futures.Executor - def __init__(self, driver: aio.Driver, settings: Optional[TopicClientSettings]): + def __init__( + self, driver: aio.Driver, settings: Optional[TopicClientSettings] = None + ): if not settings: settings = TopicClientSettings() self._closed = False From ccc6eecc2b4dc6f733e8e225d6fa541056570919 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 7 Mar 2023 17:08:41 +0100 Subject: [PATCH 097/147] rename messagetype --- ydb/_topic_writer/topic_writer.py | 6 ++---- ydb/_topic_writer/topic_writer_asyncio.py | 8 ++++---- ydb/_topic_writer/topic_writer_sync.py | 8 ++++---- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/ydb/_topic_writer/topic_writer.py b/ydb/_topic_writer/topic_writer.py index 7bfb3f30..92212f65 100644 --- a/ydb/_topic_writer/topic_writer.py +++ b/ydb/_topic_writer/topic_writer.py @@ -13,7 +13,7 @@ from .._grpc.grpcwrapper.common_utils import IToProto from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec -MessageType = typing.Union["PublicMessage", "PublicMessage.SimpleMessageSourceType"] +Message = typing.Union["PublicMessage", "PublicMessage.SimpleMessageSourceType"] @dataclass @@ -116,9 +116,7 @@ def __init__( self.data = data @staticmethod - def _create_message( - data: Union["PublicMessage", "PublicMessage.SimpleMessageSourceType"] - ) -> "PublicMessage": + def _create_message(data: Message) -> "PublicMessage": if isinstance(data, PublicMessage): return data return PublicMessage(data=data) diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 1b3b3d25..5e3bb455 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -17,7 +17,7 @@ TopicWriterError, messages_to_proto_requests, PublicWriteResultTypes, - MessageType, + Message, ) from .. import ( _apis, @@ -79,7 +79,7 @@ async def close(self, *, flush: bool = True): async def write_with_ack( self, - messages: Union[MessageType, List[MessageType]], + messages: Union[Message, List[Message]], ) -> Union[PublicWriteResultTypes, List[PublicWriteResultTypes]]: """ IT IS SLOWLY WAY. IT IS BAD CHOISE IN MOST CASES. @@ -100,7 +100,7 @@ async def write_with_ack( async def write_with_ack_future( self, - messages: Union[MessageType, List[MessageType]], + messages: Union[Message, List[Message]], ) -> Union[asyncio.Future, List[asyncio.Future]]: """ send one or number of messages to server. @@ -126,7 +126,7 @@ async def write_with_ack_future( async def write( self, - messages: Union[MessageType, List[MessageType]], + messages: Union[Message, List[Message]], ): """ send one or number of messages to server. diff --git a/ydb/_topic_writer/topic_writer_sync.py b/ydb/_topic_writer/topic_writer_sync.py index 48649875..4713c07d 100644 --- a/ydb/_topic_writer/topic_writer_sync.py +++ b/ydb/_topic_writer/topic_writer_sync.py @@ -10,7 +10,7 @@ TopicWriterError, PublicWriterInitInfo, PublicWriteResult, - MessageType, + Message, ) from .topic_writer_asyncio import WriterAsyncIO @@ -91,20 +91,20 @@ def wait_init(self, timeout: Optional[TimeoutType] = None) -> PublicWriterInitIn def write( self, - messages: Union[MessageType, List[MessageType]], + messages: Union[Message, List[Message]], timeout: Union[float, None] = None, ): self._call_sync(self._async_writer.write(messages), timeout=timeout) def async_write_with_ack( self, - messages: Union[MessageType, List[MessageType]], + messages: Union[Message, List[Message]], ) -> Future[Union[PublicWriteResult, List[PublicWriteResult]]]: return self._call(self._async_writer.write_with_ack(messages)) def write_with_ack( self, - messages: Union[MessageType, List[MessageType]], + messages: Union[Message, List[Message]], timeout: Union[float, None] = None, ) -> Union[PublicWriteResult, List[PublicWriteResult]]: return self._call_sync( From 253f4548c2bf201d8698a677e858da247e5163c8 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Fri, 10 Mar 2023 10:59:44 +0100 Subject: [PATCH 098/147] remove magic of old package renaming --- kikimr/__init__.py | 0 kikimr/public/__init__.py | 0 kikimr/public/api/__init__.py | 6 ------ kikimr/public/sdk/__init__.py | 0 kikimr/public/sdk/python/__init__.py | 0 kikimr/public/sdk/python/client/__init__.py | 20 ------------------- .../sdk/python/client/frameworks/__init__.py | 10 ---------- kikimr/public/sdk/python/iam/__init__.py | 7 ------- kikimr/stub.txt | 1 + ydb/public/__init__.py | 0 ydb/public/api/__init__.py | 0 ydb/public/api/grpc/__init__.py | 8 -------- ydb/public/api/protos/__init__.py | 8 -------- ydb/public/stub.txt | 1 + 14 files changed, 2 insertions(+), 59 deletions(-) delete mode 100644 kikimr/__init__.py delete mode 100644 kikimr/public/__init__.py delete mode 100644 kikimr/public/api/__init__.py delete mode 100644 kikimr/public/sdk/__init__.py delete mode 100644 kikimr/public/sdk/python/__init__.py delete mode 100644 kikimr/public/sdk/python/client/__init__.py delete mode 100644 kikimr/public/sdk/python/client/frameworks/__init__.py delete mode 100644 kikimr/public/sdk/python/iam/__init__.py create mode 100644 kikimr/stub.txt delete mode 100644 ydb/public/__init__.py delete mode 100644 ydb/public/api/__init__.py delete mode 100644 ydb/public/api/grpc/__init__.py delete mode 100644 ydb/public/api/protos/__init__.py create mode 100644 ydb/public/stub.txt diff --git a/kikimr/__init__.py b/kikimr/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/kikimr/public/__init__.py b/kikimr/public/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/kikimr/public/api/__init__.py b/kikimr/public/api/__init__.py deleted file mode 100644 index c97daf69..00000000 --- a/kikimr/public/api/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from ydb.public.api import * # noqa -import sys -import warnings - -sys.modules['kikimr.public.api'] = sys.modules['ydb.public.api'] -warnings.warn("using kikimr.public.api module is deprecated. please use ydb.public.api import instead") diff --git a/kikimr/public/sdk/__init__.py b/kikimr/public/sdk/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/kikimr/public/sdk/python/__init__.py b/kikimr/public/sdk/python/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/kikimr/public/sdk/python/client/__init__.py b/kikimr/public/sdk/python/client/__init__.py deleted file mode 100644 index 157c103e..00000000 --- a/kikimr/public/sdk/python/client/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -# -*- coding: utf-8 -*- -from ydb import * # noqa -import sys -import warnings - -warnings.warn("module kikimr.public.sdk.python.client is deprecated. please use ydb instead") - - -for name, module in sys.modules.copy().items(): - if not name.startswith("ydb"): - continue - - if name.startswith("ydb.public"): - continue - - module_import_path = name.split('.') - if len(module_import_path) < 2: - continue - - sys.modules['kikimr.public.sdk.python.client.' + '.'.join(module_import_path[1:])] = module diff --git a/kikimr/public/sdk/python/client/frameworks/__init__.py b/kikimr/public/sdk/python/client/frameworks/__init__.py deleted file mode 100644 index e5c4940a..00000000 --- a/kikimr/public/sdk/python/client/frameworks/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -try: - from ydb.tornado import * # noqa - import sys - import warnings - - warnings.warn("module kikimr.public.sdk.python.client.frameworks is deprecated. please use ydb.tornado instead") - - sys.modules['kikimr.public.sdk.python.client.frameworks.tornado_helpers'] = sys.modules['ydb.tornado.tornado_helpers'] -except ImportError: - pass diff --git a/kikimr/public/sdk/python/iam/__init__.py b/kikimr/public/sdk/python/iam/__init__.py deleted file mode 100644 index 884d5c1b..00000000 --- a/kikimr/public/sdk/python/iam/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from ydb.iam import * # noqa -import sys -import warnings - -warnings.warn("using kikimr.public.sdk.python.iam module is deprecated. please use ydb.iam import instead") - -sys.modules['kikimr.public.sdk.python.iam.auth'] = sys.modules['ydb.iam.auth'] diff --git a/kikimr/stub.txt b/kikimr/stub.txt new file mode 100644 index 00000000..7467d31a --- /dev/null +++ b/kikimr/stub.txt @@ -0,0 +1 @@ +the folder must not use for prevent issues with intersect with old packages. diff --git a/ydb/public/__init__.py b/ydb/public/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/ydb/public/api/__init__.py b/ydb/public/api/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/ydb/public/api/grpc/__init__.py b/ydb/public/api/grpc/__init__.py deleted file mode 100644 index 08dd1066..00000000 --- a/ydb/public/api/grpc/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -from ydb._grpc.common import * # noqa -import sys -import warnings - -sys.modules["ydb.public.api.grpc"] = sys.modules["ydb._grpc.common"] -warnings.warn( - "using ydb.public.api.grpc module is deprecated. Don't use direct grpc dependencies." -) diff --git a/ydb/public/api/protos/__init__.py b/ydb/public/api/protos/__init__.py deleted file mode 100644 index 204ff3b9..00000000 --- a/ydb/public/api/protos/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -from ydb._grpc.common.protos import * # noqa -import sys -import warnings - -sys.modules["ydb.public.api.protos"] = sys.modules["ydb._grpc.common.protos"] -warnings.warn( - "using ydb.public.api.protos module is deprecated. Don't use direct grpc dependencies." -) diff --git a/ydb/public/stub.txt b/ydb/public/stub.txt new file mode 100644 index 00000000..6c4e84e7 --- /dev/null +++ b/ydb/public/stub.txt @@ -0,0 +1 @@ +the folder must not use for prevent issues with intersect with old protobuf generate packages. From d522cb4ec6f57f7d700fc171f9cdac0eb3a77519 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Fri, 10 Mar 2023 12:43:29 +0100 Subject: [PATCH 099/147] add special compatible with arcadia --- ydb/_grpc/common/__init__.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/ydb/_grpc/common/__init__.py b/ydb/_grpc/common/__init__.py index 4a5ef87b..793eb4bf 100644 --- a/ydb/_grpc/common/__init__.py +++ b/ydb/_grpc/common/__init__.py @@ -1,4 +1,5 @@ import sys +import importlib.util import google.protobuf from packaging.version import Version @@ -8,15 +9,24 @@ # sdk code must always import from ydb._grpc.common protobuf_version = Version(google.protobuf.__version__) -if protobuf_version < Version("4.0"): - from ydb._grpc.v3 import * # noqa - from ydb._grpc.v3 import protos # noqa +# for compatible with arcadia +if importlib.util.find_spec("ydb.public.api"): + from ydb.public.api.grpc import * + sys.modules["ydb._grpc.common"] = sys.modules["ydb.public.api.grpc"] - sys.modules["ydb._grpc.common"] = sys.modules["ydb._grpc.v3"] - sys.modules["ydb._grpc.common.protos"] = sys.modules["ydb._grpc.v3.protos"] + from ydb.public.api import protos + sys.modules["ydb._grpc.common.protos"] = sys.modules["ydb.public.api.protos"] else: - from ydb._grpc.v4 import * # noqa - from ydb._grpc.v4 import protos # noqa + # common way, outside of arcadia + if protobuf_version < Version("4.0"): + from ydb._grpc.v3 import * # noqa + sys.modules["ydb._grpc.common"] = sys.modules["ydb._grpc.v3"] - sys.modules["ydb._grpc.common"] = sys.modules["ydb._grpc.v4"] - sys.modules["ydb._grpc.common.protos"] = sys.modules["ydb._grpc.v4.protos"] + from ydb._grpc.v3 import protos # noqa + sys.modules["ydb._grpc.common.protos"] = sys.modules["ydb._grpc.v3.protos"] + else: + from ydb._grpc.v4 import * # noqa + sys.modules["ydb._grpc.common"] = sys.modules["ydb._grpc.v4"] + + from ydb._grpc.v4 import protos # noqa + sys.modules["ydb._grpc.common.protos"] = sys.modules["ydb._grpc.v4.protos"] From 8175140855ea008cf4c7d51e8de5051f8df4b64a Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Fri, 10 Mar 2023 12:57:35 +0100 Subject: [PATCH 100/147] linter --- ydb/_grpc/common/__init__.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/ydb/_grpc/common/__init__.py b/ydb/_grpc/common/__init__.py index 793eb4bf..1a077800 100644 --- a/ydb/_grpc/common/__init__.py +++ b/ydb/_grpc/common/__init__.py @@ -11,22 +11,28 @@ # for compatible with arcadia if importlib.util.find_spec("ydb.public.api"): - from ydb.public.api.grpc import * + from ydb.public.api.grpc import * # noqa + sys.modules["ydb._grpc.common"] = sys.modules["ydb.public.api.grpc"] from ydb.public.api import protos + sys.modules["ydb._grpc.common.protos"] = sys.modules["ydb.public.api.protos"] else: # common way, outside of arcadia if protobuf_version < Version("4.0"): from ydb._grpc.v3 import * # noqa + sys.modules["ydb._grpc.common"] = sys.modules["ydb._grpc.v3"] from ydb._grpc.v3 import protos # noqa + sys.modules["ydb._grpc.common.protos"] = sys.modules["ydb._grpc.v3.protos"] else: from ydb._grpc.v4 import * # noqa + sys.modules["ydb._grpc.common"] = sys.modules["ydb._grpc.v4"] from ydb._grpc.v4 import protos # noqa + sys.modules["ydb._grpc.common.protos"] = sys.modules["ydb._grpc.v4.protos"] From 3c3b68871c2592989e0bbbbe6df88e6b862ccd11 Mon Sep 17 00:00:00 2001 From: robot Date: Fri, 10 Mar 2023 14:35:48 +0000 Subject: [PATCH 101/147] Release: 3.0.1b7 --- CHANGELOG.md | 1 + setup.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4cfa110d..6cac0f6e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,4 @@ +## 3.0.1b7 ## * BROKEN CHANGE: deny any action in transaction after commit/rollback ## 3.0.1b6 ## diff --git a/setup.py b/setup.py index e585796e..9d694e18 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ setuptools.setup( name="ydb", - version="3.0.1b6", # AUTOVERSION + version="3.0.1b7", # AUTOVERSION description="YDB Python SDK", author="Yandex LLC", author_email="ydb@yandex-team.ru", From 1f8b1908e4349b283b23e0f80b1f0c5bee730ab8 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Fri, 10 Mar 2023 15:39:56 +0100 Subject: [PATCH 102/147] Update CHANGELOG.md --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6cac0f6e..16b2e564 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,8 @@ ## 3.0.1b7 ## * BROKEN CHANGE: deny any action in transaction after commit/rollback +* BROKEN CHANGE: raise exception for truncated response by default +* Compatible protobaf detection for arcadia +* Add codecs support for topic writer ## 3.0.1b6 ## * BROKEN CHANGES: remove writer.write(mess1, mess2) variant, use list instead: writer.write([mess1, mess2]) From c729d58bda684f7bcba1aba2f4573b15c5d99080 Mon Sep 17 00:00:00 2001 From: Valeriya Popova Date: Thu, 9 Mar 2023 15:31:46 +0300 Subject: [PATCH 103/147] topic writer: update auth token loop --- tests/topics/test_topic_writer.py | 6 +- ydb/_topic_common/common.py | 1 - ydb/_topic_reader/topic_reader.py | 2 - ydb/_topic_reader/topic_reader_asyncio.py | 5 -- ydb/_topic_writer/topic_writer.py | 2 +- ydb/_topic_writer/topic_writer_asyncio.py | 79 +++++++++++++------ .../topic_writer_asyncio_test.py | 15 +++- ydb/credentials.py | 6 ++ 8 files changed, 76 insertions(+), 40 deletions(-) diff --git a/tests/topics/test_topic_writer.py b/tests/topics/test_topic_writer.py index c53ce0db..68c34a8e 100644 --- a/tests/topics/test_topic_writer.py +++ b/tests/topics/test_topic_writer.py @@ -94,9 +94,9 @@ async def test_write_multi_message_with_ack( ) async def test_write_encoded(self, driver: ydb.Driver, topic_path: str, codec): async with driver.topic_client.writer(topic_path, codec=codec) as writer: - writer.write("a" * 1000) - writer.write("b" * 1000) - writer.write("c" * 1000) + await writer.write("a" * 1000) + await writer.write("b" * 1000) + await writer.write("c" * 1000) class TestTopicWriterSync: diff --git a/ydb/_topic_common/common.py b/ydb/_topic_common/common.py index f2d6ca9b..c569daca 100644 --- a/ydb/_topic_common/common.py +++ b/ydb/_topic_common/common.py @@ -7,7 +7,6 @@ from .. import operation, issues from .._grpc.grpcwrapper.common_utils import IFromProtoWithProtoType -TokenGetterFuncType = typing.Optional[typing.Callable[[], str]] TimeoutType = typing.Union[int, float] diff --git a/ydb/_topic_reader/topic_reader.py b/ydb/_topic_reader/topic_reader.py index 4c9e63e1..759bbab7 100644 --- a/ydb/_topic_reader/topic_reader.py +++ b/ydb/_topic_reader/topic_reader.py @@ -8,7 +8,6 @@ ) from ..table import RetrySettings -from .._topic_common.common import TokenGetterFuncType from .._grpc.grpcwrapper.ydb_topic import StreamReadMessage, OffsetsRange @@ -28,7 +27,6 @@ class PublicReaderSettings: consumer: str topic: str buffer_size_bytes: int = 50 * 1024 * 1024 - _token_getter: Optional[TokenGetterFuncType] = None # on_commit: Callable[["Events.OnCommit"], None] = None # on_get_partition_start_offset: Callable[ # ["Events.OnPartitionGetStartOffsetRequest"], diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index 3e0e362e..c1bb321d 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -12,9 +12,6 @@ from ..issues import Error as YdbError, _process_response from . import datatypes from . import topic_reader -from .._topic_common.common import ( - TokenGetterFuncType, -) from .._grpc.grpcwrapper.common_utils import ( IGrpcWrapperAsyncIO, SupportedDriverType, @@ -264,7 +261,6 @@ class ReaderStream: _id: int _reader_reconnector_id: int - _token_getter: Optional[TokenGetterFuncType] _session_id: str _stream: Optional[IGrpcWrapperAsyncIO] _started: bool @@ -282,7 +278,6 @@ def __init__( ): self._id = ReaderStream._static_id_counter.inc_and_get() self._reader_reconnector_id = reader_reconnector_id - self._token_getter = settings._token_getter self._session_id = "not initialized" self._stream = None self._started = False diff --git a/ydb/_topic_writer/topic_writer.py b/ydb/_topic_writer/topic_writer.py index 92212f65..78349a88 100644 --- a/ydb/_topic_writer/topic_writer.py +++ b/ydb/_topic_writer/topic_writer.py @@ -44,7 +44,7 @@ class PublicWriterSettings: # codec: Optional[int] = None # codec_autoselect: bool = True # retry_policy: Optional["RetryPolicy"] = None - # update_token_interval: Union[int, float] = 3600 + update_token_interval: Union[int, float] = 3600 def __post_init__(self): if self.producer_id is None: diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 5e3bb455..9ee40250 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -26,10 +26,8 @@ RetrySettings, ) from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec -from .._topic_common.common import ( - TokenGetterFuncType, -) from .._grpc.grpcwrapper.ydb_topic import ( + UpdateTokenRequest, UpdateTokenResponse, StreamWriteMessage, WriterMessagesFromServerToClient, @@ -159,8 +157,6 @@ class WriterAsyncIOReconnector: _loop: asyncio.AbstractEventLoop _credentials: Union[ydb.credentials.Credentials, None] _driver: ydb.aio.Driver - _update_token_interval: int - _token_get_function: TokenGetterFuncType _init_message: StreamWriteMessage.InitRequest _init_info: asyncio.Future _stream_connected: asyncio.Event @@ -237,11 +233,8 @@ async def close(self, flush: bool): self._closed = True self._stop(TopicWriterStopped()) - background_tasks = self._background_tasks - - for task in background_tasks: + for task in self._background_tasks: task.cancel() - await asyncio.wait(self._background_tasks) # if work was stopped before close by error - raise the error @@ -349,7 +342,10 @@ async def _connection_loop(self): stream_writer = None try: stream_writer = await WriterAsyncIOStream.create( - self._driver, self._init_message, self._get_token + self._driver, + self._init_message, + 3, + # self._settings.update_token_interval, ) try: self._last_known_seq_no = stream_writer.last_seqno @@ -371,12 +367,10 @@ async def _connection_loop(self): self._read_loop(stream_writer), name="writer receive loop" ) - pending = [send_loop, receive_loop] - done, pending = await asyncio.wait( [send_loop, receive_loop], return_when=asyncio.FIRST_COMPLETED ) - stream_writer.close() + await stream_writer.close() done.pop().result() except issues.Error as err: # todo log error @@ -394,7 +388,7 @@ async def _connection_loop(self): return finally: if stream_writer: - stream_writer.close() + await stream_writer.close() if len(pending) > 0: for task in pending: task.cancel() @@ -561,9 +555,6 @@ def _stop(self, reason: Exception): self._stop_reason.set_result(reason) - def _get_token(self) -> str: - raise NotImplementedError() - async def flush(self): self._check_stop() if not self._messages_future: @@ -575,29 +566,43 @@ async def flush(self): class WriterAsyncIOStream: # todo slots + _closed: bool last_seqno: int supported_codecs: Optional[List[PublicCodec]] _stream: IGrpcWrapperAsyncIO - _token_getter: TokenGetterFuncType _requests: asyncio.Queue _responses: AsyncIterator + _update_token_interval: int + _update_token_task: asyncio.Task + _update_token_event: asyncio.Event + _get_token_function: Callable[[], str] + def __init__( - self, - token_getter: TokenGetterFuncType, + self, update_token_interval: int, get_token_function: Callable[[], str] ): - self._token_getter = token_getter + self._closed = False + + self._update_token_interval = update_token_interval + self._get_token_function = get_token_function + self._update_token_event = asyncio.Event() - def close(self): + async def close(self): + if self._closed: + return + self._closed = True + + self._update_token_task.cancel() + await self._update_token_task self._stream.close() @staticmethod async def create( driver: SupportedDriverType, init_request: StreamWriteMessage.InitRequest, - token_getter: TokenGetterFuncType, + update_token_interval: int, ) -> "WriterAsyncIOStream": stream = GrpcWrapperAsyncIO(StreamWriteMessage.FromServer.from_proto) @@ -605,7 +610,11 @@ async def create( driver, _apis.TopicService.Stub, _apis.TopicService.StreamWrite ) - writer = WriterAsyncIOStream(token_getter) + creds = driver._credentials + writer = WriterAsyncIOStream( + update_token_interval=update_token_interval, + get_token_function=creds.get_auth_token if creds else lambda: "", + ) await writer._start(stream, init_request) return writer @@ -616,6 +625,7 @@ async def receive(self) -> StreamWriteMessage.WriteResponse: if isinstance(item, StreamWriteMessage.WriteResponse): return item if isinstance(item, UpdateTokenResponse): + self._update_token_event.set() continue # todo log unknown messages instead of raise exception @@ -636,6 +646,11 @@ async def _start( self._stream = stream + self._update_token_event.set() + self._update_token_task = asyncio.create_task( + self._update_token_loop(), name="update_token_loop" + ) + @staticmethod def _ensure_ok(message: WriterMessagesFromServerToClient): if not message.status.is_success(): @@ -644,5 +659,21 @@ def _ensure_ok(message: WriterMessagesFromServerToClient): ) def write(self, messages: List[InternalMessage]): + if self._closed: + raise RuntimeError("Can not write on closed stream.") + for request in messages_to_proto_requests(messages): self._stream.write(request) + + async def _update_token_loop(self): + while True: + await asyncio.sleep(self._update_token_interval) + await self._update_token(token=self._get_token_function()) + + async def _update_token(self, token: str): + await self._update_token_event.wait() + try: + msg = StreamWriteMessage.FromClient(UpdateTokenRequest(token)) + self._stream.write(msg) + finally: + self._update_token_event.clear() diff --git a/ydb/_topic_writer/topic_writer_asyncio_test.py b/ydb/_topic_writer/topic_writer_asyncio_test.py index 921c6aa4..7e440f1c 100644 --- a/ydb/_topic_writer/topic_writer_asyncio_test.py +++ b/ydb/_topic_writer/topic_writer_asyncio_test.py @@ -15,7 +15,12 @@ from .. import aio from .. import StatusCode, issues -from .._grpc.grpcwrapper.ydb_topic import Codec, StreamWriteMessage +from .._grpc.grpcwrapper.ydb_topic import ( + Codec, + StreamWriteMessage, + UpdateTokenRequest, + UpdateTokenResponse, +) from .._grpc.grpcwrapper.common_utils import ServerStatus from .topic_writer import ( InternalMessage, @@ -35,11 +40,13 @@ WriterAsyncIO, ) +from ..credentials import AnonymousCredentials + @pytest.fixture def default_driver() -> aio.Driver: driver = mock.Mock(spec=aio.Driver) - driver._credentials = mock.Mock() + driver._credentials = AnonymousCredentials() return driver @@ -66,7 +73,7 @@ async def writer_and_stream(self, stream) -> WriterWithMockedStream: ) ) - writer = WriterAsyncIOStream(None) + writer = WriterAsyncIOStream(1, lambda: "") await writer._start( stream, init_message=StreamWriteMessage.InitRequest( @@ -107,7 +114,7 @@ async def test_init_writer(self, stream): ) ) - writer = WriterAsyncIOStream(None) + writer = WriterAsyncIOStream(1, lambda: "") await writer._start(stream, init_message) sent_message = await stream.from_client.get() diff --git a/ydb/credentials.py b/ydb/credentials.py index 13b45b20..2a2dea3b 100644 --- a/ydb/credentials.py +++ b/ydb/credentials.py @@ -39,6 +39,12 @@ def auth_metadata(self): """ pass + def get_auth_token(self) -> str: + for header, token in self.auth_metadata(): + if header == YDB_AUTH_TICKET_HEADER: + return token + return "" + class OneToManyValue(object): def __init__(self): From b4a5006de55d974d8de55e7bc4a0ed902b57f268 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Fri, 10 Mar 2023 18:26:15 +0100 Subject: [PATCH 104/147] fix leak tasks --- ydb/_topic_writer/topic_writer_asyncio.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 9ee40250..e3fe1e0d 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -336,7 +336,7 @@ async def _connection_loop(self): while True: attempt = 0 # todo calc and reset - pending = [] + tasks = [] # noinspection PyBroadException stream_writer = None @@ -367,7 +367,8 @@ async def _connection_loop(self): self._read_loop(stream_writer), name="writer receive loop" ) - done, pending = await asyncio.wait( + tasks = [send_loop, receive_loop] + done, _ = await asyncio.wait( [send_loop, receive_loop], return_when=asyncio.FIRST_COMPLETED ) await stream_writer.close() @@ -389,10 +390,9 @@ async def _connection_loop(self): finally: if stream_writer: await stream_writer.close() - if len(pending) > 0: - for task in pending: - task.cancel() - await asyncio.wait(pending) + for task in tasks: + task.cancel() + await asyncio.wait(tasks) async def _encode_loop(self): while True: @@ -595,7 +595,7 @@ async def close(self): self._closed = True self._update_token_task.cancel() - await self._update_token_task + await asyncio.wait([self._update_token_task]) self._stream.close() @staticmethod From 8a32f24e3989ca110c550e12d7f9bd0753dd6280 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 14 Mar 2023 09:00:04 +0100 Subject: [PATCH 105/147] fix check truncated setting with None table_client_settings --- CHANGELOG.md | 2 ++ ydb/convert.py | 11 +++++++---- ydb/table.py | 2 +- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 16b2e564..d818dad5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,5 @@ +* Fixed exception while create ResultSet with None table_settings + ## 3.0.1b7 ## * BROKEN CHANGE: deny any action in transaction after commit/rollback * BROKEN CHANGE: raise exception for truncated response by default diff --git a/ydb/convert.py b/ydb/convert.py index 567900a1..d09f8e79 100644 --- a/ydb/convert.py +++ b/ydb/convert.py @@ -12,6 +12,7 @@ _DecimalInfRepr = 10**35 _DecimalSignedInfRepr = -(10**35) _primitive_type_by_id = {} +_default_allow_truncated_result = False def _initialize(): @@ -484,16 +485,18 @@ def __init__(self, result_sets_pb, table_client_settings=None): if table_client_settings is None else table_client_settings._make_result_sets_lazy ) + + allow_truncated_result = _default_allow_truncated_result + if table_client_settings: + allow_truncated_result = table_client_settings._allow_truncated_result + result_sets = [] initializer = ( _ResultSet.from_message if not make_lazy else _ResultSet.lazy_from_message ) for result_set in result_sets_pb: result_set = initializer(result_set, table_client_settings) - if ( - result_set.truncated - and not table_client_settings._allow_truncated_result - ): + if result_set.truncated and not allow_truncated_result: raise issues.TruncatedResponseError( "Response for the request was truncated by server" ) diff --git a/ydb/table.py b/ydb/table.py index 40431c62..fcb6de5a 100644 --- a/ydb/table.py +++ b/ydb/table.py @@ -1002,7 +1002,7 @@ def __init__(self): self._native_json_in_result_sets = False self._native_interval_in_result_sets = False self._native_timestamp_in_result_sets = False - self._allow_truncated_result = False + self._allow_truncated_result = convert._default_allow_truncated_result def with_native_timestamp_in_result_sets(self, enabled): # type:(bool) -> ydb.TableClientSettings From 5f1335db8812a8f4e5a27a72841fe68dcaac4d57 Mon Sep 17 00:00:00 2001 From: robot Date: Tue, 14 Mar 2023 08:16:56 +0000 Subject: [PATCH 106/147] Release: 3.0.1b8 --- CHANGELOG.md | 1 + setup.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d818dad5..006c8496 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,4 @@ +## 3.0.1b8 ## * Fixed exception while create ResultSet with None table_settings ## 3.0.1b7 ## diff --git a/setup.py b/setup.py index 9d694e18..008c407d 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ setuptools.setup( name="ydb", - version="3.0.1b7", # AUTOVERSION + version="3.0.1b8", # AUTOVERSION description="YDB Python SDK", author="Yandex LLC", author_email="ydb@yandex-team.ru", From 02889a68472d3f004e62c4391f3cc1e096872083 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 14 Mar 2023 13:02:21 +0100 Subject: [PATCH 107/147] change internal deny split to positive allow split (with false by default) - for increase readable. Add global settings function --- CHANGELOG.md | 2 ++ ydb/__init__.py | 1 + ydb/aio/table.py | 7 +++++-- ydb/global_settings.py | 10 ++++++++++ ydb/table.py | 20 +++++++++++++------- 5 files changed, 31 insertions(+), 9 deletions(-) create mode 100644 ydb/global_settings.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 006c8496..ea3f9ab8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,5 @@ +* Add function for global revert broken behaviour: ydb.global_allow_truncated_result, ydb.global_allow_split_transactions + ## 3.0.1b8 ## * Fixed exception while create ResultSet with None table_settings diff --git a/ydb/__init__.py b/ydb/__init__.py index 6607e1a4..f305fdbc 100644 --- a/ydb/__init__.py +++ b/ydb/__init__.py @@ -1,5 +1,6 @@ from .credentials import * # noqa from .driver import * # noqa +from .global_settings import * # noqa from .table import * # noqa from .issues import * # noqa from .types import * # noqa diff --git a/ydb/aio/table.py b/ydb/aio/table.py index 95e2723d..f937a928 100644 --- a/ydb/aio/table.py +++ b/ydb/aio/table.py @@ -13,6 +13,7 @@ _scan_query_request_factory, _wrap_scan_query_response, BaseTxContext, + _allow_split_transaction, ) from . import _utilities from ydb import _apis, _session_impl @@ -120,13 +121,15 @@ async def alter_table( set_read_replicas_settings, ) - def transaction(self, tx_mode=None, *, deny_split_transactions=True): + def transaction( + self, tx_mode=None, *, allow_split_transactions=_allow_split_transaction + ): return TxContext( self._driver, self._state, self, tx_mode, - deny_split_transactions=deny_split_transactions, + allow_split_transactions=allow_split_transactions, ) async def describe_table(self, path, settings=None): # pylint: disable=W0236 diff --git a/ydb/global_settings.py b/ydb/global_settings.py new file mode 100644 index 00000000..879690f2 --- /dev/null +++ b/ydb/global_settings.py @@ -0,0 +1,10 @@ +from . import convert +from . import table + + +def global_allow_truncated_result(enabled: bool = True): + convert._default_allow_truncated_result = enabled + + +def global_allow_split_transactions(enabled: bool): + table._allow_split_transaction = enabled diff --git a/ydb/table.py b/ydb/table.py index fcb6de5a..737afd76 100644 --- a/ydb/table.py +++ b/ydb/table.py @@ -27,6 +27,8 @@ except ImportError: interceptor = None +_allow_split_transaction = False + logger = logging.getLogger(__name__) ################################################################## @@ -1179,7 +1181,9 @@ def execute_scheme(self, yql_text, settings=None): pass @abstractmethod - def transaction(self, tx_mode=None, deny_split_transactions=True): + def transaction( + self, tx_mode=None, allow_split_transactions=_allow_split_transaction + ): pass @abstractmethod @@ -1683,13 +1687,15 @@ def execute_scheme(self, yql_text, settings=None): self._state.endpoint, ) - def transaction(self, tx_mode=None, deny_split_transactions=True): + def transaction( + self, tx_mode=None, allow_split_transactions=_allow_split_transaction + ): return TxContext( self._driver, self._state, self, tx_mode, - deny_split_transactions=deny_split_transactions, + allow_split_transactions=allow_split_transactions, ) def has_prepared(self, query): @@ -2207,7 +2213,7 @@ class BaseTxContext(ITxContext): "_driver", "session", "_finished", - "_deny_split_transactions", + "_allow_split_transactions", ) _COMMIT = "commit" @@ -2220,7 +2226,7 @@ def __init__( session, tx_mode=None, *, - deny_split_transactions=True + allow_split_transactions=_allow_split_transaction ): """ An object that provides a simple transaction context manager that allows statements execution @@ -2245,7 +2251,7 @@ def __init__( self._session_state = session_state self.session = session self._finished = "" - self._deny_split_transactions = deny_split_transactions + self._allow_split_transactions = allow_split_transactions def __enter__(self): """ @@ -2405,7 +2411,7 @@ def _check_split(self, allow=""): Deny all operaions with transaction after commit/rollback. Exception: double commit and double rollbacks, because it is safe """ - if not self._deny_split_transactions: + if self._allow_split_transactions: return if self._finished != "" and self._finished != allow: From 1e623b3cb1f4ccbce1d7df09f050bfcb49d83f62 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 14 Mar 2023 13:28:29 +0100 Subject: [PATCH 108/147] add warnings --- ydb/global_settings.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/ydb/global_settings.py b/ydb/global_settings.py index 879690f2..3b0368ad 100644 --- a/ydb/global_settings.py +++ b/ydb/global_settings.py @@ -1,10 +1,26 @@ +import warnings + from . import convert from . import table def global_allow_truncated_result(enabled: bool = True): + if enabled: + warnings.warn("Global allow truncated response is deprecated behaviour.") + else: + warnings.warn( + "Global deny truncated response is default behaviour. You don't need call the function." + ) + convert._default_allow_truncated_result = enabled def global_allow_split_transactions(enabled: bool): + if enabled: + warnings.warn("Global allow truncated response is deprecated behaviour.") + else: + warnings.warn( + "Global deby truncated response is default behaviour. You don't need call the function." + ) + table._allow_split_transaction = enabled From 592bb442eae3bc1c4ec2d5c756c340f0df5e5c44 Mon Sep 17 00:00:00 2001 From: Valeriya Popova Date: Mon, 13 Mar 2023 16:11:03 +0300 Subject: [PATCH 109/147] topic writer: add tests --- tests/topics/test_topic_reader.py | 2 + ydb/_topic_writer/topic_writer_asyncio.py | 31 ++++++----- .../topic_writer_asyncio_test.py | 53 ++++++++++++++++--- 3 files changed, 65 insertions(+), 21 deletions(-) diff --git a/tests/topics/test_topic_reader.py b/tests/topics/test_topic_reader.py index a874c743..214a7620 100644 --- a/tests/topics/test_topic_reader.py +++ b/tests/topics/test_topic_reader.py @@ -23,6 +23,8 @@ async def test_read_and_commit_message( batch2 = await reader.receive_batch() assert batch.messages[0] != batch2.messages[0] + await reader.close() + class TestTopicReaderSync: def test_read_message( diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index e3fe1e0d..7cb1f1db 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -344,8 +344,7 @@ async def _connection_loop(self): stream_writer = await WriterAsyncIOStream.create( self._driver, self._init_message, - 3, - # self._settings.update_token_interval, + self._settings.update_token_interval, ) try: self._last_known_seq_no = stream_writer.last_seqno @@ -575,34 +574,39 @@ class WriterAsyncIOStream: _requests: asyncio.Queue _responses: AsyncIterator - _update_token_interval: int - _update_token_task: asyncio.Task + _update_token_interval: Optional[Union[int, float]] + _update_token_task: Optional[asyncio.Task] _update_token_event: asyncio.Event - _get_token_function: Callable[[], str] + _get_token_function: Optional[Callable[[], str]] def __init__( - self, update_token_interval: int, get_token_function: Callable[[], str] + self, + update_token_interval: Optional[Union[int, float]] = None, + get_token_function: Optional[Callable[[], str]] = None, ): self._closed = False self._update_token_interval = update_token_interval self._get_token_function = get_token_function self._update_token_event = asyncio.Event() + self._update_token_task = None async def close(self): if self._closed: return self._closed = True - self._update_token_task.cancel() - await asyncio.wait([self._update_token_task]) + if self._update_token_task: + self._update_token_task.cancel() + await asyncio.wait([self._update_token_task]) + self._stream.close() @staticmethod async def create( driver: SupportedDriverType, init_request: StreamWriteMessage.InitRequest, - update_token_interval: int, + update_token_interval: Optional[Union[int, float]] = None, ) -> "WriterAsyncIOStream": stream = GrpcWrapperAsyncIO(StreamWriteMessage.FromServer.from_proto) @@ -646,10 +650,11 @@ async def _start( self._stream = stream - self._update_token_event.set() - self._update_token_task = asyncio.create_task( - self._update_token_loop(), name="update_token_loop" - ) + if self._update_token_interval is not None: + self._update_token_event.set() + self._update_token_task = asyncio.create_task( + self._update_token_loop(), name="update_token_loop" + ) @staticmethod def _ensure_ok(message: WriterMessagesFromServerToClient): diff --git a/ydb/_topic_writer/topic_writer_asyncio_test.py b/ydb/_topic_writer/topic_writer_asyncio_test.py index 7e440f1c..73c959f9 100644 --- a/ydb/_topic_writer/topic_writer_asyncio_test.py +++ b/ydb/_topic_writer/topic_writer_asyncio_test.py @@ -7,7 +7,7 @@ import gzip import typing from queue import Queue, Empty -from typing import List +from typing import List, Callable, Optional from unittest import mock import freezegun @@ -59,10 +59,12 @@ class WriterWithMockedStream: @pytest.fixture def stream(self): - return StreamMock() + stream = StreamMock() + yield stream + stream.close() @pytest.fixture - async def writer_and_stream(self, stream) -> WriterWithMockedStream: + async def writer_and_stream(self, stream, request) -> WriterWithMockedStream: stream.from_server.put_nowait( StreamWriteMessage.InitResponse( last_seq_no=4, @@ -73,7 +75,9 @@ async def writer_and_stream(self, stream) -> WriterWithMockedStream: ) ) - writer = WriterAsyncIOStream(1, lambda: "") + params = getattr(request, "param", ()) + writer = WriterAsyncIOStream(*params) + await writer._start( stream, init_message=StreamWriteMessage.InitRequest( @@ -88,11 +92,13 @@ async def writer_and_stream(self, stream) -> WriterWithMockedStream: ) await stream.from_client.get() - return TestWriterAsyncIOStream.WriterWithMockedStream( + yield TestWriterAsyncIOStream.WriterWithMockedStream( stream=stream, writer=writer, ) + await writer.close() + async def test_init_writer(self, stream): init_seqno = 4 init_message = StreamWriteMessage.InitRequest( @@ -114,7 +120,7 @@ async def test_init_writer(self, stream): ) ) - writer = WriterAsyncIOStream(1, lambda: "") + writer = WriterAsyncIOStream() await writer._start(stream, init_message) sent_message = await stream.from_client.get() @@ -123,6 +129,8 @@ async def test_init_writer(self, stream): assert expected_message == sent_message assert writer.last_seqno == init_seqno + await writer.close() + async def test_write_a_message(self, writer_and_stream: WriterWithMockedStream): data = "123".encode() now = datetime.datetime.now() @@ -156,6 +164,30 @@ async def test_write_a_message(self, writer_and_stream: WriterWithMockedStream): sent_message = await writer_and_stream.stream.from_client.get() assert expected_message == sent_message + @pytest.mark.parametrize( + "writer_and_stream", [(0.1, lambda: "foo-bar")], indirect=True + ) + async def test_update_token(self, writer_and_stream: WriterWithMockedStream): + assert writer_and_stream.stream.from_client.empty() + + expected = StreamWriteMessage.FromClient(UpdateTokenRequest(token="foo-bar")) + got = await wait_for_fast(writer_and_stream.stream.from_client.get()) + assert expected == got, "send update token request" + + await asyncio.sleep(0.2) + assert ( + writer_and_stream.stream.from_client.empty() + ), "no answer - no new update request" + + await writer_and_stream.stream.from_server.put(UpdateTokenResponse()) + receive_task = asyncio.create_task(writer_and_stream.writer.receive()) + + got = await wait_for_fast(writer_and_stream.stream.from_client.get()) + assert expected == got + + receive_task.cancel() + await asyncio.wait([receive_task]) + @pytest.mark.asyncio class TestWriterAsyncIOReconnector: @@ -171,7 +203,11 @@ class StreamWriterMock: _closed: bool - def __init__(self): + def __init__( + self, + update_token_interval: Optional[int, float] = None, + get_token_function: Optional[Callable[[], str]] = None, + ): self.last_seqno = 0 self.from_server = asyncio.Queue() self.from_client = asyncio.Queue() @@ -193,7 +229,7 @@ async def receive(self) -> StreamWriteMessage.WriteResponse: raise item return item - def close(self): + async def close(self): if self._closed: return self._closed = True @@ -251,6 +287,7 @@ def default_settings(self) -> WriterSettings: auto_seqno=False, auto_created_at=False, codec=PublicCodec.RAW, + update_token_interval=3600, ) ) From 39615f69311b4b51c7d1f1d2eff4ae884942597d Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 14 Mar 2023 23:35:45 +0300 Subject: [PATCH 110/147] fix args in tests --- tests/aio/test_tx.py | 4 ++-- tests/table/test_tx.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/aio/test_tx.py b/tests/aio/test_tx.py index da66769e..fad6c295 100644 --- a/tests/aio/test_tx.py +++ b/tests/aio/test_tx.py @@ -102,7 +102,7 @@ async def test_split_transactions_deny_split(driver, table_name): async with ydb.aio.SessionPool(driver, 1) as pool: async def check_transaction(s: ydb.aio.table.Session): - async with s.transaction(deny_split_transactions=True) as tx: + async with s.transaction(allow_split_transactions=False) as tx: await tx.execute("INSERT INTO %s (id) VALUES (1)" % table_name) await tx.commit() @@ -123,7 +123,7 @@ async def test_split_transactions_allow_split(driver, table_name): async with ydb.aio.SessionPool(driver, 1) as pool: async def check_transaction(s: ydb.aio.table.Session): - async with s.transaction(deny_split_transactions=False) as tx: + async with s.transaction(allow_split_transactions=True) as tx: await tx.execute("INSERT INTO %s (id) VALUES (1)" % table_name) await tx.commit() diff --git a/tests/table/test_tx.py b/tests/table/test_tx.py index bd703fa8..a6ee1d9e 100644 --- a/tests/table/test_tx.py +++ b/tests/table/test_tx.py @@ -96,7 +96,7 @@ def test_split_transactions_deny_split(driver_sync, table_name): with ydb.SessionPool(driver_sync, 1) as pool: def check_transaction(s: ydb.table.Session): - with s.transaction(deny_split_transactions=True) as tx: + with s.transaction(allow_split_transactions=False) as tx: tx.execute("INSERT INTO %s (id) VALUES (1)" % table_name) tx.commit() @@ -116,7 +116,7 @@ def test_split_transactions_allow_split(driver_sync, table_name): with ydb.SessionPool(driver_sync, 1) as pool: def check_transaction(s: ydb.table.Session): - with s.transaction(deny_split_transactions=False) as tx: + with s.transaction(allow_split_transactions=True) as tx: tx.execute("INSERT INTO %s (id) VALUES (1)" % table_name) tx.commit() From 0b3dc066a9ae65500ca3d4db20bc742f146f3a8b Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Wed, 15 Mar 2023 00:42:22 +0300 Subject: [PATCH 111/147] Update CHANGELOG.md --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ea3f9ab8..fabad68b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,5 @@ * Add function for global revert broken behaviour: ydb.global_allow_truncated_result, ydb.global_allow_split_transactions +* Change argument names from deny_split_transactions to allow_split_transactions (with reverse value ## 3.0.1b8 ## * Fixed exception while create ResultSet with None table_settings From 584e3931905683aabc16216eee8048effcb82ab7 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Wed, 15 Mar 2023 18:44:08 +0300 Subject: [PATCH 112/147] add support reader codecs (#214) * add support reader codecs * fix writer codec bug --- tests/topics/test_topic_reader.py | 62 +++++++ ydb/_topic_reader/datatypes.py | 3 +- ydb/_topic_reader/topic_reader.py | 11 +- ydb/_topic_reader/topic_reader_asyncio.py | 73 +++++++- .../topic_reader_asyncio_test.py | 164 +++++++++++++++++- ydb/_topic_reader/topic_reader_sync.py | 6 + ydb/_topic_writer/topic_writer.py | 4 +- ydb/topic.py | 22 ++- 8 files changed, 331 insertions(+), 14 deletions(-) diff --git a/tests/topics/test_topic_reader.py b/tests/topics/test_topic_reader.py index 214a7620..2a451baf 100644 --- a/tests/topics/test_topic_reader.py +++ b/tests/topics/test_topic_reader.py @@ -1,5 +1,7 @@ import pytest +import ydb + @pytest.mark.asyncio class TestTopicReaderAsyncIO: @@ -25,6 +27,36 @@ async def test_read_and_commit_message( await reader.close() + async def test_read_compressed_messages(self, driver, topic_path, topic_consumer): + async with driver.topic_client.writer( + topic_path, codec=ydb.TopicCodec.GZIP + ) as writer: + await writer.write("123") + + async with driver.topic_client.reader(topic_consumer, topic_path) as reader: + batch = await reader.receive_batch() + assert batch.messages[0].data.decode() == "123" + + async def test_read_custom_encoded(self, driver, topic_path, topic_consumer): + codec = 10001 + + def encode(b: bytes): + return bytes(reversed(b)) + + def decode(b: bytes): + return bytes(reversed(b)) + + async with driver.topic_client.writer( + topic_path, codec=codec, encoders={codec: encode} + ) as writer: + await writer.write("123") + + async with driver.topic_client.reader( + topic_consumer, topic_path, decoders={codec: decode} + ) as reader: + batch = await reader.receive_batch() + assert batch.messages[0].data.decode() == "123" + class TestTopicReaderSync: def test_read_message( @@ -45,3 +77,33 @@ def test_read_and_commit_message( reader = driver_sync.topic_client.reader(topic_consumer, topic_path) batch2 = reader.receive_batch() assert batch.messages[0] != batch2.messages[0] + + def test_read_compressed_messages(self, driver_sync, topic_path, topic_consumer): + with driver_sync.topic_client.writer( + topic_path, codec=ydb.TopicCodec.GZIP + ) as writer: + writer.write("123") + + with driver_sync.topic_client.reader(topic_consumer, topic_path) as reader: + batch = reader.receive_batch() + assert batch.messages[0].data.decode() == "123" + + def test_read_custom_encoded(self, driver_sync, topic_path, topic_consumer): + codec = 10001 + + def encode(b: bytes): + return bytes(reversed(b)) + + def decode(b: bytes): + return bytes(reversed(b)) + + with driver_sync.topic_client.writer( + topic_path, codec=codec, encoders={codec: encode} + ) as writer: + writer.write("123") + + with driver_sync.topic_client.reader( + topic_consumer, topic_path, decoders={codec: decode} + ) as reader: + batch = reader.receive_batch() + assert batch.messages[0].data.decode() == "123" diff --git a/ydb/_topic_reader/datatypes.py b/ydb/_topic_reader/datatypes.py index 6ca7681c..3845995f 100644 --- a/ydb/_topic_reader/datatypes.py +++ b/ydb/_topic_reader/datatypes.py @@ -9,7 +9,7 @@ import datetime from typing import Mapping, Union, Any, List, Dict, Deque, Optional -from ydb._grpc.grpcwrapper.ydb_topic import OffsetsRange +from ydb._grpc.grpcwrapper.ydb_topic import OffsetsRange, Codec from ydb._topic_reader import topic_reader_asyncio @@ -168,6 +168,7 @@ class PublicBatch(ICommittable, ISessionAlive): messages: List[PublicMessage] _partition_session: PartitionSession _bytes_size: int + _codec: Codec def _commit_get_partition_session(self) -> PartitionSession: return self.messages[0]._commit_get_partition_session() diff --git a/ydb/_topic_reader/topic_reader.py b/ydb/_topic_reader/topic_reader.py index 759bbab7..14474be6 100644 --- a/ydb/_topic_reader/topic_reader.py +++ b/ydb/_topic_reader/topic_reader.py @@ -1,3 +1,4 @@ +import concurrent.futures import enum import datetime from dataclasses import dataclass @@ -5,6 +6,8 @@ Union, Optional, List, + Mapping, + Callable, ) from ..table import RetrySettings @@ -27,6 +30,13 @@ class PublicReaderSettings: consumer: str topic: str buffer_size_bytes: int = 50 * 1024 * 1024 + + decoders: Union[Mapping[int, Callable[[bytes], bytes]], None] = None + """decoders: map[codec_code] func(encoded_bytes)->decoded_bytes""" + + # decoder_executor, must be set for handle non raw messages + decoder_executor: Optional[concurrent.futures.Executor] = None + # on_commit: Callable[["Events.OnCommit"], None] = None # on_get_partition_start_offset: Callable[ # ["Events.OnPartitionGetStartOffsetRequest"], @@ -35,7 +45,6 @@ class PublicReaderSettings: # on_partition_session_start: Callable[["StubEvent"], None] = None # on_partition_session_stop: Callable[["StubEvent"], None] = None # on_partition_session_close: Callable[["StubEvent"], None] = None # todo? - # decoder: Union[Mapping[int, Callable[[bytes], bytes]], None] = None # deserializer: Union[Callable[[bytes], Any], None] = None # one_attempt_connection_timeout: Union[float, None] = 1 # connection_timeout: Union[float, None] = None diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index c1bb321d..5bf11cdd 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -1,6 +1,8 @@ from __future__ import annotations import asyncio +import concurrent.futures +import gzip import typing from asyncio import Task from collections import deque @@ -17,7 +19,7 @@ SupportedDriverType, GrpcWrapperAsyncIO, ) -from .._grpc.grpcwrapper.ydb_topic import StreamReadMessage +from .._grpc.grpcwrapper.ydb_topic import StreamReadMessage, Codec from .._errors import check_retriable_error @@ -25,6 +27,10 @@ class TopicReaderError(YdbError): pass +class TopicReaderUnexpectedCodec(YdbError): + pass + + class TopicReaderCommitToExpiredPartition(TopicReaderError): """ Commit message when partition read session are dropped. @@ -57,10 +63,10 @@ def __init__(self, driver: Driver, settings: topic_reader.PublicReaderSettings): self._reconnector = ReaderReconnector(driver, settings) async def __aenter__(self): - raise NotImplementedError() + return self async def __aexit__(self, exc_type, exc_val, exc_tb): - raise NotImplementedError() + await self.close() def __del__(self): if not self._closed: @@ -259,6 +265,7 @@ def _set_first_error(self, err: issues.Error): class ReaderStream: _static_id_counter = AtomicCounter() + _loop: asyncio.AbstractEventLoop _id: int _reader_reconnector_id: int _session_id: str @@ -267,6 +274,15 @@ class ReaderStream: _background_tasks: Set[asyncio.Task] _partition_sessions: Dict[int, datatypes.PartitionSession] _buffer_size_bytes: int # use for init request, then for debug purposes only + _decode_executor: concurrent.futures.Executor + _decoders: Dict[ + int, typing.Callable[[bytes], bytes] + ] # dict[codec_code] func(encoded_bytes)->decoded_bytes + + if typing.TYPE_CHECKING: + _batches_to_decode: asyncio.Queue[datatypes.PublicBatch] + else: + _batches_to_decode: asyncio.Queue _state_changed: asyncio.Event _closed: bool @@ -276,6 +292,7 @@ class ReaderStream: def __init__( self, reader_reconnector_id: int, settings: topic_reader.PublicReaderSettings ): + self._loop = asyncio.get_running_loop() self._id = ReaderStream._static_id_counter.inc_and_get() self._reader_reconnector_id = reader_reconnector_id self._session_id = "not initialized" @@ -284,10 +301,16 @@ def __init__( self._background_tasks = set() self._partition_sessions = dict() self._buffer_size_bytes = settings.buffer_size_bytes + self._decode_executor = settings.decoder_executor + + self._decoders = {Codec.CODEC_GZIP: gzip.decompress} + if settings.decoders: + self._decoders.update(settings.decoders) self._state_changed = asyncio.Event() self._closed = False self._first_error = asyncio.get_running_loop().create_future() + self._batches_to_decode = asyncio.Queue() self._message_batches = deque() @staticmethod @@ -324,8 +347,10 @@ async def _start( "Unexpected message after InitRequest: %s", init_response ) - read_messages_task = asyncio.create_task(self._read_messages_loop(stream)) - self._background_tasks.add(read_messages_task) + self._background_tasks.add( + asyncio.create_task(self._read_messages_loop(stream)) + ) + self._background_tasks.add(asyncio.create_task(self._decode_batches_loop())) async def wait_error(self): raise await self._first_error @@ -486,10 +511,12 @@ def _on_partition_session_stop( ) def _on_read_response(self, message: StreamReadMessage.ReadResponse): - batches = self._read_response_to_batches(message) - self._message_batches.extend(batches) self._buffer_consume_bytes(message.bytes_size) + batches = self._read_response_to_batches(message) + for batch in batches: + self._batches_to_decode.put_nowait(batch) + def _on_commit_response(self, message: StreamReadMessage.CommitOffsetResponse): for partition_offset in message.partitions_committed_offsets: session = self._partition_sessions.get( @@ -561,12 +588,44 @@ def _read_response_to_batches( messages=messages, _partition_session=partition_session, _bytes_size=bytes_per_batch, + _codec=Codec(server_batch.codec), ) batches.append(batch) batches[-1]._bytes_size += additional_bytes_to_last_batch return batches + async def _decode_batches_loop(self): + while True: + batch = await self._batches_to_decode.get() + await self._decode_batch_inplace(batch) + self._message_batches.append(batch) + self._state_changed.set() + + async def _decode_batch_inplace(self, batch): + if batch._codec == Codec.CODEC_RAW: + return + + try: + decode_func = self._decoders[batch._codec] + except KeyError: + raise TopicReaderUnexpectedCodec( + "Receive message with unexpected codec: %s" % batch._codec + ) + + decode_data_futures = [] + for message in batch.messages: + future = self._loop.run_in_executor( + self._decode_executor, decode_func, message.data + ) + decode_data_futures.append(future) + + decoded_data = await asyncio.gather(*decode_data_futures) + for index, message in enumerate(batch.messages): + message.data = decoded_data[index] + + batch._codec = Codec.CODEC_RAW + def _set_first_error(self, err: YdbError): try: self._first_error.set_result(err) diff --git a/ydb/_topic_reader/topic_reader_asyncio_test.py b/ydb/_topic_reader/topic_reader_asyncio_test.py index e4609ea0..917fff21 100644 --- a/ydb/_topic_reader/topic_reader_asyncio_test.py +++ b/ydb/_topic_reader/topic_reader_asyncio_test.py @@ -1,5 +1,8 @@ import asyncio +import concurrent.futures +import copy import datetime +import gzip import typing from dataclasses import dataclass from unittest import mock @@ -36,10 +39,20 @@ def handler(loop, context): @pytest.fixture() -def default_reader_settings(): +def default_executor(): + executor = concurrent.futures.ThreadPoolExecutor( + max_workers=2, thread_name_prefix="decoder_executor" + ) + yield executor + executor.shutdown() + + +@pytest.fixture() +def default_reader_settings(default_executor): return PublicReaderSettings( consumer="test-consumer", topic="test-topic", + decoder_executor=default_executor, ) @@ -358,6 +371,149 @@ async def test_commit_ranges_for_received_messages( received = stream_reader_started.receive_batch_nowait().messages assert received == [m2] + # noinspection PyTypeChecker + @pytest.mark.parametrize( + "batch,data_out", + [ + ( + PublicBatch( + session_metadata={}, + messages=[ + PublicMessage( + seqno=1, + created_at=datetime.datetime(2023, 3, 14, 15, 41), + message_group_id="", + session_metadata={}, + offset=1, + written_at=datetime.datetime(2023, 3, 14, 15, 42), + producer_id="asd", + data=rb"123", + _partition_session=None, + _commit_start_offset=5, + _commit_end_offset=15, + ) + ], + _partition_session=None, + _bytes_size=0, + _codec=Codec.CODEC_RAW, + ), + [bytes(rb"123")], + ), + ( + PublicBatch( + session_metadata={}, + messages=[ + PublicMessage( + seqno=1, + created_at=datetime.datetime(2023, 3, 14, 15, 41), + message_group_id="", + session_metadata={}, + offset=1, + written_at=datetime.datetime(2023, 3, 14, 15, 42), + producer_id="asd", + data=gzip.compress(rb"123"), + _partition_session=None, + _commit_start_offset=5, + _commit_end_offset=15, + ) + ], + _partition_session=None, + _bytes_size=0, + _codec=Codec.CODEC_GZIP, + ), + [bytes(rb"123")], + ), + ( + PublicBatch( + session_metadata={}, + messages=[ + PublicMessage( + seqno=1, + created_at=datetime.datetime(2023, 3, 14, 15, 41), + message_group_id="", + session_metadata={}, + offset=1, + written_at=datetime.datetime(2023, 3, 14, 15, 42), + producer_id="asd", + data=rb"123", + _partition_session=None, + _commit_start_offset=5, + _commit_end_offset=15, + ), + PublicMessage( + seqno=1, + created_at=datetime.datetime(2023, 3, 14, 15, 41), + message_group_id="", + session_metadata={}, + offset=1, + written_at=datetime.datetime(2023, 3, 14, 15, 42), + producer_id="asd", + data=rb"456", + _partition_session=None, + _commit_start_offset=5, + _commit_end_offset=15, + ), + ], + _partition_session=None, + _bytes_size=0, + _codec=Codec.CODEC_RAW, + ), + [bytes(rb"123"), bytes(rb"456")], + ), + ( + PublicBatch( + session_metadata={}, + messages=[ + PublicMessage( + seqno=1, + created_at=datetime.datetime(2023, 3, 14, 15, 41), + message_group_id="", + session_metadata={}, + offset=1, + written_at=datetime.datetime(2023, 3, 14, 15, 42), + producer_id="asd", + data=gzip.compress(rb"123"), + _partition_session=None, + _commit_start_offset=5, + _commit_end_offset=15, + ), + PublicMessage( + seqno=1, + created_at=datetime.datetime(2023, 3, 14, 15, 41), + message_group_id="", + session_metadata={}, + offset=1, + written_at=datetime.datetime(2023, 3, 14, 15, 42), + producer_id="asd", + data=gzip.compress(rb"456"), + _partition_session=None, + _commit_start_offset=5, + _commit_end_offset=15, + ), + ], + _partition_session=None, + _bytes_size=0, + _codec=Codec.CODEC_GZIP, + ), + [bytes(rb"123"), bytes(rb"456")], + ), + ], + ) + async def test_decode_loop( + self, stream_reader, batch: PublicBatch, data_out: typing.List[bytes] + ): + assert len(batch.messages) == len(data_out) + + expected = copy.deepcopy(batch) + expected._codec = Codec.CODEC_RAW + + for index, data in enumerate(data_out): + expected.messages[index].data = data + + await wait_for_fast(stream_reader._decode_batch_inplace(batch)) + + assert batch == expected + async def test_error_from_status_code( self, stream, stream_reader_finish_with_error ): @@ -620,6 +776,7 @@ def reader_batch_count(): ], _partition_session=partition_session, _bytes_size=bytes_size, + _codec=Codec.CODEC_RAW, ) async def test_read_batches( @@ -743,6 +900,7 @@ async def test_read_batches( ], _partition_session=partition_session, _bytes_size=1, + _codec=Codec.CODEC_RAW, ) assert last1 == PublicBatch( session_metadata=session_meta, @@ -763,6 +921,7 @@ async def test_read_batches( ], _partition_session=second_partition_session, _bytes_size=1, + _codec=Codec.CODEC_RAW, ) assert last2 == PublicBatch( session_metadata=session_meta2, @@ -796,6 +955,7 @@ async def test_read_batches( ], _partition_session=second_partition_session, _bytes_size=1, + _codec=Codec.CODEC_RAW, ) async def test_receive_batch_nowait(self, stream, stream_reader, partition_session): @@ -815,6 +975,7 @@ async def test_receive_batch_nowait(self, stream, stream_reader, partition_sessi messages=[mess1], _partition_session=mess1._partition_session, _bytes_size=self.default_batch_size, + _codec=Codec.CODEC_RAW, ) received = stream_reader.receive_batch_nowait() @@ -823,6 +984,7 @@ async def test_receive_batch_nowait(self, stream, stream_reader, partition_sessi messages=[mess2], _partition_session=mess2._partition_session, _bytes_size=self.default_batch_size, + _codec=Codec.CODEC_RAW, ) assert ( diff --git a/ydb/_topic_reader/topic_reader_sync.py b/ydb/_topic_reader/topic_reader_sync.py index 9652cb84..ec243337 100644 --- a/ydb/_topic_reader/topic_reader_sync.py +++ b/ydb/_topic_reader/topic_reader_sync.py @@ -46,6 +46,12 @@ async def create_reader(): def __del__(self): self.close() + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + def _call(self, coro) -> concurrent.futures.Future: """ Call async function and return future fow wait result diff --git a/ydb/_topic_writer/topic_writer.py b/ydb/_topic_writer/topic_writer.py index 78349a88..dab0371f 100644 --- a/ydb/_topic_writer/topic_writer.py +++ b/ydb/_topic_writer/topic_writer.py @@ -9,7 +9,7 @@ import typing import ydb.aio -from .._grpc.grpcwrapper.ydb_topic import Codec, StreamWriteMessage +from .._grpc.grpcwrapper.ydb_topic import StreamWriteMessage from .._grpc.grpcwrapper.common_utils import IToProto from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec @@ -208,7 +208,7 @@ def messages_to_proto_requests( req = StreamWriteMessage.FromClient( StreamWriteMessage.WriteRequest( messages=[msg.to_message_data()], - codec=Codec.CODEC_RAW, + codec=msg.codec, ) ) res.append(req) diff --git a/ydb/topic.py b/ydb/topic.py index 3ccdda08..efe62219 100644 --- a/ydb/topic.py +++ b/ydb/topic.py @@ -140,6 +140,11 @@ def reader( consumer: str, topic: str, buffer_size_bytes: int = 50 * 1024 * 1024, + # decoders: map[codec_code] func(encoded_bytes)->decoded_bytes + decoders: Union[Mapping[int, Callable[[bytes], bytes]], None] = None, + decoder_executor: Optional[ + concurrent.futures.Executor + ] = None, # default shared client executor pool # on_commit: Callable[["Events.OnCommit"], None] = None # on_get_partition_start_offset: Callable[ # ["Events.OnPartitionGetStartOffsetRequest"], @@ -148,15 +153,20 @@ def reader( # on_partition_session_start: Callable[["StubEvent"], None] = None # on_partition_session_stop: Callable[["StubEvent"], None] = None # on_partition_session_close: Callable[["StubEvent"], None] = None # todo? - # decoder: Union[Mapping[int, Callable[[bytes], bytes]], None] = None # deserializer: Union[Callable[[bytes], Any], None] = None # one_attempt_connection_timeout: Union[float, None] = 1 # connection_timeout: Union[float, None] = None # retry_policy: Union["RetryPolicy", None] = None ) -> TopicReaderAsyncIO: + + if not decoder_executor: + decoder_executor = self._executor + args = locals() del args["self"] + settings = TopicReaderSettings(**args) + return TopicReaderAsyncIO(self._driver, settings) def writer( @@ -299,6 +309,11 @@ def reader( consumer: str, topic: str, buffer_size_bytes: int = 50 * 1024 * 1024, + # decoders: map[codec_code] func(encoded_bytes)->decoded_bytes + decoders: Union[Mapping[int, Callable[[bytes], bytes]], None] = None, + decoder_executor: Optional[ + concurrent.futures.Executor + ] = None, # default shared client executor pool # on_commit: Callable[["Events.OnCommit"], None] = None # on_get_partition_start_offset: Callable[ # ["Events.OnPartitionGetStartOffsetRequest"], @@ -307,17 +322,20 @@ def reader( # on_partition_session_start: Callable[["StubEvent"], None] = None # on_partition_session_stop: Callable[["StubEvent"], None] = None # on_partition_session_close: Callable[["StubEvent"], None] = None # todo? - # decoder: Union[Mapping[int, Callable[[bytes], bytes]], None] = None # deserializer: Union[Callable[[bytes], Any], None] = None # one_attempt_connection_timeout: Union[float, None] = 1 # connection_timeout: Union[float, None] = None # retry_policy: Union["RetryPolicy", None] = None ) -> TopicReader: + if not decoder_executor: + decoder_executor = self._executor + args = locals() del args["self"] self._check_closed() settings = TopicReaderSettings(**args) + return TopicReader(self._driver, settings) def writer( From b3cd3b9908b3825ce192063685e945564f6ab46f Mon Sep 17 00:00:00 2001 From: Valeriya Popova Date: Mon, 13 Mar 2023 16:11:03 +0300 Subject: [PATCH 113/147] topic writer: add tests --- ydb/_topic_writer/topic_writer_asyncio_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ydb/_topic_writer/topic_writer_asyncio_test.py b/ydb/_topic_writer/topic_writer_asyncio_test.py index 73c959f9..145fd0ce 100644 --- a/ydb/_topic_writer/topic_writer_asyncio_test.py +++ b/ydb/_topic_writer/topic_writer_asyncio_test.py @@ -259,7 +259,7 @@ def get_second(self): return self.get_second() def _create(self): - writer = TestWriterAsyncIOReconnector.StreamWriterMock() + writer = TestWriterAsyncIOReconnector.StreamWriterMock(1, lambda: "") writer.last_seqno = TestWriterAsyncIOReconnector.init_last_seqno self._first.put_nowait(writer) self._second.put_nowait(writer) From 805d91ffb873dbb79df0c037a095e8b18fd727a4 Mon Sep 17 00:00:00 2001 From: Valeriya Popova Date: Mon, 13 Mar 2023 20:23:12 +0300 Subject: [PATCH 114/147] topic-reader: update auth-token loop --- ydb/_grpc/grpcwrapper/ydb_topic.py | 9 ++ ydb/_topic_reader/topic_reader.py | 1 + ydb/_topic_reader/topic_reader_asyncio.py | 119 ++++++++++++------ .../topic_reader_asyncio_test.py | 54 +++++++- .../topic_writer_asyncio_test.py | 2 +- 5 files changed, 139 insertions(+), 46 deletions(-) diff --git a/ydb/_grpc/grpcwrapper/ydb_topic.py b/ydb/_grpc/grpcwrapper/ydb_topic.py index ad8a8e72..4784d486 100644 --- a/ydb/_grpc/grpcwrapper/ydb_topic.py +++ b/ydb/_grpc/grpcwrapper/ydb_topic.py @@ -686,6 +686,8 @@ def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.FromClient: res.commit_offset_request.CopyFrom(self.client_message.to_proto()) elif isinstance(self.client_message, StreamReadMessage.InitRequest): res.init_request.CopyFrom(self.client_message.to_proto()) + elif isinstance(self.client_message, UpdateTokenRequest): + res.update_token_request.CopyFrom(self.client_message.to_proto()) elif isinstance( self.client_message, StreamReadMessage.StartPartitionSessionResponse ): @@ -737,6 +739,13 @@ def from_proto( msg.start_partition_session_request ), ) + elif mess_type == "update_token_response": + return StreamReadMessage.FromServer( + server_status=server_status, + server_message=UpdateTokenResponse.from_proto( + msg.update_token_response + ), + ) # todo replace exception to log raise NotImplementedError() diff --git a/ydb/_topic_reader/topic_reader.py b/ydb/_topic_reader/topic_reader.py index 14474be6..148d63b3 100644 --- a/ydb/_topic_reader/topic_reader.py +++ b/ydb/_topic_reader/topic_reader.py @@ -49,6 +49,7 @@ class PublicReaderSettings: # one_attempt_connection_timeout: Union[float, None] = 1 # connection_timeout: Union[float, None] = None # retry_policy: Union["RetryPolicy", None] = None + update_token_interval: Union[int, float] = 3600 def _init_message(self) -> StreamReadMessage.InitRequest: return StreamReadMessage.InitRequest( diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index 5bf11cdd..bb87d3cc 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -6,9 +6,9 @@ import typing from asyncio import Task from collections import deque -from typing import Optional, Set, Dict +from typing import Optional, Set, Dict, Union, Callable -from .. import _apis, issues, RetrySettings +from .. import _apis, issues from .._utilities import AtomicCounter from ..aio import Driver from ..issues import Error as YdbError, _process_response @@ -19,7 +19,12 @@ SupportedDriverType, GrpcWrapperAsyncIO, ) -from .._grpc.grpcwrapper.ydb_topic import StreamReadMessage, Codec +from .._grpc.grpcwrapper.ydb_topic import ( + StreamReadMessage, + UpdateTokenRequest, + UpdateTokenResponse, + Codec, +) from .._errors import check_retriable_error @@ -194,7 +199,6 @@ def __init__(self, driver: Driver, settings: topic_reader.PublicReaderSettings): self._settings = settings self._driver = driver self._background_tasks = set() - self._retry_settins = RetrySettings(idempotent=True) # get from settings self._state_changed = asyncio.Event() self._stream_reader = None @@ -227,7 +231,7 @@ async def wait_message(self): if self._first_error.done(): raise self._first_error.result() - if self._stream_reader is not None: + if self._stream_reader: try: await self._stream_reader.wait_messages() return @@ -289,8 +293,15 @@ class ReaderStream: _message_batches: typing.Deque[datatypes.PublicBatch] _first_error: asyncio.Future[YdbError] + _update_token_interval: Union[int, float] + _update_token_event: asyncio.Event + _get_token_function: Callable[[], str] + def __init__( - self, reader_reconnector_id: int, settings: topic_reader.PublicReaderSettings + self, + reader_reconnector_id: int, + settings: topic_reader.PublicReaderSettings, + get_token_function: Optional[Callable[[], str]] = None, ): self._loop = asyncio.get_running_loop() self._id = ReaderStream._static_id_counter.inc_and_get() @@ -313,6 +324,10 @@ def __init__( self._batches_to_decode = asyncio.Queue() self._message_batches = deque() + self._update_token_interval = settings.update_token_interval + self._get_token_function = get_token_function + self._update_token_event = asyncio.Event() + @staticmethod async def create( reader_reconnector_id: int, @@ -325,7 +340,12 @@ async def create( driver, _apis.TopicService.Stub, _apis.TopicService.StreamRead ) - reader = ReaderStream(reader_reconnector_id, settings) + creds = driver._credentials + reader = ReaderStream( + reader_reconnector_id, + settings, + get_token_function=creds.get_auth_token if creds else None, + ) await reader._start(stream, settings._init_message()) return reader @@ -347,35 +367,41 @@ async def _start( "Unexpected message after InitRequest: %s", init_response ) + self._update_token_event.set() + self._background_tasks.add( - asyncio.create_task(self._read_messages_loop(stream)) + asyncio.create_task(self._read_messages_loop(), name="read_messages_loop") ) self._background_tasks.add(asyncio.create_task(self._decode_batches_loop())) + if self._get_token_function: + self._background_tasks.add( + asyncio.create_task(self._update_token_loop(), name="update_token_loop") + ) async def wait_error(self): raise await self._first_error async def wait_messages(self): while True: - if self._get_first_error() is not None: + if self._get_first_error(): raise self._get_first_error() - if len(self._message_batches) > 0: + if self._message_batches: return await self._state_changed.wait() self._state_changed.clear() def receive_batch_nowait(self): - if self._get_first_error() is not None: + if self._get_first_error(): raise self._get_first_error() - try: - batch = self._message_batches.popleft() - self._buffer_release_bytes(batch._bytes_size) - return batch - except IndexError: - return None + if not self._message_batches: + return + + batch = self._message_batches.popleft() + self._buffer_release_bytes(batch._bytes_size) + return batch def commit( self, batch: datatypes.ICommittable @@ -413,7 +439,7 @@ def commit( return waiter - async def _read_messages_loop(self, stream: IGrpcWrapperAsyncIO): + async def _read_messages_loop(self): try: self._stream.write( StreamReadMessage.FromClient( @@ -423,24 +449,34 @@ async def _read_messages_loop(self, stream: IGrpcWrapperAsyncIO): ) ) while True: - message = await stream.receive() # type: StreamReadMessage.FromServer + message = ( + await self._stream.receive() + ) # type: StreamReadMessage.FromServer _process_response(message.server_status) + if isinstance(message.server_message, StreamReadMessage.ReadResponse): self._on_read_response(message.server_message) + elif isinstance( message.server_message, StreamReadMessage.CommitOffsetResponse ): self._on_commit_response(message.server_message) + elif isinstance( message.server_message, StreamReadMessage.StartPartitionSessionRequest, ): self._on_start_partition_session(message.server_message) + elif isinstance( message.server_message, StreamReadMessage.StopPartitionSessionRequest, ): self._on_partition_session_stop(message.server_message) + + elif isinstance(message.server_message, UpdateTokenResponse): + self._update_token_event.set() + else: raise NotImplementedError( "Unexpected type of StreamReadMessage.FromServer message: %s" @@ -450,7 +486,20 @@ async def _read_messages_loop(self, stream: IGrpcWrapperAsyncIO): self._state_changed.set() except Exception as e: self._set_first_error(e) - raise e + raise + + async def _update_token_loop(self): + while True: + await asyncio.sleep(self._update_token_interval) + await self._update_token(token=self._get_token_function()) + + async def _update_token(self, token: str): + await self._update_token_event.wait() + try: + msg = StreamReadMessage.FromClient(UpdateTokenRequest(token)) + self._stream.write(msg) + finally: + self._update_token_event.clear() def _on_start_partition_session( self, message: StreamReadMessage.StartPartitionSessionRequest @@ -491,14 +540,12 @@ def _on_start_partition_session( def _on_partition_session_stop( self, message: StreamReadMessage.StopPartitionSessionRequest ): - try: - partition = self._partition_sessions[message.partition_session_id] - except KeyError: + if message.partition_session_id not in self._partition_sessions: # may if receive stop partition with graceful=false after response on stop partition # with graceful=true and remove partition from internal dictionary return - del self._partition_sessions[message.partition_session_id] + partition = self._partition_sessions.pop(message.partition_session_id) partition.close() if message.graceful: @@ -519,11 +566,10 @@ def _on_read_response(self, message: StreamReadMessage.ReadResponse): def _on_commit_response(self, message: StreamReadMessage.CommitOffsetResponse): for partition_offset in message.partitions_committed_offsets: - session = self._partition_sessions.get( - partition_offset.partition_session_id - ) - if session is None: + if partition_offset.partition_session_id not in self._partition_sessions: continue + + session = self._partition_sessions[partition_offset.partition_session_id] session.ack_notify(partition_offset.committed_offset) def _buffer_consume_bytes(self, bytes_size): @@ -544,12 +590,9 @@ def _read_response_to_batches( ) -> typing.List[datatypes.PublicBatch]: batches = [] - batch_count = 0 - for partition_data in message.partition_data: - batch_count += len(partition_data.batches) - + batch_count = sum(len(p.batches) for p in message.partition_data) if batch_count == 0: - return [] + return batches bytes_per_batch = message.bytes_size // batch_count additional_bytes_to_last_batch = ( @@ -577,12 +620,11 @@ def _read_response_to_batches( _commit_end_offset=message_data.offset + 1, ) messages.append(mess) - partition_session._next_message_start_commit_offset = ( mess._commit_end_offset ) - if len(messages) > 0: + if messages: batch = datatypes.PublicBatch( session_metadata=server_batch.write_session_meta, messages=messages, @@ -637,14 +679,12 @@ def _set_first_error(self, err: YdbError): def _get_first_error(self) -> Optional[YdbError]: if self._first_error.done(): return self._first_error.result() - else: - return None async def close(self): if self._closed: - raise TopicReaderError(message="Double closed ReaderStream") - + return self._closed = True + self._set_first_error(TopicReaderStreamClosedError()) self._state_changed.set() self._stream.close() @@ -654,5 +694,4 @@ async def close(self): for task in self._background_tasks: task.cancel() - await asyncio.wait(self._background_tasks) diff --git a/ydb/_topic_reader/topic_reader_asyncio_test.py b/ydb/_topic_reader/topic_reader_asyncio_test.py index 917fff21..214e1bd6 100644 --- a/ydb/_topic_reader/topic_reader_asyncio_test.py +++ b/ydb/_topic_reader/topic_reader_asyncio_test.py @@ -15,7 +15,13 @@ from .topic_reader import PublicReaderSettings from .topic_reader_asyncio import ReaderStream, ReaderReconnector from .._grpc.grpcwrapper.common_utils import SupportedDriverType, ServerStatus -from .._grpc.grpcwrapper.ydb_topic import StreamReadMessage, Codec, OffsetsRange +from .._grpc.grpcwrapper.ydb_topic import ( + StreamReadMessage, + Codec, + OffsetsRange, + UpdateTokenRequest, + UpdateTokenResponse, +) from .._topic_common.test_helpers import ( StreamMock, wait_condition, @@ -121,12 +127,14 @@ def second_partition_session( @pytest.fixture() async def stream_reader_started( - self, - stream, - default_reader_settings, + self, stream, default_reader_settings, request ) -> ReaderStream: + + settings, token_getter = getattr( + request, "param", (default_reader_settings, None) + ) reader = ReaderStream( - self.default_reader_reconnector_id, default_reader_settings + self.default_reader_reconnector_id, settings, token_getter ) init_message = object() @@ -1004,6 +1012,42 @@ async def test_receive_batch_nowait(self, stream, stream_reader, partition_sessi with pytest.raises(asyncio.QueueEmpty): stream.from_client.get_nowait() + @pytest.mark.parametrize( + "stream_reader_started", + [ + ( + PublicReaderSettings( + consumer="test-consumer", + topic="test-topic", + update_token_interval=0.1, + ), + lambda: "foo-bar", + ) + ], + indirect=True, + ) + async def test_update_token(self, stream, stream_reader_started: ReaderStream): + assert stream.from_client.empty() + + expected = StreamReadMessage.FromClient(UpdateTokenRequest(token="foo-bar")) + got = await wait_for_fast(stream.from_client.get()) + assert expected == got, "send update token request" + + await asyncio.sleep(0.2) + assert stream.from_client.empty(), "no answer - no new update request" + + await stream.from_server.put( + StreamReadMessage.FromServer( + server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), + server_message=UpdateTokenResponse(), + ) + ) + + got = await wait_for_fast(stream.from_client.get()) + assert expected == got + + await stream_reader_started.close() + @pytest.mark.asyncio class TestReaderReconnector: diff --git a/ydb/_topic_writer/topic_writer_asyncio_test.py b/ydb/_topic_writer/topic_writer_asyncio_test.py index 145fd0ce..73c959f9 100644 --- a/ydb/_topic_writer/topic_writer_asyncio_test.py +++ b/ydb/_topic_writer/topic_writer_asyncio_test.py @@ -259,7 +259,7 @@ def get_second(self): return self.get_second() def _create(self): - writer = TestWriterAsyncIOReconnector.StreamWriterMock(1, lambda: "") + writer = TestWriterAsyncIOReconnector.StreamWriterMock() writer.last_seqno = TestWriterAsyncIOReconnector.init_last_seqno self._first.put_nowait(writer) self._second.put_nowait(writer) From 618f25c29d08362c901fc45484b7d0abed2d7ca7 Mon Sep 17 00:00:00 2001 From: Valeriya Popova Date: Thu, 16 Mar 2023 16:43:42 +0300 Subject: [PATCH 115/147] better tests fot topic_writer/topic_reader --- .../topic_reader_asyncio_test.py | 46 ++++++++----------- .../topic_writer_asyncio_test.py | 35 +++++++------- 2 files changed, 37 insertions(+), 44 deletions(-) diff --git a/ydb/_topic_reader/topic_reader_asyncio_test.py b/ydb/_topic_reader/topic_reader_asyncio_test.py index 214e1bd6..2924cb4d 100644 --- a/ydb/_topic_reader/topic_reader_asyncio_test.py +++ b/ydb/_topic_reader/topic_reader_asyncio_test.py @@ -125,17 +125,8 @@ def second_partition_session( return stream_reader_started._partition_sessions[partition_session.id] - @pytest.fixture() - async def stream_reader_started( - self, stream, default_reader_settings, request - ) -> ReaderStream: - - settings, token_getter = getattr( - request, "param", (default_reader_settings, None) - ) - reader = ReaderStream( - self.default_reader_reconnector_id, settings, token_getter - ) + async def get_started_reader(self, stream, *args, **kwargs) -> ReaderStream: + reader = ReaderStream(self.default_reader_reconnector_id, *args, **kwargs) init_message = object() # noinspection PyTypeChecker @@ -164,6 +155,12 @@ async def stream_reader_started( return reader + @pytest.fixture() + async def stream_reader_started( + self, stream, default_reader_settings + ) -> ReaderStream: + return await self.get_started_reader(stream, default_reader_settings) + @pytest.fixture() async def stream_reader(self, stream_reader_started: ReaderStream): yield stream_reader_started @@ -1012,21 +1009,16 @@ async def test_receive_batch_nowait(self, stream, stream_reader, partition_sessi with pytest.raises(asyncio.QueueEmpty): stream.from_client.get_nowait() - @pytest.mark.parametrize( - "stream_reader_started", - [ - ( - PublicReaderSettings( - consumer="test-consumer", - topic="test-topic", - update_token_interval=0.1, - ), - lambda: "foo-bar", - ) - ], - indirect=True, - ) - async def test_update_token(self, stream, stream_reader_started: ReaderStream): + async def test_update_token(self, stream): + settings = PublicReaderSettings( + consumer="test-consumer", + topic="test-topic", + update_token_interval=0.1, + ) + reader = await self.get_started_reader( + stream, settings, get_token_function=lambda: "foo-bar" + ) + assert stream.from_client.empty() expected = StreamReadMessage.FromClient(UpdateTokenRequest(token="foo-bar")) @@ -1046,7 +1038,7 @@ async def test_update_token(self, stream, stream_reader_started: ReaderStream): got = await wait_for_fast(stream.from_client.get()) assert expected == got - await stream_reader_started.close() + await reader.close() @pytest.mark.asyncio diff --git a/ydb/_topic_writer/topic_writer_asyncio_test.py b/ydb/_topic_writer/topic_writer_asyncio_test.py index 73c959f9..b5b3fcc8 100644 --- a/ydb/_topic_writer/topic_writer_asyncio_test.py +++ b/ydb/_topic_writer/topic_writer_asyncio_test.py @@ -63,8 +63,8 @@ def stream(self): yield stream stream.close() - @pytest.fixture - async def writer_and_stream(self, stream, request) -> WriterWithMockedStream: + @staticmethod + async def get_started_writer(stream, *args, **kwargs) -> WriterAsyncIOStream: stream.from_server.put_nowait( StreamWriteMessage.InitResponse( last_seq_no=4, @@ -75,9 +75,7 @@ async def writer_and_stream(self, stream, request) -> WriterWithMockedStream: ) ) - params = getattr(request, "param", ()) - writer = WriterAsyncIOStream(*params) - + writer = WriterAsyncIOStream(*args, **kwargs) await writer._start( stream, init_message=StreamWriteMessage.InitRequest( @@ -91,6 +89,11 @@ async def writer_and_stream(self, stream, request) -> WriterWithMockedStream: ), ) await stream.from_client.get() + return writer + + @pytest.fixture + async def writer_and_stream(self, stream) -> WriterWithMockedStream: + writer = await self.get_started_writer(stream) yield TestWriterAsyncIOStream.WriterWithMockedStream( stream=stream, @@ -164,25 +167,23 @@ async def test_write_a_message(self, writer_and_stream: WriterWithMockedStream): sent_message = await writer_and_stream.stream.from_client.get() assert expected_message == sent_message - @pytest.mark.parametrize( - "writer_and_stream", [(0.1, lambda: "foo-bar")], indirect=True - ) - async def test_update_token(self, writer_and_stream: WriterWithMockedStream): - assert writer_and_stream.stream.from_client.empty() + async def test_update_token(self, stream: StreamMock): + writer = await self.get_started_writer( + stream, update_token_interval=0.1, get_token_function=lambda: "foo-bar" + ) + assert stream.from_client.empty() expected = StreamWriteMessage.FromClient(UpdateTokenRequest(token="foo-bar")) - got = await wait_for_fast(writer_and_stream.stream.from_client.get()) + got = await wait_for_fast(stream.from_client.get()) assert expected == got, "send update token request" await asyncio.sleep(0.2) - assert ( - writer_and_stream.stream.from_client.empty() - ), "no answer - no new update request" + assert stream.from_client.empty(), "no answer - no new update request" - await writer_and_stream.stream.from_server.put(UpdateTokenResponse()) - receive_task = asyncio.create_task(writer_and_stream.writer.receive()) + await stream.from_server.put(UpdateTokenResponse()) + receive_task = asyncio.create_task(writer.receive()) - got = await wait_for_fast(writer_and_stream.stream.from_client.get()) + got = await wait_for_fast(stream.from_client.get()) assert expected == got receive_task.cancel() From 350f92f69699eedb111cab184c93f035416dc0e2 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Fri, 17 Mar 2023 18:49:48 +0300 Subject: [PATCH 116/147] Update ydb_version.py --- ydb/ydb_version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ydb/ydb_version.py b/ydb/ydb_version.py index e783fc72..b0649f64 100644 --- a/ydb/ydb_version.py +++ b/ydb/ydb_version.py @@ -1 +1 @@ -VERSION = "2.13.3" +VERSION = "3.0.1b8" From 0d48427c5412ee68b049bc53bd3df66f7b9843a5 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Sat, 18 Mar 2023 15:43:29 +0300 Subject: [PATCH 117/147] init read message --- ydb/_topic_reader/datatypes.py | 11 + ydb/_topic_reader/topic_reader_asyncio.py | 35 +- .../topic_reader_asyncio_test.py | 461 +++++++++++------- ydb/_topic_reader/topic_reader_sync.py | 31 +- 4 files changed, 358 insertions(+), 180 deletions(-) diff --git a/ydb/_topic_reader/datatypes.py b/ydb/_topic_reader/datatypes.py index 3845995f..860525ab 100644 --- a/ydb/_topic_reader/datatypes.py +++ b/ydb/_topic_reader/datatypes.py @@ -179,6 +179,9 @@ def _commit_get_offsets_range(self) -> OffsetsRange: self.messages[-1]._commit_get_offsets_range().end, ) + def is_empty(self) -> bool: + return len(self.messages) == 0 + # ISessionAlive implementation @property def is_alive(self) -> bool: @@ -187,3 +190,11 @@ def is_alive(self) -> bool: state == PartitionSession.State.Active or state == PartitionSession.State.GracefulShutdown ) + + def pop_message(self) -> PublicMessage: + if len(self.messages) == 0: + raise IndexError() + + res = self.messages[0] + self.messages = self.messages[1:] + return res diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index bb87d3cc..c74f7d09 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -95,14 +95,6 @@ def messages( """ raise NotImplementedError() - async def receive_message(self) -> typing.Union[topic_reader.PublicMessage, None]: - """ - Block until receive new message - - use asyncio.wait_for for wait with timeout. - """ - raise NotImplementedError() - def batches( self, *, @@ -133,6 +125,15 @@ async def receive_batch( await self._reconnector.wait_message() return self._reconnector.receive_batch_nowait() + async def receive_message(self) -> typing.Union[datatypes.PublicMessage, None]: + """ + Block until receive new message + + use asyncio.wait_for for wait with timeout. + """ + await self._reconnector.wait_message() + return self._reconnector.receive_message_nowait() + async def commit_on_exit( self, mess: datatypes.ICommittable ) -> typing.AsyncContextManager: @@ -244,6 +245,9 @@ async def wait_message(self): def receive_batch_nowait(self): return self._stream_reader.receive_batch_nowait() + def receive_message_nowait(self): + return self._stream_reader.receive_message_nowait() + def commit( self, batch: datatypes.ICommittable ) -> datatypes.PartitionSession.CommitAckWaiter: @@ -397,12 +401,25 @@ def receive_batch_nowait(self): raise self._get_first_error() if not self._message_batches: - return + return None batch = self._message_batches.popleft() self._buffer_release_bytes(batch._bytes_size) return batch + def receive_message_nowait(self): + try: + batch = self._message_batches[0] + message = batch.pop_message() + except IndexError: + return None + + if batch.is_empty(): + self._message_batches.popleft() + + return message + + def commit( self, batch: datatypes.ICommittable ) -> datatypes.PartitionSession.CommitAckWaiter: diff --git a/ydb/_topic_reader/topic_reader_asyncio_test.py b/ydb/_topic_reader/topic_reader_asyncio_test.py index 2924cb4d..a153e25c 100644 --- a/ydb/_topic_reader/topic_reader_asyncio_test.py +++ b/ydb/_topic_reader/topic_reader_asyncio_test.py @@ -4,6 +4,7 @@ import datetime import gzip import typing +from collections import deque from dataclasses import dataclass from unittest import mock @@ -53,6 +54,34 @@ def default_executor(): executor.shutdown() +def stub_partition_session(): + return datatypes.PartitionSession( + id=0, + state=datatypes.PartitionSession.State.Active, + topic_path="asd", + partition_id=1, + committed_offset=0, + reader_reconnector_id=415, + reader_stream_id=513, + ) + + +def stub_message(id: int): + return PublicMessage( + seqno=id, + created_at=datetime.datetime(2023, 3, 18, 14, 15), + message_group_id="", + session_metadata={}, + offset=0, + written_at=datetime.datetime(2023, 3, 18, 14, 15), + producer_id="", + data=bytes(), + _partition_session=stub_partition_session(), + _commit_start_offset=0, + _commit_end_offset=1, + ) + + @pytest.fixture() def default_reader_settings(default_executor): return PublicReaderSettings( @@ -85,7 +114,7 @@ def stream(self): @pytest.fixture() def partition_session( - self, default_reader_settings, stream_reader_started: ReaderStream + self, default_reader_settings, stream_reader_started: ReaderStream ) -> datatypes.PartitionSession: partition_session = datatypes.PartitionSession( id=2, @@ -106,7 +135,7 @@ def partition_session( @pytest.fixture() def second_partition_session( - self, default_reader_settings, stream_reader_started: ReaderStream + self, default_reader_settings, stream_reader_started: ReaderStream ): partition_session = datatypes.PartitionSession( id=12, @@ -157,7 +186,7 @@ async def get_started_reader(self, stream, *args, **kwargs) -> ReaderStream: @pytest.fixture() async def stream_reader_started( - self, stream, default_reader_settings + self, stream, default_reader_settings ) -> ReaderStream: return await self.get_started_reader(stream, default_reader_settings) @@ -170,7 +199,7 @@ async def stream_reader(self, stream_reader_started: ReaderStream): @pytest.fixture() async def stream_reader_finish_with_error( - self, stream_reader_started: ReaderStream + self, stream_reader_started: ReaderStream ): yield stream_reader_started @@ -179,7 +208,7 @@ async def stream_reader_finish_with_error( @staticmethod def create_message( - partition_session: datatypes.PartitionSession, seqno: int, offset_delta: int + partition_session: typing.Optional[datatypes.PartitionSession], seqno: int, offset_delta: int ): return PublicMessage( seqno=seqno, @@ -187,17 +216,17 @@ def create_message( message_group_id="test-message-group", session_metadata={}, offset=partition_session._next_message_start_commit_offset - + offset_delta - - 1, + + offset_delta + - 1, written_at=datetime.datetime(2023, 2, 3, 14, 16), producer_id="test-producer-id", data=bytes(), _partition_session=partition_session, _commit_start_offset=partition_session._next_message_start_commit_offset - + offset_delta - - 1, + + offset_delta + - 1, _commit_end_offset=partition_session._next_message_start_commit_offset - + offset_delta, + + offset_delta, ) async def send_message(self, stream_reader, message: PublicMessage): @@ -257,28 +286,28 @@ class TestError(Exception): "commit,send_range", [ ( - OffsetsRange( - partition_session_committed_offset, - partition_session_committed_offset + 1, - ), - True, + OffsetsRange( + partition_session_committed_offset, + partition_session_committed_offset + 1, + ), + True, ), ( - OffsetsRange( - partition_session_committed_offset - 1, - partition_session_committed_offset, - ), - False, + OffsetsRange( + partition_session_committed_offset - 1, + partition_session_committed_offset, + ), + False, ), ], ) async def test_send_commit_messages( - self, - stream, - stream_reader: ReaderStream, - partition_session, - commit: OffsetsRange, - send_range: bool, + self, + stream, + stream_reader: ReaderStream, + partition_session, + commit: OffsetsRange, + send_range: bool, ): @dataclass class Commitable(datatypes.ICommittable): @@ -318,7 +347,7 @@ async def wait_message(): assert start_ack_waiters == partition_session._ack_waiters async def test_commit_ack_received( - self, stream_reader, stream, partition_session, second_partition_session + self, stream_reader, stream, partition_session, second_partition_session ): offset1 = self.partition_session_committed_offset + 1 waiter1 = partition_session.add_waiter(offset1) @@ -348,7 +377,7 @@ async def test_commit_ack_received( await wait_for_fast(waiter2.future) async def test_close_ack_waiters_when_close_stream_reader( - self, stream_reader_started: ReaderStream, partition_session + self, stream_reader_started: ReaderStream, partition_session ): waiter = partition_session.add_waiter( self.partition_session_committed_offset + 1 @@ -359,7 +388,7 @@ async def test_close_ack_waiters_when_close_stream_reader( waiter.future.result() async def test_commit_ranges_for_received_messages( - self, stream, stream_reader_started: ReaderStream, partition_session + self, stream, stream_reader_started: ReaderStream, partition_session ): m1 = self.create_message(partition_session, 1, 1) m2 = self.create_message(partition_session, 2, 10) @@ -381,131 +410,131 @@ async def test_commit_ranges_for_received_messages( "batch,data_out", [ ( - PublicBatch( - session_metadata={}, - messages=[ - PublicMessage( - seqno=1, - created_at=datetime.datetime(2023, 3, 14, 15, 41), - message_group_id="", - session_metadata={}, - offset=1, - written_at=datetime.datetime(2023, 3, 14, 15, 42), - producer_id="asd", - data=rb"123", - _partition_session=None, - _commit_start_offset=5, - _commit_end_offset=15, - ) - ], - _partition_session=None, - _bytes_size=0, - _codec=Codec.CODEC_RAW, - ), - [bytes(rb"123")], + PublicBatch( + session_metadata={}, + messages=[ + PublicMessage( + seqno=1, + created_at=datetime.datetime(2023, 3, 14, 15, 41), + message_group_id="", + session_metadata={}, + offset=1, + written_at=datetime.datetime(2023, 3, 14, 15, 42), + producer_id="asd", + data=rb"123", + _partition_session=None, + _commit_start_offset=5, + _commit_end_offset=15, + ) + ], + _partition_session=None, + _bytes_size=0, + _codec=Codec.CODEC_RAW, + ), + [bytes(rb"123")], ), ( - PublicBatch( - session_metadata={}, - messages=[ - PublicMessage( - seqno=1, - created_at=datetime.datetime(2023, 3, 14, 15, 41), - message_group_id="", - session_metadata={}, - offset=1, - written_at=datetime.datetime(2023, 3, 14, 15, 42), - producer_id="asd", - data=gzip.compress(rb"123"), - _partition_session=None, - _commit_start_offset=5, - _commit_end_offset=15, - ) - ], - _partition_session=None, - _bytes_size=0, - _codec=Codec.CODEC_GZIP, - ), - [bytes(rb"123")], + PublicBatch( + session_metadata={}, + messages=[ + PublicMessage( + seqno=1, + created_at=datetime.datetime(2023, 3, 14, 15, 41), + message_group_id="", + session_metadata={}, + offset=1, + written_at=datetime.datetime(2023, 3, 14, 15, 42), + producer_id="asd", + data=gzip.compress(rb"123"), + _partition_session=None, + _commit_start_offset=5, + _commit_end_offset=15, + ) + ], + _partition_session=None, + _bytes_size=0, + _codec=Codec.CODEC_GZIP, + ), + [bytes(rb"123")], ), ( - PublicBatch( - session_metadata={}, - messages=[ - PublicMessage( - seqno=1, - created_at=datetime.datetime(2023, 3, 14, 15, 41), - message_group_id="", - session_metadata={}, - offset=1, - written_at=datetime.datetime(2023, 3, 14, 15, 42), - producer_id="asd", - data=rb"123", - _partition_session=None, - _commit_start_offset=5, - _commit_end_offset=15, - ), - PublicMessage( - seqno=1, - created_at=datetime.datetime(2023, 3, 14, 15, 41), - message_group_id="", - session_metadata={}, - offset=1, - written_at=datetime.datetime(2023, 3, 14, 15, 42), - producer_id="asd", - data=rb"456", - _partition_session=None, - _commit_start_offset=5, - _commit_end_offset=15, - ), - ], - _partition_session=None, - _bytes_size=0, - _codec=Codec.CODEC_RAW, - ), - [bytes(rb"123"), bytes(rb"456")], + PublicBatch( + session_metadata={}, + messages=[ + PublicMessage( + seqno=1, + created_at=datetime.datetime(2023, 3, 14, 15, 41), + message_group_id="", + session_metadata={}, + offset=1, + written_at=datetime.datetime(2023, 3, 14, 15, 42), + producer_id="asd", + data=rb"123", + _partition_session=None, + _commit_start_offset=5, + _commit_end_offset=15, + ), + PublicMessage( + seqno=1, + created_at=datetime.datetime(2023, 3, 14, 15, 41), + message_group_id="", + session_metadata={}, + offset=1, + written_at=datetime.datetime(2023, 3, 14, 15, 42), + producer_id="asd", + data=rb"456", + _partition_session=None, + _commit_start_offset=5, + _commit_end_offset=15, + ), + ], + _partition_session=None, + _bytes_size=0, + _codec=Codec.CODEC_RAW, + ), + [bytes(rb"123"), bytes(rb"456")], ), ( - PublicBatch( - session_metadata={}, - messages=[ - PublicMessage( - seqno=1, - created_at=datetime.datetime(2023, 3, 14, 15, 41), - message_group_id="", - session_metadata={}, - offset=1, - written_at=datetime.datetime(2023, 3, 14, 15, 42), - producer_id="asd", - data=gzip.compress(rb"123"), - _partition_session=None, - _commit_start_offset=5, - _commit_end_offset=15, - ), - PublicMessage( - seqno=1, - created_at=datetime.datetime(2023, 3, 14, 15, 41), - message_group_id="", - session_metadata={}, - offset=1, - written_at=datetime.datetime(2023, 3, 14, 15, 42), - producer_id="asd", - data=gzip.compress(rb"456"), - _partition_session=None, - _commit_start_offset=5, - _commit_end_offset=15, - ), - ], - _partition_session=None, - _bytes_size=0, - _codec=Codec.CODEC_GZIP, - ), - [bytes(rb"123"), bytes(rb"456")], + PublicBatch( + session_metadata={}, + messages=[ + PublicMessage( + seqno=1, + created_at=datetime.datetime(2023, 3, 14, 15, 41), + message_group_id="", + session_metadata={}, + offset=1, + written_at=datetime.datetime(2023, 3, 14, 15, 42), + producer_id="asd", + data=gzip.compress(rb"123"), + _partition_session=None, + _commit_start_offset=5, + _commit_end_offset=15, + ), + PublicMessage( + seqno=1, + created_at=datetime.datetime(2023, 3, 14, 15, 41), + message_group_id="", + session_metadata={}, + offset=1, + written_at=datetime.datetime(2023, 3, 14, 15, 42), + producer_id="asd", + data=gzip.compress(rb"456"), + _partition_session=None, + _commit_start_offset=5, + _commit_end_offset=15, + ), + ], + _partition_session=None, + _bytes_size=0, + _codec=Codec.CODEC_GZIP, + ), + [bytes(rb"123"), bytes(rb"456")], ), ], ) async def test_decode_loop( - self, stream_reader, batch: PublicBatch, data_out: typing.List[bytes] + self, stream_reader, batch: PublicBatch, data_out: typing.List[bytes] ): assert len(batch.messages) == len(data_out) @@ -520,7 +549,7 @@ async def test_decode_loop( assert batch == expected async def test_error_from_status_code( - self, stream, stream_reader_finish_with_error + self, stream, stream_reader_finish_with_error ): # noinspection PyTypeChecker stream.from_server.put_nowait( @@ -580,11 +609,11 @@ async def test_init_reader(self, stream, default_reader_settings): await reader.close() async def test_start_partition( - self, - stream_reader: ReaderStream, - stream, - default_reader_settings, - partition_session, + self, + stream_reader: ReaderStream, + stream, + default_reader_settings, + partition_session, ): def session_count(): return len(stream_reader._partition_sessions) @@ -624,8 +653,8 @@ def session_count(): assert len(stream_reader._partition_sessions) == initial_session_count + 1 assert stream_reader._partition_sessions[ - test_partition_session_id - ] == datatypes.PartitionSession( + test_partition_session_id + ] == datatypes.PartitionSession( id=test_partition_session_id, state=datatypes.PartitionSession.State.Active, topic_path=test_topic_path, @@ -660,7 +689,7 @@ def session_count(): assert partition_session.id not in stream_reader._partition_sessions async def test_partition_stop_graceful( - self, stream, stream_reader, partition_session + self, stream, stream_reader, partition_session ): def session_count(): return len(stream_reader._partition_sessions) @@ -703,11 +732,11 @@ def session_count(): stream.from_client.get_nowait() async def test_receive_message_from_server( - self, - stream_reader, - stream, - partition_session: datatypes.PartitionSession, - second_partition_session, + self, + stream_reader, + stream, + partition_session: datatypes.PartitionSession, + second_partition_session, ): def reader_batch_count(): return len(stream_reader._message_batches) @@ -785,7 +814,7 @@ def reader_batch_count(): ) async def test_read_batches( - self, stream_reader, partition_session, second_partition_session + self, stream_reader, partition_session, second_partition_session ): created_at = datetime.datetime(2020, 2, 1, 18, 12) created_at2 = datetime.datetime(2020, 2, 2, 18, 12) @@ -963,6 +992,102 @@ async def test_read_batches( _codec=Codec.CODEC_RAW, ) + @pytest.mark.parametrize( + 'batches_before,expected_message,batches_after', + [ + ( + [], + None, + [] + ), + ( + [PublicBatch( + session_metadata={}, + messages=[stub_message(1)], + _partition_session=stub_partition_session(), + _bytes_size=0, + _codec=Codec.CODEC_RAW, + )], + stub_message(1), + [] + ), + ( + [ + PublicBatch( + session_metadata={}, + messages=[stub_message(1), stub_message(2)], + _partition_session=stub_partition_session(), + _bytes_size=0, + _codec=Codec.CODEC_RAW, + ), + PublicBatch( + session_metadata={}, + messages=[stub_message(3), stub_message(4)], + _partition_session=stub_partition_session(), + _bytes_size=0, + _codec=Codec.CODEC_RAW, + ) + ], + stub_message(1), + [ + PublicBatch( + session_metadata={}, + messages=[stub_message(2)], + _partition_session=stub_partition_session(), + _bytes_size=0, + _codec=Codec.CODEC_RAW, + ), + PublicBatch( + session_metadata={}, + messages=[stub_message(3), stub_message(4)], + _partition_session=stub_partition_session(), + _bytes_size=0, + _codec=Codec.CODEC_RAW, + ) + ], + ), + ( + [ + PublicBatch( + session_metadata={}, + messages=[stub_message(1)], + _partition_session=stub_partition_session(), + _bytes_size=0, + _codec=Codec.CODEC_RAW, + ), + PublicBatch( + session_metadata={}, + messages=[stub_message(2), stub_message(3)], + _partition_session=stub_partition_session(), + _bytes_size=0, + _codec=Codec.CODEC_RAW, + ), + ], + stub_message(1), + [PublicBatch( + session_metadata={}, + messages=[stub_message(2), stub_message(3)], + _partition_session=stub_partition_session(), + _bytes_size=0, + _codec=Codec.CODEC_RAW, + )], + ), + + ] + ) + async def test_read_message( + self, + stream_reader, + batches_before: typing.List[datatypes.PublicBatch], + expected_message: PublicMessage, + batches_after: typing.List[datatypes.PublicBatch], + ): + stream_reader._message_batches = deque(batches_before) + mess = stream_reader.receive_message_nowait() + + assert mess == expected_message + assert list(stream_reader._message_batches) == batches_after + async def test_receive_batch_nowait(self, stream, stream_reader, partition_session): assert stream_reader.receive_batch_nowait() is None @@ -993,17 +1118,17 @@ async def test_receive_batch_nowait(self, stream, stream_reader, partition_sessi ) assert ( - stream_reader._buffer_size_bytes - == initial_buffer_size + 2 * self.default_batch_size + stream_reader._buffer_size_bytes + == initial_buffer_size + 2 * self.default_batch_size ) assert ( - StreamReadMessage.ReadRequest(self.default_batch_size) - == stream.from_client.get_nowait().client_message + StreamReadMessage.ReadRequest(self.default_batch_size) + == stream.from_client.get_nowait().client_message ) assert ( - StreamReadMessage.ReadRequest(self.default_batch_size) - == stream.from_client.get_nowait().client_message + StreamReadMessage.ReadRequest(self.default_batch_size) + == stream.from_client.get_nowait().client_message ) with pytest.raises(asyncio.QueueEmpty): @@ -1068,9 +1193,9 @@ async def wait_messages(): stream_index = 0 async def stream_create( - reader_reconnector_id: int, - driver: SupportedDriverType, - settings: PublicReaderSettings, + reader_reconnector_id: int, + driver: SupportedDriverType, + settings: PublicReaderSettings, ): nonlocal stream_index stream_index += 1 diff --git a/ydb/_topic_reader/topic_reader_sync.py b/ydb/_topic_reader/topic_reader_sync.py index ec243337..17706c79 100644 --- a/ydb/_topic_reader/topic_reader_sync.py +++ b/ydb/_topic_reader/topic_reader_sync.py @@ -5,6 +5,7 @@ from ydb._grpc.grpcwrapper.common_utils import SupportedDriverType from ydb._topic_common.common import _get_shared_event_loop +from ydb._topic_reader import datatypes from ydb._topic_reader.datatypes import PublicMessage, PublicBatch, ICommittable from ydb._topic_reader.topic_reader import ( PublicReaderSettings, @@ -72,6 +73,19 @@ def _call_sync(self, coro: Coroutine, timeout): f.cancel() raise + def _call_nowait(self, callback: typing.Callable[[], typing.Any]) -> typing.Any: + res = concurrent.futures.Future() + + def call(): + try: + res.set_result(call()) + except BaseException as err: + res.set_exception(err) + + self._loop.call_soon_threadsafe(call) + + return res.result() + def async_sessions_stat(self) -> concurrent.futures.Future: """ Receive stat from the server, return feature. @@ -100,15 +114,23 @@ def messages( """ raise NotImplementedError() - def receive_message(self, *, timeout: Union[float, None] = None) -> PublicMessage: + def receive_message(self, *, timeout: Union[float, None] = None) -> datatypes.PublicMessage: """ Block until receive new message It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available. + receive_message(timeout=0) may return None even right after async_wait_message() is ok - because lost of partition + or connection to server lost if no new message in timeout seconds (default - infinite): raise TimeoutError() if timeout <= 0 - it will fast non block method, get messages from internal buffer only. """ - raise NotImplementedError() + if timeout <= 0: + return self._receive_message_nowait() + + return self._call_sync(self._async_reader.receive_message(), timeout) + + def _receive_message_nowait(self) -> Optional[datatypes.PublicMessage]: + return self._call_nowait(lambda: self._async_reader._reconnector.receive_message_nowait()) def async_wait_message(self) -> concurrent.futures.Future: """ @@ -118,7 +140,7 @@ def async_wait_message(self) -> concurrent.futures.Future: Possible situation when receive signal about message available, but no messages when try to receive a message. If message expired between send event and try to retrieve message (for example connection broken). """ - raise NotImplementedError() + return self._call(self._async_reader._reconnector.wait_message()) def batches( self, @@ -157,6 +179,9 @@ def receive_batch( timeout, ) + def _receive_batch_nowait(self) -> Optional[PublicBatch]: + return self._call_nowait(lambda: self._async_reader._reconnector.receive_batch_nowait()) + def commit(self, mess: ICommittable): """ Put commit message to internal buffer. From d0d409fe0fa638a2a2d2ae830779ea437f94a959 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Mon, 20 Mar 2023 10:41:32 +0300 Subject: [PATCH 118/147] implement calls from sync to async --- ydb/_topic_common/common.py | 86 +++++++++++++++ ydb/_topic_common/common_test.py | 180 +++++++++++++++++++++++++++++++ 2 files changed, 266 insertions(+) diff --git a/ydb/_topic_common/common.py b/ydb/_topic_common/common.py index c569daca..b651ee40 100644 --- a/ydb/_topic_common/common.py +++ b/ydb/_topic_common/common.py @@ -60,3 +60,89 @@ def start_event_loop(): _shared_event_loop = event_loop_set_done.result() return _shared_event_loop + + +class CallFromSyncToAsync: + _loop: asyncio.AbstractEventLoop + + def __init__(self, loop: asyncio.AbstractEventLoop): + self._loop = loop + + def unsafe_call_with_future(self, coro: typing.Coroutine) -> concurrent.futures.Future: + """ + returned result from coro may be lost + """ + return asyncio.run_coroutine_threadsafe(coro, self._loop) + + def unsafe_call_with_result(self, coro: typing.Coroutine, timeout: typing.Union[int, float, None]): + """ + returned result from coro may be lost by race future cancel by timeout and return value from coroutine + """ + f = self.unsafe_call_with_future(coro) + try: + return f.result(timeout) + except concurrent.futures.TimeoutError: + raise TimeoutError() + finally: + f.cancel() + + def safe_call_with_result(self, coro: typing.Coroutine, timeout: typing.Union[int, float]): + """ + no lost returned value from coro, but may be slower especially timeout latency - it wait coroutine cancelation. + """ + + if timeout <= 0: + return self._safe_call_fast(coro) + + async def call_coro(): + task = self._loop.create_task(coro) + try: + res = await asyncio.wait_for(task, timeout) + return res + except BaseException as err: + try: + res = await task + return res + except asyncio.CancelledError: + pass + + # return builtin TimeoutError instead of asyncio.TimeoutError + raise TimeoutError() + + + return asyncio.run_coroutine_threadsafe(call_coro(), self._loop).result() + + def _safe_call_fast(self, coro: typing.Coroutine): + """ + no lost returned value from coro, but may be slower especially timeout latency - it wait coroutine cancelation. + Wait coroutine result only one loop. + """ + res = concurrent.futures.Future() + + async def call_coro(): + try: + res.set_result(await coro) + except asyncio.CancelledError: + res.set_exception(TimeoutError()) + + async def sleep0(): + await asyncio.sleep(0) + + coro_future = asyncio.run_coroutine_threadsafe(call_coro(), self._loop) + asyncio.run_coroutine_threadsafe(sleep0(), self._loop).result() + coro_future.cancel() + return res.result() + + def call_sync(self, callback: typing.Callable[[], typing.Any]) -> typing.Any: + result = concurrent.futures.Future() + + def call_callback(): + try: + res = callback() + result.set_result(res) + except BaseException as err: + result.set_exception(err) + + self._loop.call_soon_threadsafe(call_callback) + + return result.result() diff --git a/ydb/_topic_common/common_test.py b/ydb/_topic_common/common_test.py index 445abdcf..b292862f 100644 --- a/ydb/_topic_common/common_test.py +++ b/ydb/_topic_common/common_test.py @@ -1,9 +1,12 @@ import asyncio +import threading +import time import typing import grpc import pytest +from .common import CallFromSyncToAsync from .._grpc.grpcwrapper.common_utils import ( GrpcWrapperAsyncIO, ServerStatus, @@ -25,6 +28,23 @@ ) +@pytest.fixture() +def separate_loop(): + loop = asyncio.new_event_loop() + + def run_loop(): + loop.run_forever() + pass + + t = threading.Thread(target=run_loop, name="test separate loop") + t.start() + + yield loop + + loop.call_soon_threadsafe(lambda: loop.stop()) + t.join() + + @pytest.mark.asyncio class Test: async def test_callback_from_asyncio(self): @@ -111,3 +131,163 @@ def test_failed(self): assert not status.is_success() with pytest.raises(issues.Overloaded): issues._process_response(status) + + +@pytest.mark.asyncio +class TestCallFromSyncToAsync: + @pytest.fixture() + def caller(self, separate_loop): + return CallFromSyncToAsync(separate_loop) + + def test_unsafe_call_with_future(self, separate_loop, caller): + callback_loop = None + + async def callback(): + nonlocal callback_loop + callback_loop = asyncio.get_running_loop() + return 1 + + f = caller.unsafe_call_with_future(callback()) + + assert f.result() == 1 + assert callback_loop is separate_loop + + def test_unsafe_call_with_result_ok(self, separate_loop, caller): + callback_loop = None + + async def callback(): + nonlocal callback_loop + callback_loop = asyncio.get_running_loop() + return 1 + + res = caller.unsafe_call_with_result(callback(), None) + + assert res == 1 + assert callback_loop is separate_loop + + def test_unsafe_call_with_result_timeout(self, separate_loop, caller): + timeout = 0.01 + callback_loop = None + + async def callback(): + nonlocal callback_loop + callback_loop = asyncio.get_running_loop() + await asyncio.sleep(1) + return 1 + + start = time.monotonic() + with pytest.raises(TimeoutError): + caller.unsafe_call_with_result(callback(), timeout) + finished = time.monotonic() + + assert callback_loop is separate_loop + assert finished - start > timeout + + def test_safe_call_with_result_ok(self, separate_loop, caller): + callback_loop = None + + async def callback(): + nonlocal callback_loop + callback_loop = asyncio.get_running_loop() + return 1 + + res = caller.safe_call_with_result(callback(), 1) + + assert res == 1 + assert callback_loop is separate_loop + + def test_safe_call_with_result_timeout(self, separate_loop, caller): + timeout = 0.01 + callback_loop = None + cancelled = False + + async def callback(): + nonlocal callback_loop, cancelled + callback_loop = asyncio.get_running_loop() + try: + await asyncio.sleep(1) + except asyncio.CancelledError: + cancelled = True + raise + + return 1 + + start = time.monotonic() + with pytest.raises(TimeoutError): + caller.safe_call_with_result(callback(), timeout) + finished = time.monotonic() + + async def sleep0(): + await asyncio.sleep(0) + + # wait one loop for handle task cancelation + asyncio.run_coroutine_threadsafe(sleep0(), separate_loop) + + assert callback_loop is separate_loop + assert finished - start > timeout + assert cancelled + + def test_safe_callback_with_0_timeout_ok(self, separate_loop, caller): + callback_loop = None + + async def f1(): + return 1 + + async def f2(): + return await f1() + + async def callback(): + nonlocal callback_loop + callback_loop = asyncio.get_running_loop() + return await f2() + + res = caller.safe_call_with_result(callback(), 0) + assert callback_loop is separate_loop + assert res == 1 + + def test_safe_callback_with_0_timeout_timeout(self, separate_loop, caller): + callback_loop = None + cancelled = False + + async def callback(): + try: + nonlocal callback_loop, cancelled + + callback_loop = asyncio.get_running_loop() + await asyncio.sleep(1) + except asyncio.CancelledError: + cancelled = True + raise + + with pytest.raises(TimeoutError): + caller.safe_call_with_result(callback(), 0) + + assert callback_loop is separate_loop + assert cancelled + + def test_call_sync_ok(self, separate_loop, caller): + callback_eventloop = None + + def callback(): + nonlocal callback_eventloop + callback_eventloop = asyncio.get_running_loop() + return 1 + + res = caller.call_sync(callback) + assert callback_eventloop is separate_loop + assert res == 1 + + def test_call_sync_error(self, separate_loop, caller): + callback_eventloop = None + + class TestError(RuntimeError): + pass + + def callback(): + nonlocal callback_eventloop + callback_eventloop = asyncio.get_running_loop() + raise TestError + + with pytest.raises(TestError): + caller.call_sync(callback) + assert callback_eventloop is separate_loop From 9d3e99dfabbe7f455e3d6fc9b22c25aec2445f9a Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Mon, 20 Mar 2023 11:04:15 +0300 Subject: [PATCH 119/147] rewrite reader to sync caller --- ydb/_topic_reader/topic_reader_sync.py | 106 +++++++++++-------------- 1 file changed, 48 insertions(+), 58 deletions(-) diff --git a/ydb/_topic_reader/topic_reader_sync.py b/ydb/_topic_reader/topic_reader_sync.py index ec243337..839f8fc7 100644 --- a/ydb/_topic_reader/topic_reader_sync.py +++ b/ydb/_topic_reader/topic_reader_sync.py @@ -4,7 +4,8 @@ from typing import List, Union, Iterable, Optional, Coroutine from ydb._grpc.grpcwrapper.common_utils import SupportedDriverType -from ydb._topic_common.common import _get_shared_event_loop +from ydb._topic_common.common import _get_shared_event_loop, CallFromSyncToAsync +from ydb._topic_reader import datatypes from ydb._topic_reader.datatypes import PublicMessage, PublicBatch, ICommittable from ydb._topic_reader.topic_reader import ( PublicReaderSettings, @@ -18,29 +19,31 @@ class TopicReaderSync: - _loop: asyncio.AbstractEventLoop + _caller: CallFromSyncToAsync _async_reader: PublicAsyncIOReader _closed: bool def __init__( - self, - driver: SupportedDriverType, - settings: PublicReaderSettings, - *, - eventloop: Optional[asyncio.AbstractEventLoop] = None, + self, + driver: SupportedDriverType, + settings: PublicReaderSettings, + *, + eventloop: Optional[asyncio.AbstractEventLoop] = None, ): self._closed = False if eventloop: - self._loop = eventloop + loop = eventloop else: - self._loop = _get_shared_event_loop() + loop = _get_shared_event_loop() + + self._caller = CallFromSyncToAsync(loop) async def create_reader(): return PublicAsyncIOReader(driver, settings) self._async_reader = asyncio.run_coroutine_threadsafe( - create_reader(), self._loop + create_reader(), loop ).result() def __del__(self): @@ -52,26 +55,6 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.close() - def _call(self, coro) -> concurrent.futures.Future: - """ - Call async function and return future fow wait result - """ - if self._closed: - raise TopicReaderClosedError() - - return asyncio.run_coroutine_threadsafe(coro, self._loop) - - def _call_sync(self, coro: Coroutine, timeout): - """ - Call async function, wait and return result - """ - f = self._call(coro) - try: - return f.result(timeout) - except TimeoutError: - f.cancel() - raise - def async_sessions_stat(self) -> concurrent.futures.Future: """ Receive stat from the server, return feature. @@ -87,7 +70,7 @@ async def sessions_stat(self) -> List[SessionStat]: raise NotImplementedError() def messages( - self, *, timeout: Union[float, None] = None + self, *, timeout: Union[float, None] = None ) -> Iterable[PublicMessage]: """ todo? @@ -121,11 +104,11 @@ def async_wait_message(self) -> concurrent.futures.Future: raise NotImplementedError() def batches( - self, - *, - max_messages: Union[int, None] = None, - max_bytes: Union[int, None] = None, - timeout: Union[float, None] = None, + self, + *, + max_messages: Union[int, None] = None, + max_bytes: Union[int, None] = None, + timeout: Union[float, None] = None, ) -> Iterable[PublicBatch]: """ Block until receive new batch. @@ -137,11 +120,11 @@ def batches( raise NotImplementedError() def receive_batch( - self, - *, - max_messages: typing.Union[int, None] = None, - max_bytes: typing.Union[int, None] = None, - timeout: Union[float, None] = None, + self, + *, + max_messages: typing.Union[int, None] = None, + max_bytes: typing.Union[int, None] = None, + timeout: Union[float, None] = None, ) -> Union[PublicBatch, None]: """ Get one messages batch from reader @@ -150,37 +133,42 @@ def receive_batch( if no new message in timeout seconds (default - infinite): raise TimeoutError() if timeout <= 0 - it will fast non block method, get messages from internal buffer only. """ - return self._call_sync( - self._async_reader.receive_batch( - max_messages=max_messages, max_bytes=max_bytes - ), - timeout, - ) + self._check_closed() - def commit(self, mess: ICommittable): + return self._caller.safe_call_with_result( + self._async_reader.receive_batch(max_messages=max_messages, max_bytes=max_bytes), + timeout) + + def commit(self, mess: typing.Union[datatypes.PublicMessage, datatypes.PublicBatch]): """ Put commit message to internal buffer. For the method no way check the commit result (for example if lost connection - commits will not re-send and committed messages will receive again) """ - self._call_sync(self._async_reader.commit(mess), None) + self._check_closed() + + self._caller.call_sync(self._async_reader.commit(mess)) def commit_with_ack( - self, mess: ICommittable + self, mess: ICommittable, timeout: typing.Union[int, float, None] = None ) -> Union[CommitResult, List[CommitResult]]: """ write commit message to a buffer and wait ack from the server. if receive in timeout seconds (default - infinite): raise TimeoutError() """ - return self._call_sync(self._async_reader.commit_with_ack(mess), None) + self._check_closed() + + return self._caller.unsafe_call_with_result(self._async_reader.commit_with_ack(mess), timeout) - def async_commit_with_ack(self, mess: ICommittable) -> concurrent.futures.Future: + def async_commit_with_ack(self, mess: typing.Union[datatypes.PublicMessage, datatypes.PublicBatch]) -> concurrent.futures.Future: """ write commit message to a buffer and return Future for wait result. """ - return self._call(self._async_reader.commit_with_ack(mess), None) + self._check_closed() + + return self._caller.unsafe_call_with_future(self._async_reader.commit_with_ack(mess)) def async_flush(self) -> concurrent.futures.Future: """ @@ -194,12 +182,14 @@ def flush(self): """ raise NotImplementedError() - def close(self): + def close(self, *, timeout: typing.Union[int, float, None] = None): if self._closed: return + self._closed = True - # for no call self._call_sync on closed object - asyncio.run_coroutine_threadsafe( - self._async_reader.close(), self._loop - ).result() + self._caller.safe_call_with_result(self._async_reader.close(), timeout) + + def _check_closed(self): + if self._closed: + raise TopicReaderClosedError() From 99e3ab910543c54b442c8be42de50ee39950ce2b Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Mon, 20 Mar 2023 11:06:13 +0300 Subject: [PATCH 120/147] fix typo --- ydb/_topic_common/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ydb/_topic_common/common.py b/ydb/_topic_common/common.py index b651ee40..9f297b7a 100644 --- a/ydb/_topic_common/common.py +++ b/ydb/_topic_common/common.py @@ -91,7 +91,7 @@ def safe_call_with_result(self, coro: typing.Coroutine, timeout: typing.Union[in no lost returned value from coro, but may be slower especially timeout latency - it wait coroutine cancelation. """ - if timeout <= 0: + if timeout is not None and timeout <= 0: return self._safe_call_fast(coro) async def call_coro(): From 387c17c3986e029708a3b1d16840ea01d48502f1 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Mon, 20 Mar 2023 11:25:24 +0300 Subject: [PATCH 121/147] impl writer --- ydb/_topic_common/common.py | 4 +- ydb/_topic_writer/topic_writer.py | 5 ++ ydb/_topic_writer/topic_writer_sync.py | 76 +++++++++++++------------- 3 files changed, 45 insertions(+), 40 deletions(-) diff --git a/ydb/_topic_common/common.py b/ydb/_topic_common/common.py index 9f297b7a..5d2fb446 100644 --- a/ydb/_topic_common/common.py +++ b/ydb/_topic_common/common.py @@ -7,7 +7,7 @@ from .. import operation, issues from .._grpc.grpcwrapper.common_utils import IFromProtoWithProtoType -TimeoutType = typing.Union[int, float] +TimeoutType = typing.Union[int, float, None] def wrap_operation(rpc_state, response_pb, driver=None): @@ -86,7 +86,7 @@ def unsafe_call_with_result(self, coro: typing.Coroutine, timeout: typing.Union[ finally: f.cancel() - def safe_call_with_result(self, coro: typing.Coroutine, timeout: typing.Union[int, float]): + def safe_call_with_result(self, coro: typing.Coroutine, timeout: TimeoutType): """ no lost returned value from coro, but may be slower especially timeout latency - it wait coroutine cancelation. """ diff --git a/ydb/_topic_writer/topic_writer.py b/ydb/_topic_writer/topic_writer.py index dab0371f..97c05420 100644 --- a/ydb/_topic_writer/topic_writer.py +++ b/ydb/_topic_writer/topic_writer.py @@ -176,6 +176,11 @@ def __init__(self, message: str): super(TopicWriterError, self).__init__(message) +class TopicWriterClosedError(ydb.Error): + def __init__(self): + super(TopicWriterClosedError, self).__init__("Topic writer already closed") + + class TopicWriterRepeatableError(TopicWriterError): pass diff --git a/ydb/_topic_writer/topic_writer_sync.py b/ydb/_topic_writer/topic_writer_sync.py index 4713c07d..908fd56b 100644 --- a/ydb/_topic_writer/topic_writer_sync.py +++ b/ydb/_topic_writer/topic_writer_sync.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import typing from concurrent.futures import Future from typing import Union, List, Optional, Coroutine @@ -10,15 +11,15 @@ TopicWriterError, PublicWriterInitInfo, PublicWriteResult, - Message, + Message, TopicWriterClosedError, ) from .topic_writer_asyncio import WriterAsyncIO -from .._topic_common.common import _get_shared_event_loop, TimeoutType +from .._topic_common.common import _get_shared_event_loop, TimeoutType, CallFromSyncToAsync class WriterSync: - _loop: asyncio.AbstractEventLoop + _caller: CallFromSyncToAsync _async_writer: WriterAsyncIO _closed: bool @@ -33,16 +34,16 @@ def __init__( self._closed = False if eventloop: - self._loop = eventloop + loop = eventloop else: - self._loop = _get_shared_event_loop() + loop = _get_shared_event_loop() + + self._caller = CallFromSyncToAsync(loop) async def create_async_writer(): return WriterAsyncIO(driver, settings) - self._async_writer = asyncio.run_coroutine_threadsafe( - create_async_writer(), self._loop - ).result() + self._async_writer = self._caller.safe_call_with_result(create_async_writer(), None) def __enter__(self): return self @@ -50,63 +51,62 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.close() - def _call(self, coro): + def close(self, *, flush: bool = True, timeout: typing.Union[int, float, None] = None): if self._closed: - raise TopicWriterError("writer is closed") + return - return asyncio.run_coroutine_threadsafe(coro, self._loop) + self._closed = True - def _call_sync(self, coro: Coroutine, timeout): - f = self._call(coro) - try: - return f.result(timeout) - except TimeoutError: - f.cancel() - raise + self._caller.safe_call_with_result(self._async_writer.close(flush=flush), timeout) - def close(self, flush: bool = True): + def _check_closed(self): if self._closed: - return + raise TopicWriterClosedError() - self._closed = True + def async_flush(self) -> Future: + self._check_closed() - # for no call self._call_sync on closed object - asyncio.run_coroutine_threadsafe( - self._async_writer.close(flush=flush), self._loop - ).result() + return self._caller.unsafe_call_with_future(self._async_writer.flush()) - def async_flush(self) -> Future: - if self._closed: - raise TopicWriterError("writer is closed") - return self._call(self._async_writer.flush()) + def flush(self, *, timeout=None): + self._check_closed() - def flush(self, timeout=None): - self._call_sync(self._async_writer.flush(), timeout) + return self._caller.unsafe_call_with_result(self._async_writer.flush(), timeout) def async_wait_init(self) -> Future[PublicWriterInitInfo]: - return self._call(self._async_writer.wait_init()) + self._check_closed() + + return self._caller.unsafe_call_with_future(self._async_writer.wait_init()) - def wait_init(self, timeout: Optional[TimeoutType] = None) -> PublicWriterInitInfo: - return self._call_sync(self._async_writer.wait_init(), timeout) + def wait_init(self, *, timeout: TimeoutType = None) -> PublicWriterInitInfo: + self._check_closed() + + return self._caller.unsafe_call_with_result(self._async_writer.wait_init(), timeout) def write( self, messages: Union[Message, List[Message]], - timeout: Union[float, None] = None, + timeout: TimeoutType = None, ): - self._call_sync(self._async_writer.write(messages), timeout=timeout) + self._check_closed() + + self._caller.safe_call_with_result(self._async_writer.write(messages), timeout) def async_write_with_ack( self, messages: Union[Message, List[Message]], ) -> Future[Union[PublicWriteResult, List[PublicWriteResult]]]: - return self._call(self._async_writer.write_with_ack(messages)) + self._check_closed() + + return self._caller.unsafe_call_with_future(self._async_writer.write_with_ack(messages)) def write_with_ack( self, messages: Union[Message, List[Message]], timeout: Union[float, None] = None, ) -> Union[PublicWriteResult, List[PublicWriteResult]]: - return self._call_sync( + self._check_closed() + + return self._caller.unsafe_call_with_result( self._async_writer.write_with_ack(messages), timeout=timeout ) From 6e749edef1e8f32fb1a52978a99e7b70bc3987a5 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Mon, 20 Mar 2023 11:27:14 +0300 Subject: [PATCH 122/147] linter --- ydb/_topic_common/common.py | 11 +++-- ydb/_topic_reader/topic_reader_sync.py | 61 +++++++++++++++----------- ydb/_topic_writer/topic_writer_sync.py | 32 ++++++++++---- 3 files changed, 66 insertions(+), 38 deletions(-) diff --git a/ydb/_topic_common/common.py b/ydb/_topic_common/common.py index 5d2fb446..6e381c2f 100644 --- a/ydb/_topic_common/common.py +++ b/ydb/_topic_common/common.py @@ -68,13 +68,17 @@ class CallFromSyncToAsync: def __init__(self, loop: asyncio.AbstractEventLoop): self._loop = loop - def unsafe_call_with_future(self, coro: typing.Coroutine) -> concurrent.futures.Future: + def unsafe_call_with_future( + self, coro: typing.Coroutine + ) -> concurrent.futures.Future: """ returned result from coro may be lost """ return asyncio.run_coroutine_threadsafe(coro, self._loop) - def unsafe_call_with_result(self, coro: typing.Coroutine, timeout: typing.Union[int, float, None]): + def unsafe_call_with_result( + self, coro: typing.Coroutine, timeout: typing.Union[int, float, None] + ): """ returned result from coro may be lost by race future cancel by timeout and return value from coroutine """ @@ -99,7 +103,7 @@ async def call_coro(): try: res = await asyncio.wait_for(task, timeout) return res - except BaseException as err: + except asyncio.TimeoutError: try: res = await task return res @@ -109,7 +113,6 @@ async def call_coro(): # return builtin TimeoutError instead of asyncio.TimeoutError raise TimeoutError() - return asyncio.run_coroutine_threadsafe(call_coro(), self._loop).result() def _safe_call_fast(self, coro: typing.Coroutine): diff --git a/ydb/_topic_reader/topic_reader_sync.py b/ydb/_topic_reader/topic_reader_sync.py index 839f8fc7..b82fa58b 100644 --- a/ydb/_topic_reader/topic_reader_sync.py +++ b/ydb/_topic_reader/topic_reader_sync.py @@ -1,7 +1,7 @@ import asyncio import concurrent.futures import typing -from typing import List, Union, Iterable, Optional, Coroutine +from typing import List, Union, Iterable, Optional from ydb._grpc.grpcwrapper.common_utils import SupportedDriverType from ydb._topic_common.common import _get_shared_event_loop, CallFromSyncToAsync @@ -24,11 +24,11 @@ class TopicReaderSync: _closed: bool def __init__( - self, - driver: SupportedDriverType, - settings: PublicReaderSettings, - *, - eventloop: Optional[asyncio.AbstractEventLoop] = None, + self, + driver: SupportedDriverType, + settings: PublicReaderSettings, + *, + eventloop: Optional[asyncio.AbstractEventLoop] = None, ): self._closed = False @@ -70,7 +70,7 @@ async def sessions_stat(self) -> List[SessionStat]: raise NotImplementedError() def messages( - self, *, timeout: Union[float, None] = None + self, *, timeout: Union[float, None] = None ) -> Iterable[PublicMessage]: """ todo? @@ -104,11 +104,11 @@ def async_wait_message(self) -> concurrent.futures.Future: raise NotImplementedError() def batches( - self, - *, - max_messages: Union[int, None] = None, - max_bytes: Union[int, None] = None, - timeout: Union[float, None] = None, + self, + *, + max_messages: Union[int, None] = None, + max_bytes: Union[int, None] = None, + timeout: Union[float, None] = None, ) -> Iterable[PublicBatch]: """ Block until receive new batch. @@ -120,11 +120,11 @@ def batches( raise NotImplementedError() def receive_batch( - self, - *, - max_messages: typing.Union[int, None] = None, - max_bytes: typing.Union[int, None] = None, - timeout: Union[float, None] = None, + self, + *, + max_messages: typing.Union[int, None] = None, + max_bytes: typing.Union[int, None] = None, + timeout: Union[float, None] = None, ) -> Union[PublicBatch, None]: """ Get one messages batch from reader @@ -136,10 +136,15 @@ def receive_batch( self._check_closed() return self._caller.safe_call_with_result( - self._async_reader.receive_batch(max_messages=max_messages, max_bytes=max_bytes), - timeout) - - def commit(self, mess: typing.Union[datatypes.PublicMessage, datatypes.PublicBatch]): + self._async_reader.receive_batch( + max_messages=max_messages, max_bytes=max_bytes + ), + timeout, + ) + + def commit( + self, mess: typing.Union[datatypes.PublicMessage, datatypes.PublicBatch] + ): """ Put commit message to internal buffer. @@ -151,7 +156,7 @@ def commit(self, mess: typing.Union[datatypes.PublicMessage, datatypes.PublicBat self._caller.call_sync(self._async_reader.commit(mess)) def commit_with_ack( - self, mess: ICommittable, timeout: typing.Union[int, float, None] = None + self, mess: ICommittable, timeout: typing.Union[int, float, None] = None ) -> Union[CommitResult, List[CommitResult]]: """ write commit message to a buffer and wait ack from the server. @@ -160,15 +165,21 @@ def commit_with_ack( """ self._check_closed() - return self._caller.unsafe_call_with_result(self._async_reader.commit_with_ack(mess), timeout) + return self._caller.unsafe_call_with_result( + self._async_reader.commit_with_ack(mess), timeout + ) - def async_commit_with_ack(self, mess: typing.Union[datatypes.PublicMessage, datatypes.PublicBatch]) -> concurrent.futures.Future: + def async_commit_with_ack( + self, mess: typing.Union[datatypes.PublicMessage, datatypes.PublicBatch] + ) -> concurrent.futures.Future: """ write commit message to a buffer and return Future for wait result. """ self._check_closed() - return self._caller.unsafe_call_with_future(self._async_reader.commit_with_ack(mess)) + return self._caller.unsafe_call_with_future( + self._async_reader.commit_with_ack(mess) + ) def async_flush(self) -> concurrent.futures.Future: """ diff --git a/ydb/_topic_writer/topic_writer_sync.py b/ydb/_topic_writer/topic_writer_sync.py index 908fd56b..c2329ae5 100644 --- a/ydb/_topic_writer/topic_writer_sync.py +++ b/ydb/_topic_writer/topic_writer_sync.py @@ -3,19 +3,23 @@ import asyncio import typing from concurrent.futures import Future -from typing import Union, List, Optional, Coroutine +from typing import Union, List, Optional from .._grpc.grpcwrapper.common_utils import SupportedDriverType from .topic_writer import ( PublicWriterSettings, - TopicWriterError, PublicWriterInitInfo, PublicWriteResult, - Message, TopicWriterClosedError, + Message, + TopicWriterClosedError, ) from .topic_writer_asyncio import WriterAsyncIO -from .._topic_common.common import _get_shared_event_loop, TimeoutType, CallFromSyncToAsync +from .._topic_common.common import ( + _get_shared_event_loop, + TimeoutType, + CallFromSyncToAsync, +) class WriterSync: @@ -43,7 +47,9 @@ def __init__( async def create_async_writer(): return WriterAsyncIO(driver, settings) - self._async_writer = self._caller.safe_call_with_result(create_async_writer(), None) + self._async_writer = self._caller.safe_call_with_result( + create_async_writer(), None + ) def __enter__(self): return self @@ -51,13 +57,17 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.close() - def close(self, *, flush: bool = True, timeout: typing.Union[int, float, None] = None): + def close( + self, *, flush: bool = True, timeout: typing.Union[int, float, None] = None + ): if self._closed: return self._closed = True - self._caller.safe_call_with_result(self._async_writer.close(flush=flush), timeout) + self._caller.safe_call_with_result( + self._async_writer.close(flush=flush), timeout + ) def _check_closed(self): if self._closed: @@ -81,7 +91,9 @@ def async_wait_init(self) -> Future[PublicWriterInitInfo]: def wait_init(self, *, timeout: TimeoutType = None) -> PublicWriterInitInfo: self._check_closed() - return self._caller.unsafe_call_with_result(self._async_writer.wait_init(), timeout) + return self._caller.unsafe_call_with_result( + self._async_writer.wait_init(), timeout + ) def write( self, @@ -98,7 +110,9 @@ def async_write_with_ack( ) -> Future[Union[PublicWriteResult, List[PublicWriteResult]]]: self._check_closed() - return self._caller.unsafe_call_with_future(self._async_writer.write_with_ack(messages)) + return self._caller.unsafe_call_with_future( + self._async_writer.write_with_ack(messages) + ) def write_with_ack( self, From 2931699f4e2d5cdd526b58feadb1ce54697e4f11 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Mon, 20 Mar 2023 11:35:56 +0300 Subject: [PATCH 123/147] timeout style --- ydb/_topic_common/common.py | 4 +--- ydb/_topic_reader/topic_reader_sync.py | 10 +++++++--- ydb/_topic_writer/topic_writer_sync.py | 5 +---- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/ydb/_topic_common/common.py b/ydb/_topic_common/common.py index 6e381c2f..d064da01 100644 --- a/ydb/_topic_common/common.py +++ b/ydb/_topic_common/common.py @@ -76,9 +76,7 @@ def unsafe_call_with_future( """ return asyncio.run_coroutine_threadsafe(coro, self._loop) - def unsafe_call_with_result( - self, coro: typing.Coroutine, timeout: typing.Union[int, float, None] - ): + def unsafe_call_with_result(self, coro: typing.Coroutine, timeout: TimeoutType): """ returned result from coro may be lost by race future cancel by timeout and return value from coroutine """ diff --git a/ydb/_topic_reader/topic_reader_sync.py b/ydb/_topic_reader/topic_reader_sync.py index b82fa58b..30bf92a1 100644 --- a/ydb/_topic_reader/topic_reader_sync.py +++ b/ydb/_topic_reader/topic_reader_sync.py @@ -4,7 +4,11 @@ from typing import List, Union, Iterable, Optional from ydb._grpc.grpcwrapper.common_utils import SupportedDriverType -from ydb._topic_common.common import _get_shared_event_loop, CallFromSyncToAsync +from ydb._topic_common.common import ( + _get_shared_event_loop, + CallFromSyncToAsync, + TimeoutType, +) from ydb._topic_reader import datatypes from ydb._topic_reader.datatypes import PublicMessage, PublicBatch, ICommittable from ydb._topic_reader.topic_reader import ( @@ -156,7 +160,7 @@ def commit( self._caller.call_sync(self._async_reader.commit(mess)) def commit_with_ack( - self, mess: ICommittable, timeout: typing.Union[int, float, None] = None + self, mess: ICommittable, timeout: TimeoutType = None ) -> Union[CommitResult, List[CommitResult]]: """ write commit message to a buffer and wait ack from the server. @@ -193,7 +197,7 @@ def flush(self): """ raise NotImplementedError() - def close(self, *, timeout: typing.Union[int, float, None] = None): + def close(self, *, timeout: TimeoutType = None): if self._closed: return diff --git a/ydb/_topic_writer/topic_writer_sync.py b/ydb/_topic_writer/topic_writer_sync.py index c2329ae5..e6b51238 100644 --- a/ydb/_topic_writer/topic_writer_sync.py +++ b/ydb/_topic_writer/topic_writer_sync.py @@ -1,7 +1,6 @@ from __future__ import annotations import asyncio -import typing from concurrent.futures import Future from typing import Union, List, Optional @@ -57,9 +56,7 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.close() - def close( - self, *, flush: bool = True, timeout: typing.Union[int, float, None] = None - ): + def close(self, *, flush: bool = True, timeout: TimeoutType = None): if self._closed: return From 4ede48a783b83feb01924ccd2bf232fed5f043cb Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Mon, 20 Mar 2023 11:46:15 +0300 Subject: [PATCH 124/147] remove warning for good choise --- ydb/global_settings.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/ydb/global_settings.py b/ydb/global_settings.py index 3b0368ad..f1fabcc2 100644 --- a/ydb/global_settings.py +++ b/ydb/global_settings.py @@ -7,10 +7,6 @@ def global_allow_truncated_result(enabled: bool = True): if enabled: warnings.warn("Global allow truncated response is deprecated behaviour.") - else: - warnings.warn( - "Global deny truncated response is default behaviour. You don't need call the function." - ) convert._default_allow_truncated_result = enabled @@ -18,9 +14,5 @@ def global_allow_truncated_result(enabled: bool = True): def global_allow_split_transactions(enabled: bool): if enabled: warnings.warn("Global allow truncated response is deprecated behaviour.") - else: - warnings.warn( - "Global deby truncated response is default behaviour. You don't need call the function." - ) table._allow_split_transaction = enabled From 60632774b4b98a5db3cd96e1a85b266e7e3d2606 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Mon, 20 Mar 2023 13:25:41 +0300 Subject: [PATCH 125/147] merge --- ydb/_topic_reader/topic_reader_asyncio.py | 1 - .../topic_reader_asyncio_test.py | 475 +++++++++--------- ydb/_topic_reader/topic_reader_sync.py | 13 +- 3 files changed, 247 insertions(+), 242 deletions(-) diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index c74f7d09..7266ae43 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -419,7 +419,6 @@ def receive_message_nowait(self): return message - def commit( self, batch: datatypes.ICommittable ) -> datatypes.PartitionSession.CommitAckWaiter: diff --git a/ydb/_topic_reader/topic_reader_asyncio_test.py b/ydb/_topic_reader/topic_reader_asyncio_test.py index a153e25c..a310298e 100644 --- a/ydb/_topic_reader/topic_reader_asyncio_test.py +++ b/ydb/_topic_reader/topic_reader_asyncio_test.py @@ -114,7 +114,7 @@ def stream(self): @pytest.fixture() def partition_session( - self, default_reader_settings, stream_reader_started: ReaderStream + self, default_reader_settings, stream_reader_started: ReaderStream ) -> datatypes.PartitionSession: partition_session = datatypes.PartitionSession( id=2, @@ -135,7 +135,7 @@ def partition_session( @pytest.fixture() def second_partition_session( - self, default_reader_settings, stream_reader_started: ReaderStream + self, default_reader_settings, stream_reader_started: ReaderStream ): partition_session = datatypes.PartitionSession( id=12, @@ -186,7 +186,7 @@ async def get_started_reader(self, stream, *args, **kwargs) -> ReaderStream: @pytest.fixture() async def stream_reader_started( - self, stream, default_reader_settings + self, stream, default_reader_settings ) -> ReaderStream: return await self.get_started_reader(stream, default_reader_settings) @@ -199,7 +199,7 @@ async def stream_reader(self, stream_reader_started: ReaderStream): @pytest.fixture() async def stream_reader_finish_with_error( - self, stream_reader_started: ReaderStream + self, stream_reader_started: ReaderStream ): yield stream_reader_started @@ -208,7 +208,9 @@ async def stream_reader_finish_with_error( @staticmethod def create_message( - partition_session: typing.Optional[datatypes.PartitionSession], seqno: int, offset_delta: int + partition_session: typing.Optional[datatypes.PartitionSession], + seqno: int, + offset_delta: int, ): return PublicMessage( seqno=seqno, @@ -216,17 +218,17 @@ def create_message( message_group_id="test-message-group", session_metadata={}, offset=partition_session._next_message_start_commit_offset - + offset_delta - - 1, + + offset_delta + - 1, written_at=datetime.datetime(2023, 2, 3, 14, 16), producer_id="test-producer-id", data=bytes(), _partition_session=partition_session, _commit_start_offset=partition_session._next_message_start_commit_offset - + offset_delta - - 1, + + offset_delta + - 1, _commit_end_offset=partition_session._next_message_start_commit_offset - + offset_delta, + + offset_delta, ) async def send_message(self, stream_reader, message: PublicMessage): @@ -286,28 +288,28 @@ class TestError(Exception): "commit,send_range", [ ( - OffsetsRange( - partition_session_committed_offset, - partition_session_committed_offset + 1, - ), - True, + OffsetsRange( + partition_session_committed_offset, + partition_session_committed_offset + 1, + ), + True, ), ( - OffsetsRange( - partition_session_committed_offset - 1, - partition_session_committed_offset, - ), - False, + OffsetsRange( + partition_session_committed_offset - 1, + partition_session_committed_offset, + ), + False, ), ], ) async def test_send_commit_messages( - self, - stream, - stream_reader: ReaderStream, - partition_session, - commit: OffsetsRange, - send_range: bool, + self, + stream, + stream_reader: ReaderStream, + partition_session, + commit: OffsetsRange, + send_range: bool, ): @dataclass class Commitable(datatypes.ICommittable): @@ -347,7 +349,7 @@ async def wait_message(): assert start_ack_waiters == partition_session._ack_waiters async def test_commit_ack_received( - self, stream_reader, stream, partition_session, second_partition_session + self, stream_reader, stream, partition_session, second_partition_session ): offset1 = self.partition_session_committed_offset + 1 waiter1 = partition_session.add_waiter(offset1) @@ -377,7 +379,7 @@ async def test_commit_ack_received( await wait_for_fast(waiter2.future) async def test_close_ack_waiters_when_close_stream_reader( - self, stream_reader_started: ReaderStream, partition_session + self, stream_reader_started: ReaderStream, partition_session ): waiter = partition_session.add_waiter( self.partition_session_committed_offset + 1 @@ -388,7 +390,7 @@ async def test_close_ack_waiters_when_close_stream_reader( waiter.future.result() async def test_commit_ranges_for_received_messages( - self, stream, stream_reader_started: ReaderStream, partition_session + self, stream, stream_reader_started: ReaderStream, partition_session ): m1 = self.create_message(partition_session, 1, 1) m2 = self.create_message(partition_session, 2, 10) @@ -410,131 +412,131 @@ async def test_commit_ranges_for_received_messages( "batch,data_out", [ ( - PublicBatch( - session_metadata={}, - messages=[ - PublicMessage( - seqno=1, - created_at=datetime.datetime(2023, 3, 14, 15, 41), - message_group_id="", - session_metadata={}, - offset=1, - written_at=datetime.datetime(2023, 3, 14, 15, 42), - producer_id="asd", - data=rb"123", - _partition_session=None, - _commit_start_offset=5, - _commit_end_offset=15, - ) - ], - _partition_session=None, - _bytes_size=0, - _codec=Codec.CODEC_RAW, - ), - [bytes(rb"123")], + PublicBatch( + session_metadata={}, + messages=[ + PublicMessage( + seqno=1, + created_at=datetime.datetime(2023, 3, 14, 15, 41), + message_group_id="", + session_metadata={}, + offset=1, + written_at=datetime.datetime(2023, 3, 14, 15, 42), + producer_id="asd", + data=rb"123", + _partition_session=None, + _commit_start_offset=5, + _commit_end_offset=15, + ) + ], + _partition_session=None, + _bytes_size=0, + _codec=Codec.CODEC_RAW, + ), + [bytes(rb"123")], ), ( - PublicBatch( - session_metadata={}, - messages=[ - PublicMessage( - seqno=1, - created_at=datetime.datetime(2023, 3, 14, 15, 41), - message_group_id="", - session_metadata={}, - offset=1, - written_at=datetime.datetime(2023, 3, 14, 15, 42), - producer_id="asd", - data=gzip.compress(rb"123"), - _partition_session=None, - _commit_start_offset=5, - _commit_end_offset=15, - ) - ], - _partition_session=None, - _bytes_size=0, - _codec=Codec.CODEC_GZIP, - ), - [bytes(rb"123")], + PublicBatch( + session_metadata={}, + messages=[ + PublicMessage( + seqno=1, + created_at=datetime.datetime(2023, 3, 14, 15, 41), + message_group_id="", + session_metadata={}, + offset=1, + written_at=datetime.datetime(2023, 3, 14, 15, 42), + producer_id="asd", + data=gzip.compress(rb"123"), + _partition_session=None, + _commit_start_offset=5, + _commit_end_offset=15, + ) + ], + _partition_session=None, + _bytes_size=0, + _codec=Codec.CODEC_GZIP, + ), + [bytes(rb"123")], ), ( - PublicBatch( - session_metadata={}, - messages=[ - PublicMessage( - seqno=1, - created_at=datetime.datetime(2023, 3, 14, 15, 41), - message_group_id="", - session_metadata={}, - offset=1, - written_at=datetime.datetime(2023, 3, 14, 15, 42), - producer_id="asd", - data=rb"123", - _partition_session=None, - _commit_start_offset=5, - _commit_end_offset=15, - ), - PublicMessage( - seqno=1, - created_at=datetime.datetime(2023, 3, 14, 15, 41), - message_group_id="", - session_metadata={}, - offset=1, - written_at=datetime.datetime(2023, 3, 14, 15, 42), - producer_id="asd", - data=rb"456", - _partition_session=None, - _commit_start_offset=5, - _commit_end_offset=15, - ), - ], - _partition_session=None, - _bytes_size=0, - _codec=Codec.CODEC_RAW, - ), - [bytes(rb"123"), bytes(rb"456")], + PublicBatch( + session_metadata={}, + messages=[ + PublicMessage( + seqno=1, + created_at=datetime.datetime(2023, 3, 14, 15, 41), + message_group_id="", + session_metadata={}, + offset=1, + written_at=datetime.datetime(2023, 3, 14, 15, 42), + producer_id="asd", + data=rb"123", + _partition_session=None, + _commit_start_offset=5, + _commit_end_offset=15, + ), + PublicMessage( + seqno=1, + created_at=datetime.datetime(2023, 3, 14, 15, 41), + message_group_id="", + session_metadata={}, + offset=1, + written_at=datetime.datetime(2023, 3, 14, 15, 42), + producer_id="asd", + data=rb"456", + _partition_session=None, + _commit_start_offset=5, + _commit_end_offset=15, + ), + ], + _partition_session=None, + _bytes_size=0, + _codec=Codec.CODEC_RAW, + ), + [bytes(rb"123"), bytes(rb"456")], ), ( - PublicBatch( - session_metadata={}, - messages=[ - PublicMessage( - seqno=1, - created_at=datetime.datetime(2023, 3, 14, 15, 41), - message_group_id="", - session_metadata={}, - offset=1, - written_at=datetime.datetime(2023, 3, 14, 15, 42), - producer_id="asd", - data=gzip.compress(rb"123"), - _partition_session=None, - _commit_start_offset=5, - _commit_end_offset=15, - ), - PublicMessage( - seqno=1, - created_at=datetime.datetime(2023, 3, 14, 15, 41), - message_group_id="", - session_metadata={}, - offset=1, - written_at=datetime.datetime(2023, 3, 14, 15, 42), - producer_id="asd", - data=gzip.compress(rb"456"), - _partition_session=None, - _commit_start_offset=5, - _commit_end_offset=15, - ), - ], - _partition_session=None, - _bytes_size=0, - _codec=Codec.CODEC_GZIP, - ), - [bytes(rb"123"), bytes(rb"456")], + PublicBatch( + session_metadata={}, + messages=[ + PublicMessage( + seqno=1, + created_at=datetime.datetime(2023, 3, 14, 15, 41), + message_group_id="", + session_metadata={}, + offset=1, + written_at=datetime.datetime(2023, 3, 14, 15, 42), + producer_id="asd", + data=gzip.compress(rb"123"), + _partition_session=None, + _commit_start_offset=5, + _commit_end_offset=15, + ), + PublicMessage( + seqno=1, + created_at=datetime.datetime(2023, 3, 14, 15, 41), + message_group_id="", + session_metadata={}, + offset=1, + written_at=datetime.datetime(2023, 3, 14, 15, 42), + producer_id="asd", + data=gzip.compress(rb"456"), + _partition_session=None, + _commit_start_offset=5, + _commit_end_offset=15, + ), + ], + _partition_session=None, + _bytes_size=0, + _codec=Codec.CODEC_GZIP, + ), + [bytes(rb"123"), bytes(rb"456")], ), ], ) async def test_decode_loop( - self, stream_reader, batch: PublicBatch, data_out: typing.List[bytes] + self, stream_reader, batch: PublicBatch, data_out: typing.List[bytes] ): assert len(batch.messages) == len(data_out) @@ -549,7 +551,7 @@ async def test_decode_loop( assert batch == expected async def test_error_from_status_code( - self, stream, stream_reader_finish_with_error + self, stream, stream_reader_finish_with_error ): # noinspection PyTypeChecker stream.from_server.put_nowait( @@ -609,11 +611,11 @@ async def test_init_reader(self, stream, default_reader_settings): await reader.close() async def test_start_partition( - self, - stream_reader: ReaderStream, - stream, - default_reader_settings, - partition_session, + self, + stream_reader: ReaderStream, + stream, + default_reader_settings, + partition_session, ): def session_count(): return len(stream_reader._partition_sessions) @@ -653,8 +655,8 @@ def session_count(): assert len(stream_reader._partition_sessions) == initial_session_count + 1 assert stream_reader._partition_sessions[ - test_partition_session_id - ] == datatypes.PartitionSession( + test_partition_session_id + ] == datatypes.PartitionSession( id=test_partition_session_id, state=datatypes.PartitionSession.State.Active, topic_path=test_topic_path, @@ -689,7 +691,7 @@ def session_count(): assert partition_session.id not in stream_reader._partition_sessions async def test_partition_stop_graceful( - self, stream, stream_reader, partition_session + self, stream, stream_reader, partition_session ): def session_count(): return len(stream_reader._partition_sessions) @@ -732,11 +734,11 @@ def session_count(): stream.from_client.get_nowait() async def test_receive_message_from_server( - self, - stream_reader, - stream, - partition_session: datatypes.PartitionSession, - second_partition_session, + self, + stream_reader, + stream, + partition_session: datatypes.PartitionSession, + second_partition_session, ): def reader_batch_count(): return len(stream_reader._message_batches) @@ -814,7 +816,7 @@ def reader_batch_count(): ) async def test_read_batches( - self, stream_reader, partition_session, second_partition_session + self, stream_reader, partition_session, second_partition_session ): created_at = datetime.datetime(2020, 2, 1, 18, 12) created_at2 = datetime.datetime(2020, 2, 2, 18, 12) @@ -993,94 +995,93 @@ async def test_read_batches( ) @pytest.mark.parametrize( - 'batches_before,expected_message,batches_after', + "batches_before,expected_message,batches_after", [ + ([], None, []), ( - [], - None, - [] - ), - ( - [PublicBatch( + [ + PublicBatch( session_metadata={}, messages=[stub_message(1)], _partition_session=stub_partition_session(), _bytes_size=0, _codec=Codec.CODEC_RAW, - )], - stub_message(1), - [] + ) + ], + stub_message(1), + [], ), ( - [ - PublicBatch( - session_metadata={}, - messages=[stub_message(1), stub_message(2)], - _partition_session=stub_partition_session(), - _bytes_size=0, - _codec=Codec.CODEC_RAW, - ), - PublicBatch( - session_metadata={}, - messages=[stub_message(3), stub_message(4)], - _partition_session=stub_partition_session(), - _bytes_size=0, - _codec=Codec.CODEC_RAW, - ) - ], - stub_message(1), - [ - PublicBatch( - session_metadata={}, - messages=[stub_message(2)], - _partition_session=stub_partition_session(), - _bytes_size=0, - _codec=Codec.CODEC_RAW, - ), - PublicBatch( - session_metadata={}, - messages=[stub_message(3), stub_message(4)], - _partition_session=stub_partition_session(), - _bytes_size=0, - _codec=Codec.CODEC_RAW, - ) - ], + [ + PublicBatch( + session_metadata={}, + messages=[stub_message(1), stub_message(2)], + _partition_session=stub_partition_session(), + _bytes_size=0, + _codec=Codec.CODEC_RAW, + ), + PublicBatch( + session_metadata={}, + messages=[stub_message(3), stub_message(4)], + _partition_session=stub_partition_session(), + _bytes_size=0, + _codec=Codec.CODEC_RAW, + ), + ], + stub_message(1), + [ + PublicBatch( + session_metadata={}, + messages=[stub_message(2)], + _partition_session=stub_partition_session(), + _bytes_size=0, + _codec=Codec.CODEC_RAW, + ), + PublicBatch( + session_metadata={}, + messages=[stub_message(3), stub_message(4)], + _partition_session=stub_partition_session(), + _bytes_size=0, + _codec=Codec.CODEC_RAW, + ), + ], ), ( - [ - PublicBatch( - session_metadata={}, - messages=[stub_message(1)], - _partition_session=stub_partition_session(), - _bytes_size=0, - _codec=Codec.CODEC_RAW, - ), - PublicBatch( - session_metadata={}, - messages=[stub_message(2), stub_message(3)], - _partition_session=stub_partition_session(), - _bytes_size=0, - _codec=Codec.CODEC_RAW, - ), - ], - stub_message(1), - [PublicBatch( + [ + PublicBatch( + session_metadata={}, + messages=[stub_message(1)], + _partition_session=stub_partition_session(), + _bytes_size=0, + _codec=Codec.CODEC_RAW, + ), + PublicBatch( + session_metadata={}, + messages=[stub_message(2), stub_message(3)], + _partition_session=stub_partition_session(), + _bytes_size=0, + _codec=Codec.CODEC_RAW, + ), + ], + stub_message(1), + [ + PublicBatch( session_metadata={}, messages=[stub_message(2), stub_message(3)], _partition_session=stub_partition_session(), _bytes_size=0, _codec=Codec.CODEC_RAW, - )], + ) + ], ), - - ] + ], ) async def test_read_message( - self, - stream_reader, - batches_before: typing.List[datatypes.PublicBatch], - expected_message: PublicMessage, - batches_after: typing.List[datatypes.PublicBatch], + self, + stream_reader, + batches_before: typing.List[datatypes.PublicBatch], + expected_message: PublicMessage, + batches_after: typing.List[datatypes.PublicBatch], ): stream_reader._message_batches = deque(batches_before) mess = stream_reader.receive_message_nowait() @@ -1118,17 +1119,17 @@ async def test_receive_batch_nowait(self, stream, stream_reader, partition_sessi ) assert ( - stream_reader._buffer_size_bytes - == initial_buffer_size + 2 * self.default_batch_size + stream_reader._buffer_size_bytes + == initial_buffer_size + 2 * self.default_batch_size ) assert ( - StreamReadMessage.ReadRequest(self.default_batch_size) - == stream.from_client.get_nowait().client_message + StreamReadMessage.ReadRequest(self.default_batch_size) + == stream.from_client.get_nowait().client_message ) assert ( - StreamReadMessage.ReadRequest(self.default_batch_size) - == stream.from_client.get_nowait().client_message + StreamReadMessage.ReadRequest(self.default_batch_size) + == stream.from_client.get_nowait().client_message ) with pytest.raises(asyncio.QueueEmpty): @@ -1193,9 +1194,9 @@ async def wait_messages(): stream_index = 0 async def stream_create( - reader_reconnector_id: int, - driver: SupportedDriverType, - settings: PublicReaderSettings, + reader_reconnector_id: int, + driver: SupportedDriverType, + settings: PublicReaderSettings, ): nonlocal stream_index stream_index += 1 diff --git a/ydb/_topic_reader/topic_reader_sync.py b/ydb/_topic_reader/topic_reader_sync.py index 2d54ead4..72543e76 100644 --- a/ydb/_topic_reader/topic_reader_sync.py +++ b/ydb/_topic_reader/topic_reader_sync.py @@ -87,7 +87,9 @@ def messages( """ raise NotImplementedError() - def receive_message(self, *, timeout: Union[float, None] = None) -> datatypes.PublicMessage: + def receive_message( + self, *, timeout: Union[float, None] = None + ) -> datatypes.PublicMessage: """ Block until receive new message It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available. @@ -103,7 +105,9 @@ def receive_message(self, *, timeout: Union[float, None] = None) -> datatypes.Pu return self._call_sync(self._async_reader.receive_message(), timeout) def _receive_message_nowait(self) -> Optional[datatypes.PublicMessage]: - return self._call_nowait(lambda: self._async_reader._reconnector.receive_message_nowait()) + return self._call_nowait( + lambda: self._async_reader._reconnector.receive_message_nowait() + ) def async_wait_message(self) -> concurrent.futures.Future: """ @@ -155,8 +159,9 @@ def receive_batch( ) def _receive_batch_nowait(self) -> Optional[PublicBatch]: - return self._caller.call_sync(lambda: self._async_reader._reconnector.receive_batch_nowait()) - + return self._caller.call_sync( + lambda: self._async_reader._reconnector.receive_batch_nowait() + ) def commit( self, mess: typing.Union[datatypes.PublicMessage, datatypes.PublicBatch] From a2cda766eb78b7c61a3bb7d433f5345c68353b23 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Mon, 20 Mar 2023 13:39:19 +0300 Subject: [PATCH 126/147] read one message --- tests/topics/test_topic_reader.py | 32 ++++++++++++++++++++++++-- ydb/_topic_reader/topic_reader_sync.py | 32 +++++++++++--------------- 2 files changed, 44 insertions(+), 20 deletions(-) diff --git a/tests/topics/test_topic_reader.py b/tests/topics/test_topic_reader.py index 2a451baf..84d61a43 100644 --- a/tests/topics/test_topic_reader.py +++ b/tests/topics/test_topic_reader.py @@ -5,12 +5,26 @@ @pytest.mark.asyncio class TestTopicReaderAsyncIO: + async def test_read_batch( + self, driver, topic_path, topic_with_messages, topic_consumer + ): + reader = driver.topic_client.reader(topic_consumer, topic_path) + batch = await reader.receive_batch() + + assert batch is not None + assert len(batch.messages) > 0 + + await reader.close() + async def test_read_message( self, driver, topic_path, topic_with_messages, topic_consumer ): reader = driver.topic_client.reader(topic_consumer, topic_path) + msg = await reader.receive_message() + + assert msg is not None + assert msg.seqno - assert await reader.receive_batch() is not None await reader.close() async def test_read_and_commit_message( @@ -59,12 +73,26 @@ def decode(b: bytes): class TestTopicReaderSync: + def test_read_batch( + self, driver_sync, topic_path, topic_with_messages, topic_consumer + ): + reader = driver_sync.topic_client.reader(topic_consumer, topic_path) + batch = reader.receive_batch() + + assert batch is not None + assert len(batch.messages) > 0 + + reader.close() + def test_read_message( self, driver_sync, topic_path, topic_with_messages, topic_consumer ): reader = driver_sync.topic_client.reader(topic_consumer, topic_path) + msg = reader.receive_message() + + assert msg is not None + assert msg.seqno - assert reader.receive_batch() is not None reader.close() def test_read_and_commit_message( diff --git a/ydb/_topic_reader/topic_reader_sync.py b/ydb/_topic_reader/topic_reader_sync.py index 72543e76..ed9730fa 100644 --- a/ydb/_topic_reader/topic_reader_sync.py +++ b/ydb/_topic_reader/topic_reader_sync.py @@ -83,12 +83,13 @@ def messages( It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available. if no new message in timeout seconds (default - infinite): stop iterations by raise StopIteration - if timeout <= 0 - it will fast non block method, get messages from internal buffer only. + if timeout <= 0 - it will fast wait only one event loop cycle - without wait any i/o operations or pauses, + get messages from internal buffer only. """ raise NotImplementedError() def receive_message( - self, *, timeout: Union[float, None] = None + self, *, timeout: TimeoutType = None ) -> datatypes.PublicMessage: """ Block until receive new message @@ -97,16 +98,12 @@ def receive_message( or connection to server lost if no new message in timeout seconds (default - infinite): raise TimeoutError() - if timeout <= 0 - it will fast non block method, get messages from internal buffer only. + if timeout <= 0 - it will fast wait only one event loop cycle - without wait any i/o operations or pauses, get messages from internal buffer only. """ - if timeout <= 0: - return self._receive_message_nowait() - - return self._call_sync(self._async_reader.receive_message(), timeout) + self._check_closed() - def _receive_message_nowait(self) -> Optional[datatypes.PublicMessage]: - return self._call_nowait( - lambda: self._async_reader._reconnector.receive_message_nowait() + return self._caller.safe_call_with_result( + self._async_reader.receive_message(), timeout ) def async_wait_message(self) -> concurrent.futures.Future: @@ -117,7 +114,11 @@ def async_wait_message(self) -> concurrent.futures.Future: Possible situation when receive signal about message available, but no messages when try to receive a message. If message expired between send event and try to retrieve message (for example connection broken). """ - return self._call(self._async_reader._reconnector.wait_message()) + self._check_closed() + + return self._caller.unsafe_call_with_future( + self._async_reader._reconnector.wait_message() + ) def batches( self, @@ -131,7 +132,7 @@ def batches( It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available. if no new message in timeout seconds (default - infinite): stop iterations by raise StopIteration - if timeout <= 0 - it will fast non block method, get messages from internal buffer only. + if timeout <= 0 - it will fast wait only one event loop cycle - without wait any i/o operations or pauses, get messages from internal buffer only. """ raise NotImplementedError() @@ -147,7 +148,7 @@ def receive_batch( It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available. if no new message in timeout seconds (default - infinite): raise TimeoutError() - if timeout <= 0 - it will fast non block method, get messages from internal buffer only. + if timeout <= 0 - it will fast wait only one event loop cycle - without wait any i/o operations or pauses, get messages from internal buffer only. """ self._check_closed() @@ -158,11 +159,6 @@ def receive_batch( timeout, ) - def _receive_batch_nowait(self) -> Optional[PublicBatch]: - return self._caller.call_sync( - lambda: self._async_reader._reconnector.receive_batch_nowait() - ) - def commit( self, mess: typing.Union[datatypes.PublicMessage, datatypes.PublicBatch] ): From 2ef00268d0690bcbb632dd7c999495ed39a9aa6e Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Mon, 20 Mar 2023 16:57:09 +0300 Subject: [PATCH 127/147] style fix --- ydb/_topic_common/common.py | 8 +++----- ydb/_topic_common/common_test.py | 5 +---- ydb/_topic_writer/topic_writer.py | 2 +- 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/ydb/_topic_common/common.py b/ydb/_topic_common/common.py index d064da01..ab9aae5c 100644 --- a/ydb/_topic_common/common.py +++ b/ydb/_topic_common/common.py @@ -86,7 +86,8 @@ def unsafe_call_with_result(self, coro: typing.Coroutine, timeout: TimeoutType): except concurrent.futures.TimeoutError: raise TimeoutError() finally: - f.cancel() + if not f.done(): + f.cancel() def safe_call_with_result(self, coro: typing.Coroutine, timeout: TimeoutType): """ @@ -126,11 +127,8 @@ async def call_coro(): except asyncio.CancelledError: res.set_exception(TimeoutError()) - async def sleep0(): - await asyncio.sleep(0) - coro_future = asyncio.run_coroutine_threadsafe(call_coro(), self._loop) - asyncio.run_coroutine_threadsafe(sleep0(), self._loop).result() + asyncio.run_coroutine_threadsafe(asyncio.sleep(0), self._loop).result() coro_future.cancel() return res.result() diff --git a/ydb/_topic_common/common_test.py b/ydb/_topic_common/common_test.py index b292862f..b31f9af9 100644 --- a/ydb/_topic_common/common_test.py +++ b/ydb/_topic_common/common_test.py @@ -217,11 +217,8 @@ async def callback(): caller.safe_call_with_result(callback(), timeout) finished = time.monotonic() - async def sleep0(): - await asyncio.sleep(0) - # wait one loop for handle task cancelation - asyncio.run_coroutine_threadsafe(sleep0(), separate_loop) + asyncio.run_coroutine_threadsafe(asyncio.sleep(0), separate_loop) assert callback_loop is separate_loop assert finished - start > timeout diff --git a/ydb/_topic_writer/topic_writer.py b/ydb/_topic_writer/topic_writer.py index 97c05420..59ad74ff 100644 --- a/ydb/_topic_writer/topic_writer.py +++ b/ydb/_topic_writer/topic_writer.py @@ -178,7 +178,7 @@ def __init__(self, message: str): class TopicWriterClosedError(ydb.Error): def __init__(self): - super(TopicWriterClosedError, self).__init__("Topic writer already closed") + super().__init__("Topic writer already closed") class TopicWriterRepeatableError(TopicWriterError): From 8aae45029054dce5ef7ed01c3b448f9f94126cc8 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Mon, 20 Mar 2023 18:34:09 +0300 Subject: [PATCH 128/147] Update CHANGELOG.md --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 52c28763..3c1e5cf0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ * Add function for global revert broken behaviour: ydb.global_allow_truncated_result, ydb.global_allow_split_transactions * Change argument names from deny_split_transactions to allow_split_transactions (with reverse value +* Fixed check retriable for idempotent error +* Reader codecs +* Read one message +* fixed sqlalchemy get_columns method with not null columns ## 3.0.1b8 ## * Fixed exception while create ResultSet with None table_settings From b56b11992c911c248a6e10fd52a0976092f8f2f9 Mon Sep 17 00:00:00 2001 From: robot Date: Mon, 20 Mar 2023 15:35:17 +0000 Subject: [PATCH 129/147] Release: 3.0.1b9 --- CHANGELOG.md | 1 + setup.py | 2 +- ydb/ydb_version.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3c1e5cf0..59f08976 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,4 @@ +## 3.0.1b9 ## * Add function for global revert broken behaviour: ydb.global_allow_truncated_result, ydb.global_allow_split_transactions * Change argument names from deny_split_transactions to allow_split_transactions (with reverse value * Fixed check retriable for idempotent error diff --git a/setup.py b/setup.py index 008c407d..cd9569cb 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ setuptools.setup( name="ydb", - version="3.0.1b8", # AUTOVERSION + version="3.0.1b9", # AUTOVERSION description="YDB Python SDK", author="Yandex LLC", author_email="ydb@yandex-team.ru", diff --git a/ydb/ydb_version.py b/ydb/ydb_version.py index b0649f64..ef5ee52a 100644 --- a/ydb/ydb_version.py +++ b/ydb/ydb_version.py @@ -1 +1 @@ -VERSION = "3.0.1b8" +VERSION = "3.0.1b9" From dc901767497477c3952775bf170a111b8ff396aa Mon Sep 17 00:00:00 2001 From: Valeriya Popova Date: Mon, 20 Mar 2023 18:52:50 +0300 Subject: [PATCH 130/147] fix sqlalchemy nullable --- examples/_sqlalchemy_example/example.py | 2 +- tests/_sqlalchemy/_test_inspect.py | 24 ++++++++++++++++++++++++ ydb/_sqlalchemy/__init__.py | 13 ++++++++----- 3 files changed, 33 insertions(+), 6 deletions(-) create mode 100644 tests/_sqlalchemy/_test_inspect.py diff --git a/examples/_sqlalchemy_example/example.py b/examples/_sqlalchemy_example/example.py index 96f47820..70f4c465 100644 --- a/examples/_sqlalchemy_example/example.py +++ b/examples/_sqlalchemy_example/example.py @@ -196,7 +196,7 @@ def run_example_core(engine): def main(): parser = argparse.ArgumentParser( formatter_class=argparse.RawDescriptionHelpFormatter, - description="""\033[92mYandex.Database examples _sqlalchemy usage.\x1b[0m\n""", + description="""\033[92mYandex.Database examples sqlalchemy usage.\x1b[0m\n""", ) parser.add_argument( "-d", diff --git a/tests/_sqlalchemy/_test_inspect.py b/tests/_sqlalchemy/_test_inspect.py new file mode 100644 index 00000000..69a57243 --- /dev/null +++ b/tests/_sqlalchemy/_test_inspect.py @@ -0,0 +1,24 @@ +import ydb + +import sqlalchemy as sa + + +def test_get_columns(driver_sync, engine): + session = ydb.retry_operation_sync( + lambda: driver_sync.table_client.session().create() + ) + session.execute_scheme( + "CREATE TABLE test(id Int64 NOT NULL, value TEXT, num DECIMAL(22, 9), PRIMARY KEY (id))" + ) + inspect = sa.inspect(engine) + columns = inspect.get_columns("test") + for c in columns: + c["type"] = type(c["type"]) + + assert columns == [ + {"name": "id", "type": sa.INTEGER, "nullable": False}, + {"name": "value", "type": sa.TEXT, "nullable": True}, + {"name": "num", "type": sa.DECIMAL, "nullable": True}, + ] + + session.execute_scheme("DROP TABLE test") diff --git a/ydb/_sqlalchemy/__init__.py b/ydb/_sqlalchemy/__init__.py index 8336a9a8..d8931a5d 100644 --- a/ydb/_sqlalchemy/__init__.py +++ b/ydb/_sqlalchemy/__init__.py @@ -206,14 +206,16 @@ def upsert(table): } -def _get_column_type(t): +def _get_column_info(t): + nullable = False if isinstance(t, ydb.OptionalType): + nullable = True t = t.item if isinstance(t, ydb.DecimalType): - return sa.DECIMAL(precision=t.item.precision, scale=t.item.scale) + return sa.DECIMAL(precision=t.precision, scale=t.scale), nullable - return COLUMN_TYPES[t] + return COLUMN_TYPES[t], nullable class YqlDialect(DefaultDialect): @@ -268,11 +270,12 @@ def get_columns(self, connection, table_name, schema=None, **kw): columns = raw_conn.describe(qt) as_compatible = [] for column in columns: + col_type, nullable = _get_column_info(column.type) as_compatible.append( { "name": column.name, - "type": _get_column_type(column.type), - "nullable": True, + "type": col_type, + "nullable": nullable, } ) From 8aee8e15903dcccf6a3f0f6b281d147e089fac1b Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Mon, 20 Mar 2023 19:03:26 +0300 Subject: [PATCH 131/147] style fix --- ydb/_topic_common/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ydb/_topic_common/common.py b/ydb/_topic_common/common.py index ab9aae5c..9e8f1326 100644 --- a/ydb/_topic_common/common.py +++ b/ydb/_topic_common/common.py @@ -30,7 +30,7 @@ def wrapper(rpc_state, response_pb, driver=None): _shared_event_loop_lock = threading.Lock() -_shared_event_loop = None # type: Optional[asyncio.AbstractEventLoop] +_shared_event_loop: Optional[asyncio.AbstractEventLoop] = None def _get_shared_event_loop() -> asyncio.AbstractEventLoop: From c73d9f6475fc631eccce7083a9bc14082c380a01 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 21 Mar 2023 08:10:10 +0300 Subject: [PATCH 132/147] style fix --- ydb/_topic_reader/datatypes.py | 9 ++------- ydb/_topic_reader/topic_reader_asyncio.py | 4 ++-- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/ydb/_topic_reader/datatypes.py b/ydb/_topic_reader/datatypes.py index 860525ab..cff0fed8 100644 --- a/ydb/_topic_reader/datatypes.py +++ b/ydb/_topic_reader/datatypes.py @@ -179,7 +179,7 @@ def _commit_get_offsets_range(self) -> OffsetsRange: self.messages[-1]._commit_get_offsets_range().end, ) - def is_empty(self) -> bool: + def empty(self) -> bool: return len(self.messages) == 0 # ISessionAlive implementation @@ -192,9 +192,4 @@ def is_alive(self) -> bool: ) def pop_message(self) -> PublicMessage: - if len(self.messages) == 0: - raise IndexError() - - res = self.messages[0] - self.messages = self.messages[1:] - return res + return self.messages.pop() diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index 7266ae43..0068e4ba 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -125,7 +125,7 @@ async def receive_batch( await self._reconnector.wait_message() return self._reconnector.receive_batch_nowait() - async def receive_message(self) -> typing.Union[datatypes.PublicMessage, None]: + async def receive_message(self) -> typing.Optional[datatypes.PublicMessage]: """ Block until receive new message @@ -414,7 +414,7 @@ def receive_message_nowait(self): except IndexError: return None - if batch.is_empty(): + if batch.empty(): self._message_batches.popleft() return message From 8c1092451b35013cac031281707cb8fab76e9c17 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 21 Mar 2023 09:39:56 +0300 Subject: [PATCH 133/147] fix typo --- ydb/_topic_reader/datatypes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ydb/_topic_reader/datatypes.py b/ydb/_topic_reader/datatypes.py index cff0fed8..5376c76d 100644 --- a/ydb/_topic_reader/datatypes.py +++ b/ydb/_topic_reader/datatypes.py @@ -192,4 +192,4 @@ def is_alive(self) -> bool: ) def pop_message(self) -> PublicMessage: - return self.messages.pop() + return self.messages.pop(0) From 4bf3ef8a8be437e77265686fdabba84766012c50 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 21 Mar 2023 15:02:31 +0300 Subject: [PATCH 134/147] Update requirements.txt --- requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index fc27de0a..6a60288e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ grpcio>=1.42.0 packaging protobuf>=3.13.0,<5.0.0 -pytest==6.2.4 aiohttp==3.7.4 From cca8f47d73f5aec575f3f83e6c83b587645be46a Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 21 Mar 2023 15:03:39 +0300 Subject: [PATCH 135/147] Update CHANGELOG.md --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 59f08976..f3cd5eef 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,6 @@ +* fixed sqlalchemy get_columns method with not null columns +* fixed requirements.txt + ## 3.0.1b9 ## * Add function for global revert broken behaviour: ydb.global_allow_truncated_result, ydb.global_allow_split_transactions * Change argument names from deny_split_transactions to allow_split_transactions (with reverse value From a90394b3eb4a5f74ae52120316893ccfa14555cc Mon Sep 17 00:00:00 2001 From: robot Date: Tue, 21 Mar 2023 12:04:43 +0000 Subject: [PATCH 136/147] Release: 3.0.1b10 --- CHANGELOG.md | 1 + setup.py | 2 +- ydb/ydb_version.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f3cd5eef..3fa42abc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,4 @@ +## 3.0.1b10 ## * fixed sqlalchemy get_columns method with not null columns * fixed requirements.txt diff --git a/setup.py b/setup.py index cd9569cb..dd80a6b4 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ setuptools.setup( name="ydb", - version="3.0.1b9", # AUTOVERSION + version="3.0.1b10", # AUTOVERSION description="YDB Python SDK", author="Yandex LLC", author_email="ydb@yandex-team.ru", diff --git a/ydb/ydb_version.py b/ydb/ydb_version.py index ef5ee52a..92fd0d98 100644 --- a/ydb/ydb_version.py +++ b/ydb/ydb_version.py @@ -1 +1 @@ -VERSION = "3.0.1b9" +VERSION = "3.0.1b10" From 88defe365bfe0b104c5375132929fab3cf4144ba Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 21 Mar 2023 15:36:47 +0300 Subject: [PATCH 137/147] fix default truncated response --- ydb/_topic_common/common.py | 2 +- ydb/aio/table.py | 5 +---- ydb/global_settings.py | 10 ++++++++-- ydb/table.py | 20 +++++++++++--------- 4 files changed, 21 insertions(+), 16 deletions(-) diff --git a/ydb/_topic_common/common.py b/ydb/_topic_common/common.py index ab9aae5c..9e8f1326 100644 --- a/ydb/_topic_common/common.py +++ b/ydb/_topic_common/common.py @@ -30,7 +30,7 @@ def wrapper(rpc_state, response_pb, driver=None): _shared_event_loop_lock = threading.Lock() -_shared_event_loop = None # type: Optional[asyncio.AbstractEventLoop] +_shared_event_loop: Optional[asyncio.AbstractEventLoop] = None def _get_shared_event_loop() -> asyncio.AbstractEventLoop: diff --git a/ydb/aio/table.py b/ydb/aio/table.py index 92ed9812..06f8ca7c 100644 --- a/ydb/aio/table.py +++ b/ydb/aio/table.py @@ -13,7 +13,6 @@ _scan_query_request_factory, _wrap_scan_query_response, BaseTxContext, - _allow_split_transaction, ) from . import _utilities from ydb import _apis, _session_impl @@ -121,9 +120,7 @@ async def alter_table( set_read_replicas_settings, ) - def transaction( - self, tx_mode=None, *, allow_split_transactions=_allow_split_transaction - ): + def transaction(self, tx_mode=None, *, allow_split_transactions=None): return TxContext( self._driver, self._state, diff --git a/ydb/global_settings.py b/ydb/global_settings.py index f1fabcc2..8edac3f4 100644 --- a/ydb/global_settings.py +++ b/ydb/global_settings.py @@ -5,6 +5,9 @@ def global_allow_truncated_result(enabled: bool = True): + if convert._default_allow_truncated_result == enabled: + return + if enabled: warnings.warn("Global allow truncated response is deprecated behaviour.") @@ -12,7 +15,10 @@ def global_allow_truncated_result(enabled: bool = True): def global_allow_split_transactions(enabled: bool): + if table._default_allow_split_transaction == enabled: + return + if enabled: - warnings.warn("Global allow truncated response is deprecated behaviour.") + warnings.warn("Global allow split transaction is deprecated behaviour.") - table._allow_split_transaction = enabled + table._default_allow_split_transaction = enabled diff --git a/ydb/table.py b/ydb/table.py index eb3a9780..799a5426 100644 --- a/ydb/table.py +++ b/ydb/table.py @@ -27,7 +27,7 @@ except ImportError: interceptor = None -_allow_split_transaction = False +_default_allow_split_transaction = False logger = logging.getLogger(__name__) @@ -1181,9 +1181,7 @@ def execute_scheme(self, yql_text, settings=None): pass @abstractmethod - def transaction( - self, tx_mode=None, allow_split_transactions=_allow_split_transaction - ): + def transaction(self, tx_mode=None, allow_split_transactions=None): pass @abstractmethod @@ -1687,9 +1685,7 @@ def execute_scheme(self, yql_text, settings=None): self._state.endpoint, ) - def transaction( - self, tx_mode=None, allow_split_transactions=_allow_split_transaction - ): + def transaction(self, tx_mode=None, allow_split_transactions=None): return TxContext( self._driver, self._state, @@ -2226,7 +2222,7 @@ def __init__( session, tx_mode=None, *, - allow_split_transactions=_allow_split_transaction + allow_split_transactions=None ): """ An object that provides a simple transaction context manager that allows statements execution @@ -2413,7 +2409,13 @@ def _check_split(self, allow=""): Deny all operaions with transaction after commit/rollback. Exception: double commit and double rollbacks, because it is safe """ - if self._allow_split_transactions: + allow_split_transaction = ( + self._allow_split_transactions + if self._allow_split_transactions is not None + else _default_allow_split_transaction + ) + + if allow_split_transaction: return if self._finished != "" and self._finished != allow: From 3ebf21d1c1c9cffd8074a81fa78b8d9ff4817b3f Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 21 Mar 2023 17:45:31 +0300 Subject: [PATCH 138/147] Update CHANGELOG.md --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3fa42abc..fd75caad 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,6 @@ +* Fixed global_allow_split_transactions +* Added reader.receive_message() method + ## 3.0.1b10 ## * fixed sqlalchemy get_columns method with not null columns * fixed requirements.txt From 677edfdcb8a1607ab90d6776d2eb29ec24b5af1e Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 21 Mar 2023 19:55:26 +0300 Subject: [PATCH 139/147] sync --- examples/topic/reader_async_example.py | 78 ++++++++++---------------- 1 file changed, 29 insertions(+), 49 deletions(-) diff --git a/examples/topic/reader_async_example.py b/examples/topic/reader_async_example.py index fb7f4c26..448e3ded 100644 --- a/examples/topic/reader_async_example.py +++ b/examples/topic/reader_async_example.py @@ -15,44 +15,32 @@ async def connect(): async def create_reader_and_close_with_context_manager(db: ydb.aio.Driver): - with ydb.TopicClientAsyncIO(db).reader( + async with ydb.TopicClientAsyncIO(db).reader( "/database/topic/path", consumer="consumer" ) as reader: - async for message in reader.messages(): - pass + ... async def print_message_content(reader: ydb.TopicReaderAsyncIO): - async for message in reader.messages(): + while True: + message = await reader.receive_message() print("text", message.data.read().decode("utf-8")) # await and async_commit need only for sync commit mode - for wait ack from servr await reader.commit(message) -async def process_messages_batch_explicit_commit(reader: ydb.TopicReaderAsyncIO): +async def process_messages_batch_with_commit(reader: ydb.TopicReaderAsyncIO): # Explicit commit example - async for batch in reader.batches(max_messages=100, timeout=2): - async with asyncio.TaskGroup() as tg: - for message in batch.messages: - tg.create_task(_process(message)) - - # wait complete of process all messages from batch be taskgroup context manager - # and commit complete batch + while True: + batch = await reader.receive_batch() + ... await reader.commit(batch) -async def process_messages_batch_context_manager_commit(reader: ydb.TopicReaderAsyncIO): - # Commit with context manager - async for batch in reader.batches(): - async with reader.commit_on_exit(batch), asyncio.TaskGroup() as tg: - for message in batch.messages: - tg.create_task(_process(message)) - - async def get_message_with_timeout(reader: ydb.TopicReaderAsyncIO): try: message = await asyncio.wait_for(reader.receive_message(), timeout=1) - except TimeoutError: + except asyncio.TimeoutError: print("Have no new messages in a second") return @@ -60,16 +48,19 @@ async def get_message_with_timeout(reader: ydb.TopicReaderAsyncIO): async def get_all_messages_with_small_wait(reader: ydb.TopicReaderAsyncIO): - async for message in reader.messages(timeout=1): - await _process(message) - print("Have no new messages in a second") + while True: + try: + message = await reader.receive_message() + await _process(message) + except asyncio.TimeoutError: + print("Have no new messages in a second") async def get_a_message_from_external_loop(reader: ydb.TopicReaderAsyncIO): for i in range(10): try: message = await asyncio.wait_for(reader.receive_message(), timeout=1) - except TimeoutError: + except asyncio.TimeoutError: return await _process(message) @@ -78,7 +69,7 @@ async def get_one_batch_from_external_loop_async(reader: ydb.TopicReaderAsyncIO) for i in range(10): try: batch = await asyncio.wait_for(reader.receive_batch(), timeout=2) - except TimeoutError: + except asyncio.TimeoutError: return for message in batch.messages: @@ -92,27 +83,20 @@ async def auto_deserialize_message(db: ydb.aio.Driver): async with ydb.TopicClientAsyncIO(db).reader( "/database/topic/path", consumer="asd", deserializer=json.loads ) as reader: - async for message in reader.messages(): + while True: + message = await reader.receive_message() print( message.data.Name ) # message.data replaces by json.loads(message.data) of raw message reader.commit(message) -async def commit_batch_with_context(reader: ydb.TopicReaderAsyncIO): - async for batch in reader.batches(): - async with reader.commit_on_exit(batch): - for message in batch.messages: - if not batch.is_alive: - break - await _process(message) - - async def handle_partition_stop(reader: ydb.TopicReaderAsyncIO): - async for message in reader.messages(): - time.sleep(1) # some work + while True: + message = await reader.receive_message() + time.sleep(123) # some work if message.is_alive: - time.sleep(123) # some other work + time.sleep(1) # some other work await reader.commit(message) @@ -126,7 +110,8 @@ def process_batch(batch): _process(message) reader.commit(batch) - async for batch in reader.batches(): + while True: + batch = await reader.receive_batch() process_batch(batch) @@ -137,18 +122,12 @@ async def connect_and_read_few_topics(db: ydb.aio.Driver): ydb.TopicSelector("/database/second-topic", partitions=3), ] ) as reader: - async for message in reader.messages(): + while True: + message = await reader.receive_message() await _process(message) await reader.commit(message) -async def handle_partition_graceful_stop_batch(reader: ydb.TopicReaderAsyncIO): - # no special handle, but batch will contain less than prefer count messages - async for batch in reader.batches(): - await _process(batch) - reader.commit(batch) - - async def advanced_commit_notify(db: ydb.aio.Driver): def on_commit(event: ydb.TopicReaderEvents.OnCommit) -> None: print(event.topic) @@ -157,7 +136,8 @@ def on_commit(event: ydb.TopicReaderEvents.OnCommit) -> None: async with ydb.TopicClientAsyncIO(db).reader( "/local", consumer="consumer", commit_batch_time=4, on_commit=on_commit ) as reader: - async for message in reader.messages(): + while True: + message = await reader.receive_message() await _process(message) await reader.commit(message) From 3b5db9f06c4261134cec522d4a03346318d658f1 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 21 Mar 2023 20:01:06 +0300 Subject: [PATCH 140/147] sync --- CHANGELOG.md | 1 + tests/topics/test_topic_reader.py | 24 ++++++++++++------------ ydb/topic.py | 4 ++-- 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fd75caad..e9f37abf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ * Fixed global_allow_split_transactions * Added reader.receive_message() method +* Swap topic_path and consumer arguments in topic_client.reader method ## 3.0.1b10 ## * fixed sqlalchemy get_columns method with not null columns diff --git a/tests/topics/test_topic_reader.py b/tests/topics/test_topic_reader.py index 84d61a43..03731b76 100644 --- a/tests/topics/test_topic_reader.py +++ b/tests/topics/test_topic_reader.py @@ -8,7 +8,7 @@ class TestTopicReaderAsyncIO: async def test_read_batch( self, driver, topic_path, topic_with_messages, topic_consumer ): - reader = driver.topic_client.reader(topic_consumer, topic_path) + reader = driver.topic_client.reader(topic_path, topic_consumer) batch = await reader.receive_batch() assert batch is not None @@ -19,7 +19,7 @@ async def test_read_batch( async def test_read_message( self, driver, topic_path, topic_with_messages, topic_consumer ): - reader = driver.topic_client.reader(topic_consumer, topic_path) + reader = driver.topic_client.reader(topic_path, topic_consumer) msg = await reader.receive_message() assert msg is not None @@ -31,11 +31,11 @@ async def test_read_and_commit_message( self, driver, topic_path, topic_with_messages, topic_consumer ): - reader = driver.topic_client.reader(topic_consumer, topic_path) + reader = driver.topic_client.reader(topic_path, topic_consumer) batch = await reader.receive_batch() await reader.commit_with_ack(batch) - reader = driver.topic_client.reader(topic_consumer, topic_path) + reader = driver.topic_client.reader(topic_path, topic_consumer) batch2 = await reader.receive_batch() assert batch.messages[0] != batch2.messages[0] @@ -47,7 +47,7 @@ async def test_read_compressed_messages(self, driver, topic_path, topic_consumer ) as writer: await writer.write("123") - async with driver.topic_client.reader(topic_consumer, topic_path) as reader: + async with driver.topic_client.reader(topic_path, topic_consumer) as reader: batch = await reader.receive_batch() assert batch.messages[0].data.decode() == "123" @@ -66,7 +66,7 @@ def decode(b: bytes): await writer.write("123") async with driver.topic_client.reader( - topic_consumer, topic_path, decoders={codec: decode} + topic_path, topic_consumer, decoders={codec: decode} ) as reader: batch = await reader.receive_batch() assert batch.messages[0].data.decode() == "123" @@ -76,7 +76,7 @@ class TestTopicReaderSync: def test_read_batch( self, driver_sync, topic_path, topic_with_messages, topic_consumer ): - reader = driver_sync.topic_client.reader(topic_consumer, topic_path) + reader = driver_sync.topic_client.reader(topic_path, topic_consumer) batch = reader.receive_batch() assert batch is not None @@ -87,7 +87,7 @@ def test_read_batch( def test_read_message( self, driver_sync, topic_path, topic_with_messages, topic_consumer ): - reader = driver_sync.topic_client.reader(topic_consumer, topic_path) + reader = driver_sync.topic_client.reader(topic_path, topic_consumer) msg = reader.receive_message() assert msg is not None @@ -98,11 +98,11 @@ def test_read_message( def test_read_and_commit_message( self, driver_sync, topic_path, topic_with_messages, topic_consumer ): - reader = driver_sync.topic_client.reader(topic_consumer, topic_path) + reader = driver_sync.topic_client.reader(topic_path, topic_consumer) batch = reader.receive_batch() reader.commit_with_ack(batch) - reader = driver_sync.topic_client.reader(topic_consumer, topic_path) + reader = driver_sync.topic_client.reader(topic_path, topic_consumer) batch2 = reader.receive_batch() assert batch.messages[0] != batch2.messages[0] @@ -112,7 +112,7 @@ def test_read_compressed_messages(self, driver_sync, topic_path, topic_consumer) ) as writer: writer.write("123") - with driver_sync.topic_client.reader(topic_consumer, topic_path) as reader: + with driver_sync.topic_client.reader(topic_path, topic_consumer) as reader: batch = reader.receive_batch() assert batch.messages[0].data.decode() == "123" @@ -131,7 +131,7 @@ def decode(b: bytes): writer.write("123") with driver_sync.topic_client.reader( - topic_consumer, topic_path, decoders={codec: decode} + topic_path, topic_consumer, decoders={codec: decode} ) as reader: batch = reader.receive_batch() assert batch.messages[0].data.decode() == "123" diff --git a/ydb/topic.py b/ydb/topic.py index efe62219..ae6b5a5b 100644 --- a/ydb/topic.py +++ b/ydb/topic.py @@ -137,8 +137,8 @@ async def drop_topic(self, path: str): def reader( self, - consumer: str, topic: str, + consumer: str, buffer_size_bytes: int = 50 * 1024 * 1024, # decoders: map[codec_code] func(encoded_bytes)->decoded_bytes decoders: Union[Mapping[int, Callable[[bytes], bytes]], None] = None, @@ -306,8 +306,8 @@ def drop_topic(self, path: str): def reader( self, - consumer: str, topic: str, + consumer: str, buffer_size_bytes: int = 50 * 1024 * 1024, # decoders: map[codec_code] func(encoded_bytes)->decoded_bytes decoders: Union[Mapping[int, Callable[[bytes], bytes]], None] = None, From fa3693a25371b56253a4a38e2b1d5a8c2831ac4f Mon Sep 17 00:00:00 2001 From: robot Date: Tue, 21 Mar 2023 17:14:22 +0000 Subject: [PATCH 141/147] Release: 3.0.1b11 --- CHANGELOG.md | 1 + setup.py | 2 +- ydb/ydb_version.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e9f37abf..c24d452d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,4 @@ +## 3.0.1b11 ## * Fixed global_allow_split_transactions * Added reader.receive_message() method * Swap topic_path and consumer arguments in topic_client.reader method diff --git a/setup.py b/setup.py index dd80a6b4..6d322334 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ setuptools.setup( name="ydb", - version="3.0.1b10", # AUTOVERSION + version="3.0.1b11", # AUTOVERSION description="YDB Python SDK", author="Yandex LLC", author_email="ydb@yandex-team.ru", diff --git a/ydb/ydb_version.py b/ydb/ydb_version.py index 92fd0d98..9b3d0a8c 100644 --- a/ydb/ydb_version.py +++ b/ydb/ydb_version.py @@ -1 +1 @@ -VERSION = "3.0.1b10" +VERSION = "3.0.1b11" From 1f94f7b5094048e7d96f89851232ff6c48611750 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 21 Mar 2023 20:23:14 +0300 Subject: [PATCH 142/147] fix examples --- examples/topic/reader_async_example.py | 17 ++++--- examples/topic/reader_example.py | 67 ++++++++++++++------------ examples/topic/writer_async_example.py | 28 +++++------ examples/topic/writer_example.py | 24 ++++----- 4 files changed, 71 insertions(+), 65 deletions(-) diff --git a/examples/topic/reader_async_example.py b/examples/topic/reader_async_example.py index 448e3ded..41825ab4 100644 --- a/examples/topic/reader_async_example.py +++ b/examples/topic/reader_async_example.py @@ -10,14 +10,14 @@ async def connect(): connection_string="grpc://localhost:2135?database=/local", credentials=ydb.credentials.AnonymousCredentials(), ) - reader = ydb.TopicClientAsyncIO(db).reader("/local/topic", consumer="consumer") + reader = db.topic_client.reader("/local/topic", consumer="consumer") return reader async def create_reader_and_close_with_context_manager(db: ydb.aio.Driver): - async with ydb.TopicClientAsyncIO(db).reader( + async with db.topic_client.reader( "/database/topic/path", consumer="consumer" - ) as reader: + ) as reader: # noqa ... @@ -80,7 +80,7 @@ async def get_one_batch_from_external_loop_async(reader: ydb.TopicReaderAsyncIO) async def auto_deserialize_message(db: ydb.aio.Driver): # async, batch work similar to this - async with ydb.TopicClientAsyncIO(db).reader( + async with db.topic_client.reader( "/database/topic/path", consumer="asd", deserializer=json.loads ) as reader: while True: @@ -116,7 +116,7 @@ def process_batch(batch): async def connect_and_read_few_topics(db: ydb.aio.Driver): - with ydb.TopicClientAsyncIO(db).reader( + with db.topic_client.reader( [ "/database/topic/path", ydb.TopicSelector("/database/second-topic", partitions=3), @@ -133,7 +133,7 @@ def on_commit(event: ydb.TopicReaderEvents.OnCommit) -> None: print(event.topic) print(event.offset) - async with ydb.TopicClientAsyncIO(db).reader( + async with db.topic_client.reader( "/local", consumer="consumer", commit_batch_time=4, on_commit=on_commit ) as reader: while True: @@ -151,12 +151,13 @@ async def on_get_partition_start_offset( resp.start_offset = 123 return resp - async with ydb.TopicClient(db).reader( + async with db.topic_client.reader( "/local/test", consumer="consumer", on_get_partition_start_offset=on_get_partition_start_offset, ) as reader: - async for mess in reader.messages(): + while True: + mess = reader.receive_message() await _process(mess) # save progress to own database diff --git a/examples/topic/reader_example.py b/examples/topic/reader_example.py index 183c51d6..8de33c7e 100644 --- a/examples/topic/reader_example.py +++ b/examples/topic/reader_example.py @@ -9,33 +9,37 @@ def connect(): connection_string="grpc://localhost:2135?database=/local", credentials=ydb.credentials.AnonymousCredentials(), ) - reader = ydb.TopicClient(db).reader("/local/topic", consumer="consumer") + reader = db.topic_client.reader("/local/topic", consumer="consumer") return reader def create_reader_and_close_with_context_manager(db: ydb.Driver): - with ydb.TopicClient(db).reader( + with db.topic_client.reader( "/database/topic/path", consumer="consumer", buffer_size_bytes=123 ) as reader: - for message in reader: + while True: + message = reader.receive_message() # noqa pass def print_message_content(reader: ydb.TopicReader): - for message in reader.messages(): + while True: + message = reader.receive_message() print("text", message.data.read().decode("utf-8")) reader.commit(message) def process_messages_batch_explicit_commit(reader: ydb.TopicReader): - for batch in reader.batches(max_messages=100, timeout=2): + while True: + batch = reader.receive_batch() for message in batch.messages: _process(message) reader.commit(batch) def process_messages_batch_context_manager_commit(reader: ydb.TopicReader): - for batch in reader.batches(max_messages=100, timeout=2): + while True: + batch = reader.receive_batch() with reader.commit_on_exit(batch): for message in batch.messages: _process(message) @@ -52,9 +56,12 @@ def get_message_with_timeout(reader: ydb.TopicReader): def get_all_messages_with_small_wait(reader: ydb.TopicReader): - for message in reader.messages(timeout=1): - _process(message) - print("Have no new messages in a second") + while True: + try: + message = reader.receive_message(timeout=1) + _process(message) + except TimeoutError: + print("Have no new messages in a second") def get_a_message_from_external_loop(reader: ydb.TopicReader): @@ -81,30 +88,23 @@ def get_one_batch_from_external_loop(reader: ydb.TopicReader): def auto_deserialize_message(db: ydb.Driver): # async, batch work similar to this - reader = ydb.TopicClient(db).reader( + reader = db.topic_client.reader( "/database/topic/path", consumer="asd", deserializer=json.loads ) - for message in reader.messages(): + while True: + message = reader.receive_message() print( message.data.Name ) # message.data replaces by json.loads(message.data) of raw message reader.commit(message) -def commit_batch_with_context(reader: ydb.TopicReader): - for batch in reader.batches(): - with reader.commit_on_exit(batch): - for message in batch.messages: - if not batch.is_alive: - break - _process(message) - - def handle_partition_stop(reader: ydb.TopicReader): - for message in reader.messages(): - time.sleep(1) # some work + while True: + message = reader.receive_message() + time.sleep(123) # some work if message.is_alive: - time.sleep(123) # some other work + time.sleep(1) # some other work reader.commit(message) @@ -118,25 +118,28 @@ def process_batch(batch): _process(message) reader.commit(batch) - for batch in reader.batches(): + while True: + batch = reader.receive_batch() process_batch(batch) def connect_and_read_few_topics(db: ydb.Driver): - with ydb.TopicClient(db).reader( + with db.topic_client.reader( [ "/database/topic/path", ydb.TopicSelector("/database/second-topic", partitions=3), ] ) as reader: - for message in reader: + while True: + message = reader.receive_message() _process(message) reader.commit(message) def handle_partition_graceful_stop_batch(reader: ydb.TopicReader): # no special handle, but batch will contain less than prefer count messages - for batch in reader.batches(): + while True: + batch = reader.receive_batch() _process(batch) reader.commit(batch) @@ -146,10 +149,11 @@ def on_commit(event: ydb.TopicReaderEvents.OnCommit) -> None: print(event.topic) print(event.offset) - with ydb.TopicClient(db).reader( + with db.topic_client.reader( "/local", consumer="consumer", commit_batch_time=4, on_commit=on_commit ) as reader: - for message in reader: + while True: + message = reader.receive_message() with reader.commit_on_exit(message): _process(message) @@ -164,12 +168,13 @@ def on_get_partition_start_offset( resp.start_offset = 123 return resp - with ydb.TopicClient(db).reader( + with db.topic_client.reader( "/local/test", consumer="consumer", on_get_partition_start_offset=on_get_partition_start_offset, ) as reader: - for mess in reader: + while True: + mess = reader.receive_message() _process(mess) # save progress to own database diff --git a/examples/topic/writer_async_example.py b/examples/topic/writer_async_example.py index c5144685..28a17f52 100644 --- a/examples/topic/writer_async_example.py +++ b/examples/topic/writer_async_example.py @@ -1,6 +1,6 @@ import asyncio +import datetime import json -import time from typing import Dict, List import ydb @@ -8,7 +8,7 @@ async def create_writer(db: ydb.aio.Driver): - async with ydb.TopicClientAsyncIO(db).writer( + async with db.topic_client.writer( "/database/topic/path", producer_id="producer-id", ) as writer: @@ -16,15 +16,16 @@ async def create_writer(db: ydb.aio.Driver): async def connect_and_wait(db: ydb.aio.Driver): - async with ydb.TopicClientAsyncIO(db).writer( + async with db.topic_client.writer( "/database/topic/path", producer_id="producer-id", ) as writer: - writer.wait_init() + info = await writer.wait_init() # noqa + ... async def connect_without_context_manager(db: ydb.aio.Driver): - writer = ydb.TopicClientAsyncIO(db).writer( + writer = db.topic_client.writer( "/database/topic/path", producer_id="producer-id", ) @@ -49,7 +50,7 @@ async def send_messages(writer: ydb.TopicWriterAsyncIO): # with meta await writer.write( - ydb.TopicWriterMessage("asd", seqno=123, created_at_ns=time.time_ns()) + ydb.TopicWriterMessage("asd", seqno=123, created_at=datetime.datetime.now()) ) @@ -71,7 +72,7 @@ async def send_messages_with_manual_seqno(writer: ydb.TopicWriter): async def send_messages_with_wait_ack(writer: ydb.TopicWriterAsyncIO): # future wait - await writer.write_with_result( + await writer.write_with_ack( [ ydb.TopicWriterMessage("mess", seqno=1), ydb.TopicWriterMessage("mess", seqno=2), @@ -84,10 +85,10 @@ async def send_messages_with_wait_ack(writer: ydb.TopicWriterAsyncIO): async def send_json_message(db: ydb.aio.Driver): - async with ydb.TopicClientAsyncIO(db).writer( + async with db.topic_client.writer( "/database/path/topic", serializer=json.dumps ) as writer: - writer.write({"a": 123}) + await writer.write({"a": 123}) async def send_messages_and_wait_all_commit_with_flush(writer: ydb.TopicWriterAsyncIO): @@ -99,14 +100,11 @@ async def send_messages_and_wait_all_commit_with_flush(writer: ydb.TopicWriterAs async def send_messages_and_wait_all_commit_with_results( writer: ydb.TopicWriterAsyncIO, ): - last_future = None for i in range(10): content = "%s" % i - last_future = await writer.write_with_ack(content) + await writer.write(content) - await asyncio.wait(last_future) - if last_future.exception() is not None: - raise last_future.exception() + await writer.flush() async def switch_messages_with_many_producers( @@ -118,7 +116,7 @@ async def switch_messages_with_many_producers( # select writer for the msg writer_idx = msg[:1] writer = writers[writer_idx] - future = await writer.write_with_ack(msg) + future = await writer.write_with_ack_future(msg) futures.append(future) # wait acks from all writes diff --git a/examples/topic/writer_example.py b/examples/topic/writer_example.py index 1465dba5..63f6a108 100644 --- a/examples/topic/writer_example.py +++ b/examples/topic/writer_example.py @@ -1,6 +1,6 @@ import concurrent.futures +import datetime import json -import time from typing import Dict, List from concurrent.futures import Future, wait @@ -8,20 +8,20 @@ from ydb import TopicWriterMessage -async def connect(): - db = ydb.aio.Driver( +def connect(): + db = ydb.Driver( connection_string="grpc://localhost:2135?database=/local", credentials=ydb.credentials.AnonymousCredentials(), ) - writer = ydb.TopicClientAsyncIO(db).writer( + writer = db.topic_client.writer( "/local/topic", producer_id="producer-id", ) - await writer.write(TopicWriterMessage("asd")) + writer.write(TopicWriterMessage("asd")) def create_writer(db: ydb.Driver): - with ydb.TopicClient(db).writer( + with db.topic_client.writer( "/database/topic/path", producer_id="producer-id", ) as writer: @@ -29,15 +29,15 @@ def create_writer(db: ydb.Driver): def connect_and_wait(db: ydb.Driver): - with ydb.TopicClient(db).writer( + with db.topic_client.writer( "/database/topic/path", producer_id="producer-id", ) as writer: - writer.wait() + info = writer.wait_init() # noqa def connect_without_context_manager(db: ydb.Driver): - writer = ydb.TopicClient(db).writer( + writer = db.topic_client.writer( "/database/topic/path", producer_id="producer-id", ) @@ -61,7 +61,9 @@ def send_messages(writer: ydb.TopicWriter): ) # send few messages by one call # with meta - writer.write(ydb.TopicWriterMessage("asd", seqno=123, created_at_ns=time.time_ns())) + writer.write( + ydb.TopicWriterMessage("asd", seqno=123, created_at=datetime.datetime.now()) + ) def send_message_without_block_if_internal_buffer_is_full( @@ -101,7 +103,7 @@ def send_messages_with_wait_ack(writer: ydb.TopicWriter): def send_json_message(db: ydb.Driver): - with ydb.TopicClient(db).writer( + with db.topic_client.writer( "/database/path/topic", serializer=json.dumps ) as writer: writer.write({"a": 123}) From a490c27ec6f381f371981ac60980d1c6ff273ec3 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Wed, 22 Mar 2023 10:32:25 +0300 Subject: [PATCH 143/147] clean public interface --- examples/topic/reader_async_example.py | 69 +------------- examples/topic/reader_example.py | 89 +------------------ examples/topic/writer_async_example.py | 13 --- examples/topic/writer_example.py | 13 +-- ydb/_topic_reader/datatypes.py | 25 +++--- ydb/_topic_reader/topic_reader.py | 13 --- ydb/_topic_reader/topic_reader_asyncio.py | 55 ------------ .../topic_reader_asyncio_test.py | 18 ---- ydb/_topic_reader/topic_reader_sync.py | 70 ++------------- ydb/_topic_writer/topic_writer.py | 7 -- ydb/_topic_writer/topic_writer_asyncio.py | 4 - ydb/topic.py | 24 ----- 12 files changed, 25 insertions(+), 375 deletions(-) diff --git a/examples/topic/reader_async_example.py b/examples/topic/reader_async_example.py index 41825ab4..848aeed2 100644 --- a/examples/topic/reader_async_example.py +++ b/examples/topic/reader_async_example.py @@ -1,5 +1,4 @@ import asyncio -import json import time import ydb @@ -77,25 +76,11 @@ async def get_one_batch_from_external_loop_async(reader: ydb.TopicReaderAsyncIO) await reader.commit(batch) -async def auto_deserialize_message(db: ydb.aio.Driver): - # async, batch work similar to this - - async with db.topic_client.reader( - "/database/topic/path", consumer="asd", deserializer=json.loads - ) as reader: - while True: - message = await reader.receive_message() - print( - message.data.Name - ) # message.data replaces by json.loads(message.data) of raw message - reader.commit(message) - - async def handle_partition_stop(reader: ydb.TopicReaderAsyncIO): while True: message = await reader.receive_message() time.sleep(123) # some work - if message.is_alive: + if message.alive: time.sleep(1) # some other work await reader.commit(message) @@ -103,7 +88,7 @@ async def handle_partition_stop(reader: ydb.TopicReaderAsyncIO): async def handle_partition_stop_batch(reader: ydb.TopicReaderAsyncIO): def process_batch(batch): for message in batch.messages: - if not batch.is_alive: + if not batch.alive: # no reason work with expired batch # go read next - good batch return @@ -115,55 +100,5 @@ def process_batch(batch): process_batch(batch) -async def connect_and_read_few_topics(db: ydb.aio.Driver): - with db.topic_client.reader( - [ - "/database/topic/path", - ydb.TopicSelector("/database/second-topic", partitions=3), - ] - ) as reader: - while True: - message = await reader.receive_message() - await _process(message) - await reader.commit(message) - - -async def advanced_commit_notify(db: ydb.aio.Driver): - def on_commit(event: ydb.TopicReaderEvents.OnCommit) -> None: - print(event.topic) - print(event.offset) - - async with db.topic_client.reader( - "/local", consumer="consumer", commit_batch_time=4, on_commit=on_commit - ) as reader: - while True: - message = await reader.receive_message() - await _process(message) - await reader.commit(message) - - -async def advanced_read_with_own_progress_storage(db: ydb.TopicReaderAsyncIO): - async def on_get_partition_start_offset( - req: ydb.TopicReaderEvents.OnPartitionGetStartOffsetRequest, - ) -> ydb.TopicReaderEvents.OnPartitionGetStartOffsetResponse: - # read current progress from database - resp = ydb.TopicReaderEvents.OnPartitionGetStartOffsetResponse() - resp.start_offset = 123 - return resp - - async with db.topic_client.reader( - "/local/test", - consumer="consumer", - on_get_partition_start_offset=on_get_partition_start_offset, - ) as reader: - while True: - mess = reader.receive_message() - await _process(mess) - # save progress to own database - - # no commit progress to topic service - # reader.commit(mess) - - async def _process(msg): raise NotImplementedError() diff --git a/examples/topic/reader_example.py b/examples/topic/reader_example.py index 8de33c7e..f1c4eb73 100644 --- a/examples/topic/reader_example.py +++ b/examples/topic/reader_example.py @@ -1,4 +1,3 @@ -import json import time import ydb @@ -37,14 +36,6 @@ def process_messages_batch_explicit_commit(reader: ydb.TopicReader): reader.commit(batch) -def process_messages_batch_context_manager_commit(reader: ydb.TopicReader): - while True: - batch = reader.receive_batch() - with reader.commit_on_exit(batch): - for message in batch.messages: - _process(message) - - def get_message_with_timeout(reader: ydb.TopicReader): try: message = reader.receive_message(timeout=1) @@ -85,25 +76,11 @@ def get_one_batch_from_external_loop(reader: ydb.TopicReader): reader.commit(batch) -def auto_deserialize_message(db: ydb.Driver): - # async, batch work similar to this - - reader = db.topic_client.reader( - "/database/topic/path", consumer="asd", deserializer=json.loads - ) - while True: - message = reader.receive_message() - print( - message.data.Name - ) # message.data replaces by json.loads(message.data) of raw message - reader.commit(message) - - def handle_partition_stop(reader: ydb.TopicReader): while True: message = reader.receive_message() time.sleep(123) # some work - if message.is_alive: + if message.alive: time.sleep(1) # some other work reader.commit(message) @@ -111,7 +88,7 @@ def handle_partition_stop(reader: ydb.TopicReader): def handle_partition_stop_batch(reader: ydb.TopicReader): def process_batch(batch): for message in batch.messages: - if not batch.is_alive: + if not batch.alive: # no reason work with expired batch # go read next - good batch return @@ -123,19 +100,6 @@ def process_batch(batch): process_batch(batch) -def connect_and_read_few_topics(db: ydb.Driver): - with db.topic_client.reader( - [ - "/database/topic/path", - ydb.TopicSelector("/database/second-topic", partitions=3), - ] - ) as reader: - while True: - message = reader.receive_message() - _process(message) - reader.commit(message) - - def handle_partition_graceful_stop_batch(reader: ydb.TopicReader): # no special handle, but batch will contain less than prefer count messages while True: @@ -144,54 +108,5 @@ def handle_partition_graceful_stop_batch(reader: ydb.TopicReader): reader.commit(batch) -def advanced_commit_notify(db: ydb.Driver): - def on_commit(event: ydb.TopicReaderEvents.OnCommit) -> None: - print(event.topic) - print(event.offset) - - with db.topic_client.reader( - "/local", consumer="consumer", commit_batch_time=4, on_commit=on_commit - ) as reader: - while True: - message = reader.receive_message() - with reader.commit_on_exit(message): - _process(message) - - -def advanced_read_with_own_progress_storage(db: ydb.TopicReader): - def on_get_partition_start_offset( - req: ydb.TopicReaderEvents.OnPartitionGetStartOffsetRequest, - ) -> ydb.TopicReaderEvents.OnPartitionGetStartOffsetResponse: - - # read current progress from database - resp = ydb.TopicReaderEvents.OnPartitionGetStartOffsetResponse() - resp.start_offset = 123 - return resp - - with db.topic_client.reader( - "/local/test", - consumer="consumer", - on_get_partition_start_offset=on_get_partition_start_offset, - ) as reader: - while True: - mess = reader.receive_message() - _process(mess) - # save progress to own database - - # no commit progress to topic service - # reader.commit(mess) - - -def get_current_statistics(reader: ydb.TopicReader): - # sync - stat = reader.sessions_stat() - print(stat) - - # with feature - f = reader.async_sessions_stat() - stat = f.result() - print(stat) - - def _process(msg): raise NotImplementedError() diff --git a/examples/topic/writer_async_example.py b/examples/topic/writer_async_example.py index 28a17f52..4d11a86d 100644 --- a/examples/topic/writer_async_example.py +++ b/examples/topic/writer_async_example.py @@ -1,6 +1,5 @@ import asyncio import datetime -import json from typing import Dict, List import ydb @@ -84,13 +83,6 @@ async def send_messages_with_wait_ack(writer: ydb.TopicWriterAsyncIO): await writer.flush() -async def send_json_message(db: ydb.aio.Driver): - async with db.topic_client.writer( - "/database/path/topic", serializer=json.dumps - ) as writer: - await writer.write({"a": 123}) - - async def send_messages_and_wait_all_commit_with_flush(writer: ydb.TopicWriterAsyncIO): for i in range(10): await writer.write(ydb.TopicWriterMessage("%s" % i)) @@ -127,8 +119,3 @@ async def switch_messages_with_many_producers( # all ok, explicit return - for better return - - -async def get_current_statistics(reader: ydb.TopicReaderAsyncIO): - stat = await reader.sessions_stat() - print(stat) diff --git a/examples/topic/writer_example.py b/examples/topic/writer_example.py index 63f6a108..14b401e0 100644 --- a/examples/topic/writer_example.py +++ b/examples/topic/writer_example.py @@ -1,6 +1,5 @@ import concurrent.futures import datetime -import json from typing import Dict, List from concurrent.futures import Future, wait @@ -85,7 +84,10 @@ def send_messages_with_manual_seqno(writer: ydb.TopicWriter): def send_messages_with_wait_ack(writer: ydb.TopicWriter): # Explicit future wait writer.async_write_with_ack( - ydb.TopicWriterMessage("mess", seqno=1), ydb.TopicWriterMessage("mess", seqno=2) + [ + ydb.TopicWriterMessage("mess", seqno=1), + ydb.TopicWriterMessage("mess", seqno=2), + ] ).result() # implicit, by sync call @@ -102,13 +104,6 @@ def send_messages_with_wait_ack(writer: ydb.TopicWriter): writer.flush() -def send_json_message(db: ydb.Driver): - with db.topic_client.writer( - "/database/path/topic", serializer=json.dumps - ) as writer: - writer.write({"a": 123}) - - def send_messages_and_wait_all_commit_with_flush(writer: ydb.TopicWriter): for i in range(10): content = "%s" % i diff --git a/ydb/_topic_reader/datatypes.py b/ydb/_topic_reader/datatypes.py index 5376c76d..434eff35 100644 --- a/ydb/_topic_reader/datatypes.py +++ b/ydb/_topic_reader/datatypes.py @@ -7,7 +7,7 @@ from collections import deque from dataclasses import dataclass, field import datetime -from typing import Mapping, Union, Any, List, Dict, Deque, Optional +from typing import Union, Any, List, Dict, Deque, Optional from ydb._grpc.grpcwrapper.ydb_topic import OffsetsRange, Codec from ydb._topic_reader import topic_reader_asyncio @@ -26,7 +26,7 @@ def _commit_get_offsets_range(self) -> OffsetsRange: class ISessionAlive(abc.ABC): @property @abc.abstractmethod - def is_alive(self) -> bool: + def alive(self) -> bool: pass @@ -54,8 +54,8 @@ def _commit_get_offsets_range(self) -> OffsetsRange: # ISessionAlive implementation @property - def is_alive(self) -> bool: - raise NotImplementedError() + def alive(self) -> bool: + return not self._partition_session.closed @dataclass @@ -127,9 +127,7 @@ def ack_notify(self, offset: int): break def close(self): - try: - self._ensure_not_closed() - except topic_reader_asyncio.TopicReaderCommitToExpiredPartition: + if self.closed: return self.state = PartitionSession.State.Stopped @@ -137,6 +135,10 @@ def close(self): for waiter in self._ack_waiters: waiter._finish_error(exception) + @property + def closed(self): + return self.state == PartitionSession.State.Stopped + def _ensure_not_closed(self): if self.state == PartitionSession.State.Stopped: raise topic_reader_asyncio.TopicReaderCommitToExpiredPartition() @@ -164,7 +166,6 @@ def _finish_error(self, error: Exception): @dataclass class PublicBatch(ICommittable, ISessionAlive): - session_metadata: Mapping[str, str] messages: List[PublicMessage] _partition_session: PartitionSession _bytes_size: int @@ -184,12 +185,8 @@ def empty(self) -> bool: # ISessionAlive implementation @property - def is_alive(self) -> bool: - state = self._partition_session.state - return ( - state == PartitionSession.State.Active - or state == PartitionSession.State.GracefulShutdown - ) + def alive(self) -> bool: + return not self._partition_session.closed def pop_message(self) -> PublicMessage: return self.messages.pop(0) diff --git a/ydb/_topic_reader/topic_reader.py b/ydb/_topic_reader/topic_reader.py index 148d63b3..b3d5637d 100644 --- a/ydb/_topic_reader/topic_reader.py +++ b/ydb/_topic_reader/topic_reader.py @@ -36,19 +36,6 @@ class PublicReaderSettings: # decoder_executor, must be set for handle non raw messages decoder_executor: Optional[concurrent.futures.Executor] = None - - # on_commit: Callable[["Events.OnCommit"], None] = None - # on_get_partition_start_offset: Callable[ - # ["Events.OnPartitionGetStartOffsetRequest"], - # "Events.OnPartitionGetStartOffsetResponse", - # ] = None - # on_partition_session_start: Callable[["StubEvent"], None] = None - # on_partition_session_stop: Callable[["StubEvent"], None] = None - # on_partition_session_close: Callable[["StubEvent"], None] = None # todo? - # deserializer: Union[Callable[[bytes], Any], None] = None - # one_attempt_connection_timeout: Union[float, None] = 1 - # connection_timeout: Union[float, None] = None - # retry_policy: Union["RetryPolicy", None] = None update_token_interval: Union[int, float] = 3600 def _init_message(self) -> StreamReadMessage.InitRequest: diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index 0068e4ba..9eda2fbf 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -77,44 +77,8 @@ def __del__(self): if not self._closed: self._loop.create_task(self.close(), name="close reader") - async def sessions_stat(self) -> typing.List["topic_reader.SessionStat"]: - """ - Receive stat from the server - - use asyncio.wait_for for wait with timeout. - """ - raise NotImplementedError() - - def messages( - self, *, timeout: typing.Union[float, None] = None - ) -> typing.AsyncIterable[topic_reader.PublicMessage]: - """ - Block until receive new message - - if no new messages in timeout seconds: stop iteration by raise StopAsyncIteration - """ - raise NotImplementedError() - - def batches( - self, - *, - max_messages: typing.Union[int, None] = None, - max_bytes: typing.Union[int, None] = None, - timeout: typing.Union[float, None] = None, - ) -> typing.AsyncIterable[datatypes.PublicBatch]: - """ - Block until receive new batch. - All messages in a batch from same partition. - - if no new message in timeout seconds (default - infinite): stop iterations by raise StopIteration - """ - raise NotImplementedError() - async def receive_batch( self, - *, - max_messages: typing.Union[int, None] = None, - max_bytes: typing.Union[int, None] = None, ) -> typing.Union[datatypes.PublicBatch, None]: """ Get one messages batch from reader. @@ -134,16 +98,6 @@ async def receive_message(self) -> typing.Optional[datatypes.PublicMessage]: await self._reconnector.wait_message() return self._reconnector.receive_message_nowait() - async def commit_on_exit( - self, mess: datatypes.ICommittable - ) -> typing.AsyncContextManager: - """ - commit the mess match/message if exit from context manager without exceptions - - reader will close if exit from context manager with exception - """ - raise NotImplementedError() - def commit( self, batch: typing.Union[datatypes.PublicMessage, datatypes.PublicBatch] ): @@ -166,14 +120,6 @@ async def commit_with_ack( waiter = self._reconnector.commit(batch) await waiter.future - async def flush(self): - """ - force send all commit messages from internal buffers to server and wait acks for all of them. - - use asyncio.wait_for for wait with timeout. - """ - raise NotImplementedError() - async def close(self): if self._closed: raise TopicReaderClosedError() @@ -642,7 +588,6 @@ def _read_response_to_batches( if messages: batch = datatypes.PublicBatch( - session_metadata=server_batch.write_session_meta, messages=messages, _partition_session=partition_session, _bytes_size=bytes_per_batch, diff --git a/ydb/_topic_reader/topic_reader_asyncio_test.py b/ydb/_topic_reader/topic_reader_asyncio_test.py index a310298e..0134f38b 100644 --- a/ydb/_topic_reader/topic_reader_asyncio_test.py +++ b/ydb/_topic_reader/topic_reader_asyncio_test.py @@ -413,7 +413,6 @@ async def test_commit_ranges_for_received_messages( [ ( PublicBatch( - session_metadata={}, messages=[ PublicMessage( seqno=1, @@ -437,7 +436,6 @@ async def test_commit_ranges_for_received_messages( ), ( PublicBatch( - session_metadata={}, messages=[ PublicMessage( seqno=1, @@ -461,7 +459,6 @@ async def test_commit_ranges_for_received_messages( ), ( PublicBatch( - session_metadata={}, messages=[ PublicMessage( seqno=1, @@ -498,7 +495,6 @@ async def test_commit_ranges_for_received_messages( ), ( PublicBatch( - session_metadata={}, messages=[ PublicMessage( seqno=1, @@ -794,7 +790,6 @@ def reader_batch_count(): last_batch = stream_reader._message_batches[-1] assert last_batch == PublicBatch( - session_metadata=session_meta, messages=[ PublicMessage( seqno=2, @@ -918,7 +913,6 @@ async def test_read_batches( last2 = batches[2] assert last0 == PublicBatch( - session_metadata=session_meta, messages=[ PublicMessage( seqno=3, @@ -939,7 +933,6 @@ async def test_read_batches( _codec=Codec.CODEC_RAW, ) assert last1 == PublicBatch( - session_metadata=session_meta, messages=[ PublicMessage( seqno=2, @@ -960,7 +953,6 @@ async def test_read_batches( _codec=Codec.CODEC_RAW, ) assert last2 == PublicBatch( - session_metadata=session_meta2, messages=[ PublicMessage( seqno=3, @@ -1001,7 +993,6 @@ async def test_read_batches( ( [ PublicBatch( - session_metadata={}, messages=[stub_message(1)], _partition_session=stub_partition_session(), _bytes_size=0, @@ -1014,14 +1005,12 @@ async def test_read_batches( ( [ PublicBatch( - session_metadata={}, messages=[stub_message(1), stub_message(2)], _partition_session=stub_partition_session(), _bytes_size=0, _codec=Codec.CODEC_RAW, ), PublicBatch( - session_metadata={}, messages=[stub_message(3), stub_message(4)], _partition_session=stub_partition_session(), _bytes_size=0, @@ -1031,14 +1020,12 @@ async def test_read_batches( stub_message(1), [ PublicBatch( - session_metadata={}, messages=[stub_message(2)], _partition_session=stub_partition_session(), _bytes_size=0, _codec=Codec.CODEC_RAW, ), PublicBatch( - session_metadata={}, messages=[stub_message(3), stub_message(4)], _partition_session=stub_partition_session(), _bytes_size=0, @@ -1049,14 +1036,12 @@ async def test_read_batches( ( [ PublicBatch( - session_metadata={}, messages=[stub_message(1)], _partition_session=stub_partition_session(), _bytes_size=0, _codec=Codec.CODEC_RAW, ), PublicBatch( - session_metadata={}, messages=[stub_message(2), stub_message(3)], _partition_session=stub_partition_session(), _bytes_size=0, @@ -1066,7 +1051,6 @@ async def test_read_batches( stub_message(1), [ PublicBatch( - session_metadata={}, messages=[stub_message(2), stub_message(3)], _partition_session=stub_partition_session(), _bytes_size=0, @@ -1102,7 +1086,6 @@ async def test_receive_batch_nowait(self, stream, stream_reader, partition_sessi received = stream_reader.receive_batch_nowait() assert received == PublicBatch( - session_metadata=mess1.session_metadata, messages=[mess1], _partition_session=mess1._partition_session, _bytes_size=self.default_batch_size, @@ -1111,7 +1094,6 @@ async def test_receive_batch_nowait(self, stream, stream_reader, partition_sessi received = stream_reader.receive_batch_nowait() assert received == PublicBatch( - mess2.session_metadata, messages=[mess2], _partition_session=mess2._partition_session, _bytes_size=self.default_batch_size, diff --git a/ydb/_topic_reader/topic_reader_sync.py b/ydb/_topic_reader/topic_reader_sync.py index ed9730fa..cea7e36c 100644 --- a/ydb/_topic_reader/topic_reader_sync.py +++ b/ydb/_topic_reader/topic_reader_sync.py @@ -1,7 +1,7 @@ import asyncio import concurrent.futures import typing -from typing import List, Union, Iterable, Optional +from typing import List, Union, Optional from ydb._grpc.grpcwrapper.common_utils import SupportedDriverType from ydb._topic_common.common import ( @@ -10,10 +10,9 @@ TimeoutType, ) from ydb._topic_reader import datatypes -from ydb._topic_reader.datatypes import PublicMessage, PublicBatch, ICommittable +from ydb._topic_reader.datatypes import PublicBatch from ydb._topic_reader.topic_reader import ( PublicReaderSettings, - SessionStat, CommitResult, ) from ydb._topic_reader.topic_reader_asyncio import ( @@ -59,35 +58,6 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.close() - def async_sessions_stat(self) -> concurrent.futures.Future: - """ - Receive stat from the server, return feature. - """ - raise NotImplementedError() - - async def sessions_stat(self) -> List[SessionStat]: - """ - Receive stat from the server - - use async_sessions_stat for set explicit wait timeout - """ - raise NotImplementedError() - - def messages( - self, *, timeout: Union[float, None] = None - ) -> Iterable[PublicMessage]: - """ - todo? - - Block until receive new message - It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available. - - if no new message in timeout seconds (default - infinite): stop iterations by raise StopIteration - if timeout <= 0 - it will fast wait only one event loop cycle - without wait any i/o operations or pauses, - get messages from internal buffer only. - """ - raise NotImplementedError() - def receive_message( self, *, timeout: TimeoutType = None ) -> datatypes.PublicMessage: @@ -120,22 +90,6 @@ def async_wait_message(self) -> concurrent.futures.Future: self._async_reader._reconnector.wait_message() ) - def batches( - self, - *, - max_messages: Union[int, None] = None, - max_bytes: Union[int, None] = None, - timeout: Union[float, None] = None, - ) -> Iterable[PublicBatch]: - """ - Block until receive new batch. - It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available. - - if no new message in timeout seconds (default - infinite): stop iterations by raise StopIteration - if timeout <= 0 - it will fast wait only one event loop cycle - without wait any i/o operations or pauses, get messages from internal buffer only. - """ - raise NotImplementedError() - def receive_batch( self, *, @@ -153,9 +107,7 @@ def receive_batch( self._check_closed() return self._caller.safe_call_with_result( - self._async_reader.receive_batch( - max_messages=max_messages, max_bytes=max_bytes - ), + self._async_reader.receive_batch(), timeout, ) @@ -173,7 +125,9 @@ def commit( self._caller.call_sync(self._async_reader.commit(mess)) def commit_with_ack( - self, mess: ICommittable, timeout: TimeoutType = None + self, + mess: typing.Union[datatypes.PublicMessage, datatypes.PublicBatch], + timeout: TimeoutType = None, ) -> Union[CommitResult, List[CommitResult]]: """ write commit message to a buffer and wait ack from the server. @@ -198,18 +152,6 @@ def async_commit_with_ack( self._async_reader.commit_with_ack(mess) ) - def async_flush(self) -> concurrent.futures.Future: - """ - force send all commit messages from internal buffers to server and return Future for wait server acks. - """ - raise NotImplementedError() - - def flush(self): - """ - force send all commit messages from internal buffers to server and wait acks for all of them. - """ - raise NotImplementedError() - def close(self, *, timeout: TimeoutType = None): if self._closed: return diff --git a/ydb/_topic_writer/topic_writer.py b/ydb/_topic_writer/topic_writer.py index 59ad74ff..b94ff46b 100644 --- a/ydb/_topic_writer/topic_writer.py +++ b/ydb/_topic_writer/topic_writer.py @@ -37,13 +37,6 @@ class PublicWriterSettings: encoders: Optional[ typing.Mapping[PublicCodec, typing.Callable[[bytes], bytes]] ] = None - # get_last_seqno: bool = False - # serializer: Union[Callable[[Any], bytes], None] = None - # send_buffer_count: Optional[int] = 10000 - # send_buffer_bytes: Optional[int] = 100 * 1024 * 1024 - # codec: Optional[int] = None - # codec_autoselect: bool = True - # retry_policy: Optional["RetryPolicy"] = None update_token_interval: Union[int, float] = 3600 def __post_init__(self): diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 7cb1f1db..666fc11b 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -44,10 +44,6 @@ class WriterAsyncIO: _reconnector: "WriterAsyncIOReconnector" _closed: bool - @property - def last_seqno(self) -> int: - raise NotImplementedError() - def __init__(self, driver: SupportedDriverType, settings: PublicWriterSettings): self._loop = asyncio.get_running_loop() self._closed = False diff --git a/ydb/topic.py b/ydb/topic.py index ae6b5a5b..b7a04dcc 100644 --- a/ydb/topic.py +++ b/ydb/topic.py @@ -145,18 +145,6 @@ def reader( decoder_executor: Optional[ concurrent.futures.Executor ] = None, # default shared client executor pool - # on_commit: Callable[["Events.OnCommit"], None] = None - # on_get_partition_start_offset: Callable[ - # ["Events.OnPartitionGetStartOffsetRequest"], - # "Events.OnPartitionGetStartOffsetResponse", - # ] = None - # on_partition_session_start: Callable[["StubEvent"], None] = None - # on_partition_session_stop: Callable[["StubEvent"], None] = None - # on_partition_session_close: Callable[["StubEvent"], None] = None # todo? - # deserializer: Union[Callable[[bytes], Any], None] = None - # one_attempt_connection_timeout: Union[float, None] = 1 - # connection_timeout: Union[float, None] = None - # retry_policy: Union["RetryPolicy", None] = None ) -> TopicReaderAsyncIO: if not decoder_executor: @@ -314,18 +302,6 @@ def reader( decoder_executor: Optional[ concurrent.futures.Executor ] = None, # default shared client executor pool - # on_commit: Callable[["Events.OnCommit"], None] = None - # on_get_partition_start_offset: Callable[ - # ["Events.OnPartitionGetStartOffsetRequest"], - # "Events.OnPartitionGetStartOffsetResponse", - # ] = None - # on_partition_session_start: Callable[["StubEvent"], None] = None - # on_partition_session_stop: Callable[["StubEvent"], None] = None - # on_partition_session_close: Callable[["StubEvent"], None] = None # todo? - # deserializer: Union[Callable[[bytes], Any], None] = None - # one_attempt_connection_timeout: Union[float, None] = 1 - # connection_timeout: Union[float, None] = None - # retry_policy: Union["RetryPolicy", None] = None ) -> TopicReader: if not decoder_executor: decoder_executor = self._executor From fb39f2d1215f14ec2c1894d77838f16183bed8c8 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Wed, 22 Mar 2023 11:18:59 +0300 Subject: [PATCH 144/147] add __all__ to topic module --- ydb/topic.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/ydb/topic.py b/ydb/topic.py index b7a04dcc..4a4cc05d 100644 --- a/ydb/topic.py +++ b/ydb/topic.py @@ -1,5 +1,24 @@ from __future__ import annotations +__all__ = [ + 'TopicClient', + 'TopicClientAsyncIO', + 'TopicClientSettings', + 'TopicCodec', + 'TopicConsumer', + 'TopicDescription', + 'TopicError', + 'TopicMeteringMode', + 'TopicReader', + 'TopicReaderAsyncIO', + 'TopicReaderSettings', + 'TopicStatWindow', + 'TopicWriter', + 'TopicWriterAsyncIO', + 'TopicWriterMessage', + 'TopicWriterSettings', +] + import concurrent.futures import datetime from dataclasses import dataclass From abb6876ad9589eee35ed8a210ddca0e38561fa5a Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Wed, 22 Mar 2023 11:20:53 +0300 Subject: [PATCH 145/147] fix linter --- ydb/topic.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/ydb/topic.py b/ydb/topic.py index 4a4cc05d..7dde70ff 100644 --- a/ydb/topic.py +++ b/ydb/topic.py @@ -1,22 +1,22 @@ from __future__ import annotations __all__ = [ - 'TopicClient', - 'TopicClientAsyncIO', - 'TopicClientSettings', - 'TopicCodec', - 'TopicConsumer', - 'TopicDescription', - 'TopicError', - 'TopicMeteringMode', - 'TopicReader', - 'TopicReaderAsyncIO', - 'TopicReaderSettings', - 'TopicStatWindow', - 'TopicWriter', - 'TopicWriterAsyncIO', - 'TopicWriterMessage', - 'TopicWriterSettings', + "TopicClient", + "TopicClientAsyncIO", + "TopicClientSettings", + "TopicCodec", + "TopicConsumer", + "TopicDescription", + "TopicError", + "TopicMeteringMode", + "TopicReader", + "TopicReaderAsyncIO", + "TopicReaderSettings", + "TopicStatWindow", + "TopicWriter", + "TopicWriterAsyncIO", + "TopicWriterMessage", + "TopicWriterSettings", ] import concurrent.futures From b4d102a0137dd89db9cbc1af104eacbb5ac88e9c Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Wed, 22 Mar 2023 17:42:13 +0300 Subject: [PATCH 146/147] Update docker-compose-tls.yml --- docker-compose-tls.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker-compose-tls.yml b/docker-compose-tls.yml index c9d6fac9..83403ea1 100644 --- a/docker-compose-tls.yml +++ b/docker-compose-tls.yml @@ -1,7 +1,7 @@ version: "3.9" services: ydb: - image: cr.yandex/yc/yandex-docker-local-ydb:latest + image: cr.yandex/yc/yandex-docker-local-ydb@sha256:b569c23d6854564ec4d970bda86cddcf5b11c7c6362df62beb8ba8eafb8d54fd restart: always ports: - 2136:2136 From 12a4b7178208dfc1d2094de223a275f52b415bb7 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Wed, 22 Mar 2023 17:42:24 +0300 Subject: [PATCH 147/147] Update docker-compose.yml --- docker-compose.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker-compose.yml b/docker-compose.yml index d8b898ae..3223033b 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,7 +1,7 @@ version: "3.9" services: ydb: - image: cr.yandex/yc/yandex-docker-local-ydb:latest + image: cr.yandex/yc/yandex-docker-local-ydb@sha256:b569c23d6854564ec4d970bda86cddcf5b11c7c6362df62beb8ba8eafb8d54fd restart: always ports: - 2136:2136