From d41b6f417cb2529f6785dc696387b47e792c3551 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Wed, 15 Feb 2023 18:55:24 +0300 Subject: [PATCH 1/8] add commit for reader --- .github/workflows/style.yaml | 1 + tests/conftest.py | 4 + tests/topics/test_topic_reader.py | 23 + ydb/_grpc/grpcwrapper/ydb_topic.py | 112 ++++- ydb/_grpc/grpcwrapper/ydb_topic_test.py | 27 ++ ydb/_topic_common/test_helpers.py | 27 +- ydb/_topic_reader/datatypes.py | 183 +++++++- ydb/_topic_reader/datatypes_test.py | 315 ++++++++++++++ ydb/_topic_reader/topic_reader_asyncio.py | 180 ++++++-- .../topic_reader_asyncio_test.py | 408 ++++++++++++++---- ydb/_topic_reader/topic_reader_sync.py | 16 +- ydb/_topic_writer/topic_writer_asyncio.py | 9 +- ydb/_utilities.py | 15 + 13 files changed, 1153 insertions(+), 167 deletions(-) create mode 100644 ydb/_grpc/grpcwrapper/ydb_topic_test.py create mode 100644 ydb/_topic_reader/datatypes_test.py diff --git a/.github/workflows/style.yaml b/.github/workflows/style.yaml index 8723d8f2..c280042b 100644 --- a/.github/workflows/style.yaml +++ b/.github/workflows/style.yaml @@ -2,6 +2,7 @@ name: Style checks on: push: + - main pull_request: jobs: diff --git a/tests/conftest.py b/tests/conftest.py index 6fa1f174..62f486cb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -136,6 +136,10 @@ async def topic_with_messages(driver, topic_path): ydb.TopicWriterMessage(data="123".encode()), ydb.TopicWriterMessage(data="456".encode()), ) + await writer.write_with_ack( + ydb.TopicWriterMessage(data="789".encode()), + ydb.TopicWriterMessage(data="0".encode()), + ) await writer.close() diff --git a/tests/topics/test_topic_reader.py b/tests/topics/test_topic_reader.py index 21675eb2..734b64c7 100644 --- a/tests/topics/test_topic_reader.py +++ b/tests/topics/test_topic_reader.py @@ -11,6 +11,18 @@ async def test_read_message( assert await reader.receive_batch() is not None await reader.close() + async def test_read_and_commit_message( + self, driver, topic_path, topic_with_messages, topic_consumer + ): + + reader = driver.topic_client.topic_reader(topic_consumer, topic_path) + batch = await reader.receive_batch() + await reader.commit_with_ack(batch) + + reader = driver.topic_client.topic_reader(topic_consumer, topic_path) + batch2 = await reader.receive_batch() + assert batch.messages[0] != batch2.messages[0] + class TestTopicReaderSync: def test_read_message( @@ -20,3 +32,14 @@ def test_read_message( assert reader.receive_batch() is not None reader.close() + + def test_read_and_commit_message( + self, driver_sync, topic_path, topic_with_messages, topic_consumer + ): + reader = driver_sync.topic_client.topic_reader(topic_consumer, topic_path) + batch = reader.receive_batch() + reader.commit_with_ack(batch) + + reader = driver_sync.topic_client.topic_reader(topic_consumer, topic_path) + batch2 = reader.receive_batch() + assert batch.messages[0] != batch2.messages[0] diff --git a/ydb/_grpc/grpcwrapper/ydb_topic.py b/ydb/_grpc/grpcwrapper/ydb_topic.py index e6a5a8e3..ad8a8e72 100644 --- a/ydb/_grpc/grpcwrapper/ydb_topic.py +++ b/ydb/_grpc/grpcwrapper/ydb_topic.py @@ -67,10 +67,23 @@ def to_public(self) -> List[ydb_topic_public_types.PublicCodec]: return list(map(Codec.to_public, self.codecs)) -@dataclass -class OffsetsRange(IFromProto): - start: int - end: int +@dataclass(order=True) +class OffsetsRange(IFromProto, IToProto): + """ + half-opened interval, include [start, end) offsets + """ + + __slots__ = ("start", "end") + + start: int # first offset + end: int # offset after last, included to range + + def __post_init__(self): + if self.end < self.start: + raise ValueError( + "offset end must be not less then start. Got start=%s end=%s" + % (self.start, self.end) + ) @staticmethod def from_proto(msg: ydb_topic_pb2.OffsetsRange) -> "OffsetsRange": @@ -79,6 +92,20 @@ def from_proto(msg: ydb_topic_pb2.OffsetsRange) -> "OffsetsRange": end=msg.end, ) + def to_proto(self) -> ydb_topic_pb2.OffsetsRange: + return ydb_topic_pb2.OffsetsRange( + start=self.start, + end=self.end, + ) + + def is_intersected_with(self, other: "OffsetsRange") -> bool: + return ( + self.start <= other.start < self.end + or self.start < other.end <= self.end + or other.start <= self.start < other.end + or other.start < self.end <= other.end + ) + @dataclass class UpdateTokenRequest(IToProto): @@ -527,23 +554,67 @@ def from_proto( ) @dataclass - class CommitOffsetRequest: + class CommitOffsetRequest(IToProto): commit_offsets: List["PartitionCommitOffset"] + def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.CommitOffsetRequest: + res = ydb_topic_pb2.StreamReadMessage.CommitOffsetRequest( + commit_offsets=list( + map( + StreamReadMessage.CommitOffsetRequest.PartitionCommitOffset.to_proto, + self.commit_offsets, + ) + ), + ) + return res + @dataclass - class PartitionCommitOffset: + class PartitionCommitOffset(IToProto): partition_session_id: int offsets: List["OffsetsRange"] + def to_proto( + self, + ) -> ydb_topic_pb2.StreamReadMessage.CommitOffsetRequest.PartitionCommitOffset: + res = ydb_topic_pb2.StreamReadMessage.CommitOffsetRequest.PartitionCommitOffset( + partition_session_id=self.partition_session_id, + offsets=list(map(OffsetsRange.to_proto, self.offsets)), + ) + return res + @dataclass - class CommitOffsetResponse: - partitions_committed_offsets: List["PartitionCommittedOffset"] + class CommitOffsetResponse(IFromProto): + partitions_committed_offsets: List[ + "StreamReadMessage.CommitOffsetResponse.PartitionCommittedOffset" + ] + + @staticmethod + def from_proto( + msg: ydb_topic_pb2.StreamReadMessage.CommitOffsetResponse, + ) -> "StreamReadMessage.CommitOffsetResponse": + return StreamReadMessage.CommitOffsetResponse( + partitions_committed_offsets=list( + map( + StreamReadMessage.CommitOffsetResponse.PartitionCommittedOffset.from_proto, + msg.partitions_committed_offsets, + ) + ) + ) @dataclass - class PartitionCommittedOffset: + class PartitionCommittedOffset(IFromProto): partition_session_id: int committed_offset: int + @staticmethod + def from_proto( + msg: ydb_topic_pb2.StreamReadMessage.CommitOffsetResponse.PartitionCommittedOffset, + ) -> "StreamReadMessage.CommitOffsetResponse.PartitionCommittedOffset": + return StreamReadMessage.CommitOffsetResponse.PartitionCommittedOffset( + partition_session_id=msg.partition_session_id, + committed_offset=msg.committed_offset, + ) + @dataclass class PartitionSessionStatusRequest: partition_session_id: int @@ -576,16 +647,18 @@ def from_proto( @dataclass class StartPartitionSessionResponse(IToProto): partition_session_id: int - read_offset: int - commit_offset: int + read_offset: Optional[int] + commit_offset: Optional[int] def to_proto( self, ) -> ydb_topic_pb2.StreamReadMessage.StartPartitionSessionResponse: res = ydb_topic_pb2.StreamReadMessage.StartPartitionSessionResponse() res.partition_session_id = self.partition_session_id - res.read_offset = self.read_offset - res.commit_offset = self.commit_offset + if self.read_offset is not None: + res.read_offset = self.read_offset + if self.commit_offset is not None: + res.commit_offset = self.commit_offset return res @dataclass @@ -609,6 +682,8 @@ def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.FromClient: res = ydb_topic_pb2.StreamReadMessage.FromClient() if isinstance(self.client_message, StreamReadMessage.ReadRequest): res.read_request.CopyFrom(self.client_message.to_proto()) + elif isinstance(self.client_message, StreamReadMessage.CommitOffsetRequest): + res.commit_offset_request.CopyFrom(self.client_message.to_proto()) elif isinstance(self.client_message, StreamReadMessage.InitRequest): res.init_request.CopyFrom(self.client_message.to_proto()) elif isinstance( @@ -618,7 +693,9 @@ def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.FromClient: self.client_message.to_proto() ) else: - raise NotImplementedError() + raise NotImplementedError( + "Unknown message type: %s" % type(self.client_message) + ) return res @dataclass @@ -639,6 +716,13 @@ def from_proto( msg.read_response ), ) + elif mess_type == "commit_offset_response": + return StreamReadMessage.FromServer( + server_status=server_status, + server_message=StreamReadMessage.CommitOffsetResponse.from_proto( + msg.commit_offset_response + ), + ) elif mess_type == "init_response": return StreamReadMessage.FromServer( server_status=server_status, diff --git a/ydb/_grpc/grpcwrapper/ydb_topic_test.py b/ydb/_grpc/grpcwrapper/ydb_topic_test.py new file mode 100644 index 00000000..bff9b43d --- /dev/null +++ b/ydb/_grpc/grpcwrapper/ydb_topic_test.py @@ -0,0 +1,27 @@ +from ydb._grpc.grpcwrapper.ydb_topic import OffsetsRange + + +def test_offsets_range_intersected(): + # not intersected + for test in [(0, 1, 1, 2), (1, 2, 3, 5)]: + assert not OffsetsRange(test[0], test[1]).is_intersected_with( + OffsetsRange(test[2], test[3]) + ) + assert not OffsetsRange(test[2], test[3]).is_intersected_with( + OffsetsRange(test[0], test[1]) + ) + + # intersected + for test in [ + (1, 2, 1, 2), + (1, 10, 1, 2), + (1, 10, 2, 3), + (1, 10, 5, 15), + (10, 20, 5, 15), + ]: + assert OffsetsRange(test[0], test[1]).is_intersected_with( + OffsetsRange(test[2], test[3]) + ) + assert OffsetsRange(test[2], test[3]).is_intersected_with( + OffsetsRange(test[0], test[1]) + ) diff --git a/ydb/_topic_common/test_helpers.py b/ydb/_topic_common/test_helpers.py index 9023f759..60166d0d 100644 --- a/ydb/_topic_common/test_helpers.py +++ b/ydb/_topic_common/test_helpers.py @@ -39,7 +39,21 @@ def close(self): self.from_server.put_nowait(None) -async def wait_condition(f: typing.Callable[[], bool], timeout=1): +class WaitConditionException(Exception): + pass + + +async def wait_condition( + f: typing.Callable[[], bool], + timeout: typing.Optional[typing.Union[float, int]] = None, +): + """ + timeout default is 1 second + if timeout is 0 - only counter work. It userful if test need fast timeout for condition (without wait full timeout) + """ + if timeout is None: + timeout = 1 + start = time.monotonic() counter = 0 while (time.monotonic() - start < timeout) or counter < 1000: @@ -48,8 +62,13 @@ async def wait_condition(f: typing.Callable[[], bool], timeout=1): return await asyncio.sleep(0) - raise Exception("Bad condition in test") + raise WaitConditionException("Bad condition in test") -async def wait_for_fast(fut): - return await asyncio.wait_for(fut, 1) +async def wait_for_fast( + awaitable: typing.Awaitable, + timeout: typing.Optional[typing.Union[float, int]] = None, +): + fut = asyncio.ensure_future(awaitable) + await wait_condition(lambda: fut.done(), timeout) + return fut.result() diff --git a/ydb/_topic_reader/datatypes.py b/ydb/_topic_reader/datatypes.py index 9b2ab31a..06b8d690 100644 --- a/ydb/_topic_reader/datatypes.py +++ b/ydb/_topic_reader/datatypes.py @@ -1,20 +1,26 @@ +from __future__ import annotations + import abc +import asyncio +import bisect import enum -from dataclasses import dataclass +from collections import deque +from dataclasses import dataclass, field import datetime -from typing import Mapping, Union, Any, List, Dict +from typing import Mapping, Union, Any, List, Dict, Deque, Optional + +from ydb._grpc.grpcwrapper.ydb_topic import OffsetsRange +from ydb._topic_reader import topic_reader_asyncio class ICommittable(abc.ABC): - @property @abc.abstractmethod - def start_offset(self) -> int: - pass + def _commit_get_partition_session(self) -> PartitionSession: + ... - @property @abc.abstractmethod - def end_offset(self) -> int: - pass + def _commit_get_offsets_range(self) -> OffsetsRange: + ... class ISessionAlive(abc.ABC): @@ -36,15 +42,15 @@ class PublicMessage(ICommittable, ISessionAlive): data: Union[ bytes, Any ] # set as original decompressed bytes or deserialized object if deserializer set in reader - _partition_session: "PartitionSession" + _partition_session: PartitionSession + _commit_start_offset: int + _commit_end_offset: int - @property - def start_offset(self) -> int: - raise NotImplementedError() + def _commit_get_partition_session(self) -> PartitionSession: + return self._partition_session - @property - def end_offset(self) -> int: - raise NotImplementedError() + def _commit_get_offsets_range(self) -> OffsetsRange: + return OffsetsRange(self._commit_start_offset, self._commit_end_offset) # ISessionAlive implementation @property @@ -58,15 +64,147 @@ class PartitionSession: state: "PartitionSession.State" topic_path: str partition_id: int + committed_offset: int # last commit offset, acked from server. Processed messages up to the field-1 offset. + reader_reconnector_id: int + reader_stream_id: int + _next_message_start_commit_offset: int = field(init=False) + _send_commit_window_start: int = field(init=False) + + # todo: check if deque is optimal + _pending_commits: Deque[OffsetsRange] = field( + init=False, default_factory=lambda: deque() + ) + + # todo: check if deque is optimal + _ack_waiters: Deque["PartitionSession.CommitAckWaiter"] = field( + init=False, default_factory=lambda: deque() + ) + + _state_changed: asyncio.Event = field( + init=False, default_factory=lambda: asyncio.Event(), compare=False + ) + _loop: Optional[asyncio.AbstractEventLoop] = field( + init=False + ) # may be None in tests + + def __post_init__(self): + self._next_message_start_commit_offset = self.committed_offset + self._send_commit_window_start = self.committed_offset + + try: + self._loop = asyncio.get_running_loop() + except RuntimeError: + self._loop = None + + def add_commit( + self, new_commit: OffsetsRange + ) -> "PartitionSession.CommitAckWaiter": + self._ensure_not_closed() + + self._add_to_commits(new_commit) + return self._add_waiter(new_commit.end) + + def _add_to_commits(self, new_commit: OffsetsRange): + index = bisect.bisect_left(self._pending_commits, new_commit) + + prev_commit = self._pending_commits[index - 1] if index > 0 else None + commit = ( + self._pending_commits[index] if index < len(self._pending_commits) else None + ) + + for c in (prev_commit, commit): + if c is not None and new_commit.is_intersected_with(c): + raise ValueError( + "new commit intersected with existed. New range: %s, existed: %s" + % (new_commit, c) + ) + + if commit is not None and commit.start == new_commit.end: + commit.start = new_commit.start + elif prev_commit is not None and prev_commit.end == new_commit.start: + prev_commit.end = new_commit.end + else: + self._pending_commits.insert(index, new_commit) + + def _add_waiter(self, end_offset: int) -> "PartitionSession.CommitAckWaiter": + waiter = PartitionSession.CommitAckWaiter(end_offset, self._create_future()) + + # fast way + if len(self._ack_waiters) > 0 and self._ack_waiters[-1].end_offset < end_offset: + self._ack_waiters.append(waiter) + else: + bisect.insort(self._ack_waiters, waiter) + + return waiter + + def _create_future(self) -> asyncio.Future: + if self._loop: + return self._loop.create_future() + else: + return asyncio.Future() + + def pop_commit_range(self) -> Optional[OffsetsRange]: + self._ensure_not_closed() + + if len(self._pending_commits) == 0: + return None + + if self._pending_commits[0].start != self._send_commit_window_start: + return None + + res = self._pending_commits.popleft() + while ( + len(self._pending_commits) > 0 and self._pending_commits[0].start == res.end + ): + commit = self._pending_commits.popleft() + res.end = commit.end + + self._send_commit_window_start = res.end + + return res + + def ack_notify(self, offset: int): + self._ensure_not_closed() + + self.committed_offset = offset + + if len(self._ack_waiters) == 0: + # todo log warning + # must be never receive ack for not sended request + return + + while len(self._ack_waiters) > 0: + if self._ack_waiters[0].end_offset <= offset: + waiter = self._ack_waiters.popleft() + waiter.future.set_result(None) + else: + break + + def close(self): + try: + self._ensure_not_closed() + except topic_reader_asyncio.TopicReaderCommitToExpiredPartition: + return - def stop(self): self.state = PartitionSession.State.Stopped + exception = topic_reader_asyncio.TopicReaderCommitToExpiredPartition() + for waiter in self._ack_waiters: + waiter.future.set_exception(exception) + + def _ensure_not_closed(self): + if self.state == PartitionSession.State.Stopped: + raise topic_reader_asyncio.TopicReaderCommitToExpiredPartition() class State(enum.Enum): Active = 1 GracefulShutdown = 2 Stopped = 3 + @dataclass(order=True) + class CommitAckWaiter: + end_offset: int + future: asyncio.Future = field(compare=False) + @dataclass class PublicBatch(ICommittable, ISessionAlive): @@ -75,13 +213,14 @@ class PublicBatch(ICommittable, ISessionAlive): _partition_session: PartitionSession _bytes_size: int - @property - def start_offset(self) -> int: - raise NotImplementedError() + def _commit_get_partition_session(self) -> PartitionSession: + return self.messages[0]._commit_get_partition_session() - @property - def end_offset(self) -> int: - raise NotImplementedError() + def _commit_get_offsets_range(self) -> OffsetsRange: + return OffsetsRange( + self.messages[0]._commit_get_offsets_range().start, + self.messages[-1]._commit_get_offsets_range().end, + ) # ISessionAlive implementation @property diff --git a/ydb/_topic_reader/datatypes_test.py b/ydb/_topic_reader/datatypes_test.py new file mode 100644 index 00000000..6ead9a88 --- /dev/null +++ b/ydb/_topic_reader/datatypes_test.py @@ -0,0 +1,315 @@ +import asyncio +import bisect +import copy +import functools +from collections import deque +from typing import List, Optional, Type, Union + +import pytest + +from ydb._grpc.grpcwrapper.ydb_topic import OffsetsRange +from ydb._topic_common.test_helpers import wait_condition +from ydb._topic_reader import topic_reader_asyncio +from ydb._topic_reader.datatypes import PartitionSession + + +@pytest.mark.asyncio +class TestPartitionSession: + session_comitted_offset = 10 + + @pytest.fixture + def session(self) -> PartitionSession: + return PartitionSession( + id=1, + state=PartitionSession.State.Active, + topic_path="", + partition_id=1, + committed_offset=self.session_comitted_offset, + reader_reconnector_id=1, + reader_stream_id=1, + ) + + @pytest.mark.parametrize( + "offsets_waited,notify_offset,offsets_notified,offsets_waited_rest", + [ + ([1], 1, [1], []), + ([1], 10, [1], []), + ([1, 2, 3], 10, [1, 2, 3], []), + ([1, 2, 10, 20], 10, [1, 2, 10], [20]), + ([10, 20], 1, [], [10, 20]), + ], + ) + async def test_ack_notify( + self, + session, + offsets_waited: List[int], + notify_offset: int, + offsets_notified: List[int], + offsets_waited_rest: List[int], + ): + notified = [] + + for offset in offsets_waited: + fut = asyncio.Future() + + def add_notify(future, notified_offset): + notified.append(notified_offset) + + fut.add_done_callback(functools.partial(add_notify, notified_offset=offset)) + waiter = PartitionSession.CommitAckWaiter(offset, fut) + session._ack_waiters.append(waiter) + + session.ack_notify(notify_offset) + assert session._ack_waiters == deque( + [ + PartitionSession.CommitAckWaiter(offset, asyncio.Future()) + for offset in offsets_waited_rest + ] + ) + + await wait_condition(lambda: len(notified) == len(offsets_notified)) + + notified.sort() + assert notified == offsets_notified + assert session.committed_offset == notify_offset + + def test_add_commit(self, session): + commit = OffsetsRange( + self.session_comitted_offset, self.session_comitted_offset + 5 + ) + waiter = session.add_commit(commit) + assert waiter.end_offset == commit.end + + @pytest.mark.parametrize( + "original,add,result", + [ + ( + [], + OffsetsRange(1, 10), + [OffsetsRange(1, 10)], + ), + ( + [OffsetsRange(1, 10)], + OffsetsRange(15, 20), + [OffsetsRange(1, 10), OffsetsRange(15, 20)], + ), + ( + [OffsetsRange(15, 20)], + OffsetsRange(1, 10), + [OffsetsRange(1, 10), OffsetsRange(15, 20)], + ), + ( + [OffsetsRange(1, 10)], + OffsetsRange(10, 20), + [OffsetsRange(1, 20)], + ), + ( + [OffsetsRange(10, 20)], + OffsetsRange(1, 10), + [OffsetsRange(1, 20)], + ), + ( + [OffsetsRange(1, 2), OffsetsRange(3, 4)], + OffsetsRange(2, 3), + [OffsetsRange(1, 2), OffsetsRange(2, 4)], + ), + ( + [OffsetsRange(1, 10)], + OffsetsRange(5, 6), + ValueError, + ), + ], + ) + def test_add_to_commits( + self, + session, + original: List[OffsetsRange], + add: OffsetsRange, + result: Union[List[OffsetsRange], Type[Exception]], + ): + session._pending_commits = copy.deepcopy(original) + if isinstance(result, type) and issubclass(result, Exception): + with pytest.raises(result): + session._add_to_commits(add) + else: + session._add_to_commits(add) + assert session._pending_commits == result + + # noinspection PyTypeChecker + @pytest.mark.parametrize( + "original,add,result", + [ + ( + [], + 5, + [PartitionSession.CommitAckWaiter(5, None)], + ), + ( + [PartitionSession.CommitAckWaiter(5, None)], + 6, + [ + PartitionSession.CommitAckWaiter(5, None), + PartitionSession.CommitAckWaiter(6, None), + ], + ), + ( + [PartitionSession.CommitAckWaiter(5, None)], + 4, + [ + PartitionSession.CommitAckWaiter(4, None), + PartitionSession.CommitAckWaiter(5, None), + ], + ), + ( + [PartitionSession.CommitAckWaiter(5, None)], + 0, + [ + PartitionSession.CommitAckWaiter(0, None), + PartitionSession.CommitAckWaiter(5, None), + ], + ), + ( + [PartitionSession.CommitAckWaiter(5, None)], + 100, + [ + PartitionSession.CommitAckWaiter(5, None), + PartitionSession.CommitAckWaiter(100, None), + ], + ), + ( + [ + PartitionSession.CommitAckWaiter(5, None), + PartitionSession.CommitAckWaiter(100, None), + ], + 50, + [ + PartitionSession.CommitAckWaiter(5, None), + PartitionSession.CommitAckWaiter(50, None), + PartitionSession.CommitAckWaiter(100, None), + ], + ), + ( + [ + PartitionSession.CommitAckWaiter(5, None), + PartitionSession.CommitAckWaiter(7, None), + ], + 6, + [ + PartitionSession.CommitAckWaiter(5, None), + PartitionSession.CommitAckWaiter(6, None), + PartitionSession.CommitAckWaiter(7, None), + ], + ), + ( + [ + PartitionSession.CommitAckWaiter(5, None), + PartitionSession.CommitAckWaiter(100, None), + ], + 6, + [ + PartitionSession.CommitAckWaiter(5, None), + PartitionSession.CommitAckWaiter(6, None), + PartitionSession.CommitAckWaiter(100, None), + ], + ), + ( + [ + PartitionSession.CommitAckWaiter(5, None), + PartitionSession.CommitAckWaiter(100, None), + ], + 99, + [ + PartitionSession.CommitAckWaiter(5, None), + PartitionSession.CommitAckWaiter(99, None), + PartitionSession.CommitAckWaiter(100, None), + ], + ), + ], + ) + def test_add_waiter( + self, + session, + original: List[PartitionSession.CommitAckWaiter], + add: int, + result: List[PartitionSession.CommitAckWaiter], + ): + session._ack_waiters = copy.deepcopy(original) + res = session._add_waiter(add) + assert result == session._ack_waiters + + index = bisect.bisect_left(session._ack_waiters, res) + assert res is session._ack_waiters[index] + + def test_close_notify_waiters(self, session): + waiter = session._add_waiter(session.committed_offset + 1) + session.close() + + with pytest.raises(topic_reader_asyncio.TopicReaderCommitToExpiredPartition): + waiter.future.result() + + def test_close_twice(self, session): + session.close() + session.close() + + @pytest.mark.parametrize( + "commits,result,rest", + [ + ([], None, []), + ( + [OffsetsRange(session_comitted_offset + 1, 20)], + None, + [OffsetsRange(session_comitted_offset + 1, 20)], + ), + ( + [OffsetsRange(session_comitted_offset, session_comitted_offset + 1)], + OffsetsRange(session_comitted_offset, session_comitted_offset + 1), + [], + ), + ( + [ + OffsetsRange(session_comitted_offset, session_comitted_offset + 1), + OffsetsRange( + session_comitted_offset + 1, session_comitted_offset + 2 + ), + ], + OffsetsRange(session_comitted_offset, session_comitted_offset + 2), + [], + ), + ( + [ + OffsetsRange(session_comitted_offset, session_comitted_offset + 1), + OffsetsRange( + session_comitted_offset + 1, session_comitted_offset + 2 + ), + OffsetsRange( + session_comitted_offset + 10, session_comitted_offset + 20 + ), + ], + OffsetsRange(session_comitted_offset, session_comitted_offset + 2), + [ + OffsetsRange( + session_comitted_offset + 10, session_comitted_offset + 20 + ) + ], + ), + ], + ) + def test_get_commit_range( + self, + session, + commits: List[OffsetsRange], + result: Optional[OffsetsRange], + rest: List[OffsetsRange], + ): + send_commit_window_start = session._send_commit_window_start + + session._pending_commits = deque(commits) + res = session.pop_commit_range() + assert res == result + assert session._pending_commits == deque(rest) + + if res is None: + assert session._send_commit_window_start == send_commit_window_start + else: + assert session._send_commit_window_start != send_commit_window_start + assert session._send_commit_window_start == res.end diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index a3f792de..cc0839f7 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -6,12 +6,12 @@ from collections import deque from typing import Optional, Set, Dict - from .. import _apis, issues, RetrySettings +from .._utilities import AtomicCounter from ..aio import Driver from ..issues import Error as YdbError, _process_response -from .datatypes import PartitionSession, PublicMessage, PublicBatch, ICommittable -from .topic_reader import PublicReaderSettings, CommitResult, SessionStat +from . import datatypes +from . import topic_reader from .._topic_common.common import ( TokenGetterFuncType, ) @@ -28,6 +28,17 @@ class TopicReaderError(YdbError): pass +class TopicReaderCommitToExpiredPartition(TopicReaderError): + """ + Commit message when partition read session are dropped. + It is ok - the message/batch will not commit to server and will receive in other read session + (with this or other reader). + """ + + def __init__(self, message: str = "Topic reader partition session is closed"): + super().__init__(message) + + class TopicReaderStreamClosedError(TopicReaderError): def __init__(self): super().__init__("Topic reader stream is closed") @@ -43,7 +54,7 @@ class PublicAsyncIOReader: _closed: bool _reconnector: ReaderReconnector - def __init__(self, driver: Driver, settings: PublicReaderSettings): + def __init__(self, driver: Driver, settings: topic_reader.PublicReaderSettings): self._loop = asyncio.get_running_loop() self._closed = False self._reconnector = ReaderReconnector(driver, settings) @@ -58,7 +69,7 @@ def __del__(self): if not self._closed: self._loop.create_task(self.close(), name="close reader") - async def sessions_stat(self) -> typing.List["SessionStat"]: + async def sessions_stat(self) -> typing.List["topic_reader.SessionStat"]: """ Receive stat from the server @@ -68,7 +79,7 @@ async def sessions_stat(self) -> typing.List["SessionStat"]: def messages( self, *, timeout: typing.Union[float, None] = None - ) -> typing.AsyncIterable["PublicMessage"]: + ) -> typing.AsyncIterable[topic_reader.PublicMessage]: """ Block until receive new message @@ -76,7 +87,7 @@ def messages( """ raise NotImplementedError() - async def receive_message(self) -> typing.Union["PublicMessage", None]: + async def receive_message(self) -> typing.Union[topic_reader.PublicMessage, None]: """ Block until receive new message @@ -90,7 +101,7 @@ def batches( max_messages: typing.Union[int, None] = None, max_bytes: typing.Union[int, None] = None, timeout: typing.Union[float, None] = None, - ) -> typing.AsyncIterable["PublicBatch"]: + ) -> typing.AsyncIterable[datatypes.PublicBatch]: """ Block until receive new batch. All messages in a batch from same partition. @@ -104,7 +115,7 @@ async def receive_batch( *, max_messages: typing.Union[int, None] = None, max_bytes: typing.Union[int, None] = None, - ) -> typing.Union["PublicBatch", None]: + ) -> typing.Union[topic_reader.PublicBatch, None]: """ Get one messages batch from reader. All messages in a batch from same partition. @@ -114,7 +125,9 @@ async def receive_batch( await self._reconnector.wait_message() return self._reconnector.receive_batch_nowait() - async def commit_on_exit(self, mess: ICommittable) -> typing.AsyncContextManager: + async def commit_on_exit( + self, mess: datatypes.ICommittable + ) -> typing.AsyncContextManager: """ commit the mess match/message if exit from context manager without exceptions @@ -122,24 +135,27 @@ async def commit_on_exit(self, mess: ICommittable) -> typing.AsyncContextManager """ raise NotImplementedError() - def commit(self, mess: ICommittable): + def commit( + self, batch: typing.Union[datatypes.PublicMessage, datatypes.PublicBatch] + ): """ Write commit message to a buffer. For the method no way check the commit result (for example if lost connection - commits will not re-send and committed messages will receive again) """ - raise NotImplementedError() + self._reconnector.commit(batch) async def commit_with_ack( - self, mess: ICommittable - ) -> typing.Union[CommitResult, typing.List[CommitResult]]: + self, batch: typing.Union[datatypes.PublicMessage, datatypes.PublicBatch] + ): """ write commit message to a buffer and wait ack from the server. use asyncio.wait_for for wait with timeout. """ - raise NotImplementedError() + waiter = self._reconnector.commit(batch) + await waiter.future async def flush(self): """ @@ -158,7 +174,10 @@ async def close(self): class ReaderReconnector: - _settings: PublicReaderSettings + _static_reader_reconnector_counter = AtomicCounter() + + _id: int + _settings: topic_reader.PublicReaderSettings _driver: Driver _background_tasks: Set[Task] @@ -166,7 +185,9 @@ class ReaderReconnector: _stream_reader: Optional["ReaderStream"] _first_error: asyncio.Future[YdbError] - def __init__(self, driver: Driver, settings: PublicReaderSettings): + def __init__(self, driver: Driver, settings: topic_reader.PublicReaderSettings): + self._id = self._static_reader_reconnector_counter.inc_and_get() + self._settings = settings self._driver = driver self._background_tasks = set() @@ -182,7 +203,7 @@ async def _connection_loop(self): while True: try: self._stream_reader = await ReaderStream.create( - self._driver, self._settings + self._id, self._driver, self._settings ) attempt = 0 self._state_changed.set() @@ -216,6 +237,11 @@ async def wait_message(self): def receive_batch_nowait(self): return self._stream_reader.receive_batch_nowait() + def commit( + self, batch: datatypes.ICommittable + ) -> datatypes.PartitionSession.CommitAckWaiter: + return self._stream_reader.commit(batch) + async def close(self): await self._stream_reader.close() for task in self._background_tasks: @@ -233,20 +259,28 @@ def _set_first_error(self, err: issues.Error): class ReaderStream: + _static_id_counter = AtomicCounter() + + _id: int + _reader_reconnector_id: int _token_getter: Optional[TokenGetterFuncType] _session_id: str _stream: Optional[IGrpcWrapperAsyncIO] _started: bool _background_tasks: Set[asyncio.Task] - _partition_sessions: Dict[int, PartitionSession] + _partition_sessions: Dict[int, datatypes.PartitionSession] _buffer_size_bytes: int # use for init request, then for debug purposes only _state_changed: asyncio.Event _closed: bool - _message_batches: typing.Deque[PublicBatch] + _message_batches: typing.Deque[datatypes.PublicBatch] _first_error: asyncio.Future[YdbError] - def __init__(self, settings: PublicReaderSettings): + def __init__( + self, reader_reconnector_id: int, settings: topic_reader.PublicReaderSettings + ): + self._id = ReaderStream._static_id_counter.inc_and_get() + self._reader_reconnector_id = reader_reconnector_id self._token_getter = settings._token_getter self._session_id = "not initialized" self._stream = None @@ -262,8 +296,9 @@ def __init__(self, settings: PublicReaderSettings): @staticmethod async def create( + reader_reconnector_id: int, driver: SupportedDriverType, - settings: PublicReaderSettings, + settings: topic_reader.PublicReaderSettings, ) -> "ReaderStream": stream = GrpcWrapperAsyncIO(StreamReadMessage.FromServer.from_proto) @@ -271,7 +306,7 @@ async def create( driver, _apis.TopicService.Stub, _apis.TopicService.StreamRead ) - reader = ReaderStream(settings) + reader = ReaderStream(reader_reconnector_id, settings) await reader._start(stream, settings._init_message()) return reader @@ -321,6 +356,45 @@ def receive_batch_nowait(self): except IndexError: return None + def commit( + self, batch: datatypes.ICommittable + ) -> datatypes.PartitionSession.CommitAckWaiter: + partition_session = batch._commit_get_partition_session() + + if ( + partition_session.reader_reconnector_id + != partition_session.reader_reconnector_id + ): + raise TopicReaderError("reader can commit only self-produced messages") + + if partition_session.reader_stream_id != self._id: + raise TopicReaderCommitToExpiredPartition( + "commit messages after reconnect to server" + ) + + if partition_session.id not in self._partition_sessions: + raise TopicReaderCommitToExpiredPartition( + "commit messages after server stop the partition read session" + ) + + waiter = partition_session.add_commit(batch._commit_get_offsets_range()) + + send_range = partition_session.pop_commit_range() + if send_range: + client_message = StreamReadMessage.CommitOffsetRequest( + commit_offsets=[ + StreamReadMessage.CommitOffsetRequest.PartitionCommitOffset( + partition_session_id=partition_session.id, + offsets=[send_range], + ) + ] + ) + self._stream.write( + StreamReadMessage.FromClient(client_message=client_message) + ) + + return waiter + async def _read_messages_loop(self, stream: IGrpcWrapperAsyncIO): try: self._stream.write( @@ -335,11 +409,15 @@ async def _read_messages_loop(self, stream: IGrpcWrapperAsyncIO): _process_response(message.server_status) if isinstance(message.server_message, StreamReadMessage.ReadResponse): self._on_read_response(message.server_message) + elif isinstance( + message.server_message, StreamReadMessage.CommitOffsetResponse + ): + self._on_commit_response(message.server_message) elif isinstance( message.server_message, StreamReadMessage.StartPartitionSessionRequest, ): - self._on_start_partition_session_start(message.server_message) + self._on_start_partition_session(message.server_message) elif isinstance( message.server_message, StreamReadMessage.StopPartitionSessionRequest, @@ -356,7 +434,7 @@ async def _read_messages_loop(self, stream: IGrpcWrapperAsyncIO): self._set_first_error(e) raise e - def _on_start_partition_session_start( + def _on_start_partition_session( self, message: StreamReadMessage.StartPartitionSessionRequest ): try: @@ -371,18 +449,21 @@ def _on_start_partition_session_start( self._partition_sessions[ message.partition_session.partition_session_id - ] = PartitionSession( + ] = datatypes.PartitionSession( id=message.partition_session.partition_session_id, - state=PartitionSession.State.Active, + state=datatypes.PartitionSession.State.Active, topic_path=message.partition_session.path, partition_id=message.partition_session.partition_id, + committed_offset=message.committed_offset, + reader_reconnector_id=self._reader_reconnector_id, + reader_stream_id=self._id, ) self._stream.write( StreamReadMessage.FromClient( client_message=StreamReadMessage.StartPartitionSessionResponse( partition_session_id=message.partition_session.partition_session_id, - read_offset=0, - commit_offset=0, + read_offset=None, + commit_offset=None, ) ), ) @@ -399,7 +480,7 @@ def _on_partition_session_stop( return del self._partition_sessions[message.partition_session_id] - partition.stop() + partition.close() if message.graceful: self._stream.write( @@ -415,6 +496,16 @@ def _on_read_response(self, message: StreamReadMessage.ReadResponse): self._message_batches.extend(batches) self._buffer_consume_bytes(message.bytes_size) + def _on_commit_response(self, message: StreamReadMessage.CommitOffsetResponse): + for partition_offset in message.partitions_committed_offsets: + try: + session = self._partition_sessions[ + partition_offset.partition_session_id + ] + except KeyError: + continue + session.ack_notify(partition_offset.committed_offset) + def _buffer_consume_bytes(self, bytes_size): self._buffer_size_bytes -= bytes_size @@ -430,7 +521,7 @@ def _buffer_release_bytes(self, bytes_size): def _read_response_to_batches( self, message: StreamReadMessage.ReadResponse - ) -> typing.List[PublicBatch]: + ) -> typing.List[datatypes.PublicBatch]: batches = [] batch_count = 0 @@ -452,7 +543,7 @@ def _read_response_to_batches( for server_batch in partition_data.batches: messages = [] for message_data in server_batch.message_data: - mess = PublicMessage( + mess = datatypes.PublicMessage( seqno=message_data.seq_no, created_at=message_data.created_at, message_group_id=message_data.message_group_id, @@ -462,15 +553,23 @@ def _read_response_to_batches( producer_id=server_batch.producer_id, data=message_data.data, _partition_session=partition_session, + _commit_start_offset=partition_session._next_message_start_commit_offset, + _commit_end_offset=message_data.offset + 1, ) messages.append(mess) - batch = PublicBatch( - session_metadata=server_batch.write_session_meta, - messages=messages, - _partition_session=partition_session, - _bytes_size=bytes_per_batch, - ) - batches.append(batch) + + partition_session._next_message_start_commit_offset = ( + mess._commit_end_offset + ) + + if len(messages) > 0: + batch = datatypes.PublicBatch( + session_metadata=server_batch.write_session_meta, + messages=messages, + _partition_session=partition_session, + _bytes_size=bytes_per_batch, + ) + batches.append(batch) batches[-1]._bytes_size += additional_bytes_to_last_batch return batches @@ -498,6 +597,9 @@ async def close(self): self._state_changed.set() self._stream.close() + for session in self._partition_sessions.values(): + session.close() + for task in self._background_tasks: task.cancel() diff --git a/ydb/_topic_reader/topic_reader_asyncio_test.py b/ydb/_topic_reader/topic_reader_asyncio_test.py index f761a315..c73be69f 100644 --- a/ydb/_topic_reader/topic_reader_asyncio_test.py +++ b/ydb/_topic_reader/topic_reader_asyncio_test.py @@ -1,16 +1,21 @@ import asyncio import datetime import typing +from collections import deque +from dataclasses import dataclass +from typing import List, Optional from unittest import mock import pytest from ydb import issues +from . import datatypes, topic_reader_asyncio from .datatypes import PublicBatch, PublicMessage from .topic_reader import PublicReaderSettings -from .topic_reader_asyncio import ReaderStream, PartitionSession, ReaderReconnector +from .topic_reader_asyncio import ReaderStream, ReaderReconnector from .._grpc.grpcwrapper.common_utils import SupportedDriverType, ServerStatus from .._grpc.grpcwrapper.ydb_topic import StreamReadMessage, Codec, OffsetsRange +from .._topic_common import test_helpers from .._topic_common.test_helpers import StreamMock, wait_condition, wait_for_fast # Workaround for good IDE and universal for runtime @@ -47,38 +52,67 @@ def write(self, message: StreamReadMessage.FromClient): @pytest.mark.asyncio class TestReaderStream: default_batch_size = 1 + partition_session_id = 2 + partition_session_committed_offset = 10 + second_partition_session_id = 12 + second_partition_session_offset = 50 + default_reader_reconnector_id = 4 @pytest.fixture() def stream(self): return StreamMock() @pytest.fixture() - def partition_session(self, default_reader_settings): - return PartitionSession( + def partition_session( + self, default_reader_settings, stream_reader_started: ReaderStream + ) -> datatypes.PartitionSession: + partition_session = datatypes.PartitionSession( id=2, topic_path=default_reader_settings.topic, partition_id=4, - state=PartitionSession.State.Active, + state=datatypes.PartitionSession.State.Active, + committed_offset=self.partition_session_committed_offset, + reader_reconnector_id=self.default_reader_reconnector_id, + reader_stream_id=stream_reader_started._id, ) + assert partition_session.id not in stream_reader_started._partition_sessions + stream_reader_started._partition_sessions[ + partition_session.id + ] = partition_session + + return stream_reader_started._partition_sessions[partition_session.id] + @pytest.fixture() - def second_partition_session(self, default_reader_settings): - return PartitionSession( + def second_partition_session( + self, default_reader_settings, stream_reader_started: ReaderStream + ): + partition_session = datatypes.PartitionSession( id=12, topic_path=default_reader_settings.topic, partition_id=10, - state=PartitionSession.State.Active, + state=datatypes.PartitionSession.State.Active, + committed_offset=self.second_partition_session_offset, + reader_reconnector_id=self.default_reader_reconnector_id, + reader_stream_id=stream_reader_started._id, ) + assert partition_session.id not in stream_reader_started._partition_sessions + stream_reader_started._partition_sessions[ + partition_session.id + ] = partition_session + + return stream_reader_started._partition_sessions[partition_session.id] + @pytest.fixture() async def stream_reader_started( self, stream, default_reader_settings, - partition_session, - second_partition_session, ) -> ReaderStream: - reader = ReaderStream(default_reader_settings) + reader = ReaderStream( + self.default_reader_reconnector_id, default_reader_settings + ) init_message = object() # noinspection PyTypeChecker @@ -99,54 +133,8 @@ async def stream_reader_started( read_request = await wait_for_fast(stream.from_client.get()) assert isinstance(read_request.client_message, StreamReadMessage.ReadRequest) - stream.from_server.put_nowait( - StreamReadMessage.FromServer( - server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), - server_message=StreamReadMessage.StartPartitionSessionRequest( - partition_session=StreamReadMessage.PartitionSession( - partition_session_id=partition_session.id, - path=partition_session.topic_path, - partition_id=partition_session.partition_id, - ), - committed_offset=0, - partition_offsets=OffsetsRange( - start=0, - end=0, - ), - ), - ) - ) await start - start_partition_resp = await wait_for_fast(stream.from_client.get()) - assert isinstance( - start_partition_resp.client_message, - StreamReadMessage.StartPartitionSessionResponse, - ) - - stream.from_server.put_nowait( - StreamReadMessage.FromServer( - server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), - server_message=StreamReadMessage.StartPartitionSessionRequest( - partition_session=StreamReadMessage.PartitionSession( - partition_session_id=second_partition_session.id, - path=second_partition_session.topic_path, - partition_id=second_partition_session.partition_id, - ), - committed_offset=0, - partition_offsets=OffsetsRange( - start=0, - end=0, - ), - ), - ) - ) - start_partition_resp = await wait_for_fast(stream.from_client.get()) - assert isinstance( - start_partition_resp.client_message, - StreamReadMessage.StartPartitionSessionResponse, - ) - await asyncio.sleep(0) with pytest.raises(asyncio.QueueEmpty): stream.from_client.get_nowait() @@ -170,17 +158,26 @@ async def stream_reader_finish_with_error( await stream_reader_started.close() @staticmethod - def create_message(partition_session: PartitionSession, seqno: int): + def create_message( + partition_session: datatypes.PartitionSession, seqno: int, offset_delta: int + ): return PublicMessage( seqno=seqno, created_at=datetime.datetime(2023, 2, 3, 14, 15), message_group_id="test-message-group", session_metadata={}, - offset=seqno + 1, + offset=partition_session._next_message_start_commit_offset + + offset_delta + - 1, written_at=datetime.datetime(2023, 2, 3, 14, 16), producer_id="test-producer-id", data=bytes(), _partition_session=partition_session, + _commit_start_offset=partition_session._next_message_start_commit_offset + + offset_delta + - 1, + _commit_end_offset=partition_session._next_message_start_commit_offset + + offset_delta, ) async def send_message(self, stream_reader, message: PublicMessage): @@ -236,6 +233,231 @@ class TestError(Exception): with pytest.raises(TestError): stream_reader_finish_with_error.receive_batch_nowait() + @pytest.mark.parametrize( + "pending_ranges,commit,send_range,rest_ranges", + [ + ( + [], + OffsetsRange( + partition_session_committed_offset, + partition_session_committed_offset + 1, + ), + OffsetsRange( + partition_session_committed_offset, + partition_session_committed_offset + 1, + ), + [], + ), + ( + [], + OffsetsRange( + partition_session_committed_offset + 1, + partition_session_committed_offset + 2, + ), + None, + [ + OffsetsRange( + partition_session_committed_offset + 1, + partition_session_committed_offset + 2, + ) + ], + ), + ( + [ + OffsetsRange( + partition_session_committed_offset + 5, + partition_session_committed_offset + 10, + ) + ], + OffsetsRange( + partition_session_committed_offset + 1, + partition_session_committed_offset + 2, + ), + None, + [ + OffsetsRange( + partition_session_committed_offset + 1, + partition_session_committed_offset + 2, + ), + OffsetsRange( + partition_session_committed_offset + 5, + partition_session_committed_offset + 10, + ), + ], + ), + ( + [ + OffsetsRange( + partition_session_committed_offset + 1, + partition_session_committed_offset + 2, + ) + ], + OffsetsRange( + partition_session_committed_offset, + partition_session_committed_offset + 1, + ), + OffsetsRange( + partition_session_committed_offset, + partition_session_committed_offset + 2, + ), + [], + ), + ( + [ + OffsetsRange( + partition_session_committed_offset + 1, + partition_session_committed_offset + 2, + ), + OffsetsRange( + partition_session_committed_offset + 2, + partition_session_committed_offset + 3, + ), + ], + OffsetsRange( + partition_session_committed_offset, + partition_session_committed_offset + 1, + ), + OffsetsRange( + partition_session_committed_offset, + partition_session_committed_offset + 3, + ), + [], + ), + ( + [ + OffsetsRange( + partition_session_committed_offset + 1, + partition_session_committed_offset + 2, + ), + OffsetsRange( + partition_session_committed_offset + 2, + partition_session_committed_offset + 3, + ), + OffsetsRange( + partition_session_committed_offset + 4, + partition_session_committed_offset + 5, + ), + ], + OffsetsRange( + partition_session_committed_offset, + partition_session_committed_offset + 1, + ), + OffsetsRange( + partition_session_committed_offset, + partition_session_committed_offset + 3, + ), + [ + OffsetsRange( + partition_session_committed_offset + 4, + partition_session_committed_offset + 5, + ) + ], + ), + ], + ) + async def test_send_commit_messages( + self, + stream, + stream_reader: ReaderStream, + partition_session, + pending_ranges: List[OffsetsRange], + commit: OffsetsRange, + send_range: Optional[OffsetsRange], + rest_ranges: List[OffsetsRange], + ): + @dataclass + class Commitable(datatypes.ICommittable): + start: int + end: int + + def _commit_get_partition_session(self) -> datatypes.PartitionSession: + return partition_session + + def _commit_get_offsets_range(self) -> OffsetsRange: + return OffsetsRange(self.start, self.end) + + partition_session._pending_commits = deque(pending_ranges) + + stream_reader.commit(Commitable(commit.start, commit.end)) + + async def wait_message(): + return await wait_for_fast(stream.from_client.get(), timeout=0) + + if send_range: + msg = await wait_message() # type: StreamReadMessage.FromClient + assert msg.client_message == StreamReadMessage.CommitOffsetRequest( + commit_offsets=[ + StreamReadMessage.CommitOffsetRequest.PartitionCommitOffset( + partition_session_id=partition_session.id, + offsets=[send_range], + ) + ] + ) + else: + with pytest.raises(test_helpers.WaitConditionException): + await wait_message() + + assert partition_session._pending_commits == deque(rest_ranges) + + async def test_commit_ack_received( + self, stream_reader, stream, partition_session, second_partition_session + ): + offset1 = self.partition_session_committed_offset + 1 + waiter1 = partition_session._add_waiter(offset1) + + offset2 = self.second_partition_session_offset + 2 + waiter2 = second_partition_session._add_waiter(offset2) + + stream.from_server.put_nowait( + StreamReadMessage.FromServer( + server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), + server_message=StreamReadMessage.CommitOffsetResponse( + partitions_committed_offsets=[ + StreamReadMessage.CommitOffsetResponse.PartitionCommittedOffset( + partition_session_id=partition_session.id, + committed_offset=offset1, + ), + StreamReadMessage.CommitOffsetResponse.PartitionCommittedOffset( + partition_session_id=second_partition_session.id, + committed_offset=offset2, + ), + ] + ), + ) + ) + + await wait_for_fast(waiter1.future) + await wait_for_fast(waiter2.future) + + async def test_close_ack_waiters_when_close_stream_reader( + self, stream_reader_started: ReaderStream, partition_session + ): + waiter = partition_session._add_waiter( + self.partition_session_committed_offset + 1 + ) + await wait_for_fast(stream_reader_started.close()) + + with pytest.raises(topic_reader_asyncio.TopicReaderCommitToExpiredPartition): + waiter.future.result() + + async def test_commit_ranges_for_received_messages( + self, stream, stream_reader_started: ReaderStream, partition_session + ): + m1 = self.create_message(partition_session, 1, 1) + m2 = self.create_message(partition_session, 2, 10) + m2._commit_start_offset = m1.offset + 1 + + await self.send_message(stream_reader_started, m1) + await self.send_message(stream_reader_started, m2) + + await stream_reader_started.wait_messages() + received = stream_reader_started.receive_batch_nowait().messages + assert received == [m1] + + await stream_reader_started.wait_messages() + received = stream_reader_started.receive_batch_nowait().messages + assert received == [m2] + async def test_error_from_status_code( self, stream, stream_reader_finish_with_error ): @@ -257,7 +479,9 @@ async def test_error_from_status_code( stream_reader_finish_with_error.receive_batch_nowait() async def test_init_reader(self, stream, default_reader_settings): - reader = ReaderStream(default_reader_settings) + reader = ReaderStream( + self.default_reader_reconnector_id, default_reader_settings + ) init_message = StreamReadMessage.InitRequest( consumer="test-consumer", topics_read_settings=[ @@ -309,6 +533,7 @@ def session_count(): test_partition_id = partition_session.partition_id + 1 test_partition_session_id = partition_session.id + 1 test_topic_path = default_reader_settings.topic + "-asd" + test_partition_committed_offset = 18 stream.from_server.put_nowait( StreamReadMessage.FromServer( @@ -319,7 +544,7 @@ def session_count(): path=test_topic_path, partition_id=test_partition_id, ), - committed_offset=0, + committed_offset=test_partition_committed_offset, partition_offsets=OffsetsRange( start=0, end=0, @@ -331,19 +556,22 @@ def session_count(): assert response == StreamReadMessage.FromClient( client_message=StreamReadMessage.StartPartitionSessionResponse( partition_session_id=test_partition_session_id, - read_offset=0, - commit_offset=0, + read_offset=None, + commit_offset=None, ) ) assert len(stream_reader._partition_sessions) == initial_session_count + 1 assert stream_reader._partition_sessions[ test_partition_session_id - ] == PartitionSession( + ] == datatypes.PartitionSession( id=test_partition_session_id, - state=PartitionSession.State.Active, + state=datatypes.PartitionSession.State.Active, topic_path=test_topic_path, partition_id=test_partition_id, + committed_offset=test_partition_committed_offset, + reader_reconnector_id=self.default_reader_reconnector_id, + reader_stream_id=stream_reader._id, ) async def test_partition_stop_force(self, stream, stream_reader, partition_session): @@ -414,7 +642,11 @@ def session_count(): stream.from_client.get_nowait() async def test_receive_message_from_server( - self, stream_reader, stream, partition_session, second_partition_session + self, + stream_reader, + stream, + partition_session: datatypes.PartitionSession, + second_partition_session, ): def reader_batch_count(): return len(stream_reader._message_batches) @@ -430,6 +662,8 @@ def reader_batch_count(): session_meta = {"a": "b"} message_group_id = "test-message-group-id" + expected_message_offset = partition_session.committed_offset + stream.from_server.put_nowait( StreamReadMessage.FromServer( server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), @@ -442,7 +676,7 @@ def reader_batch_count(): StreamReadMessage.ReadResponse.Batch( message_data=[ StreamReadMessage.ReadResponse.MessageData( - offset=1, + offset=expected_message_offset, seq_no=2, created_at=created_at, data=data, @@ -475,11 +709,13 @@ def reader_batch_count(): created_at=created_at, message_group_id=message_group_id, session_metadata=session_meta, - offset=1, + offset=expected_message_offset, written_at=written_at, producer_id=producer_id, data=data, _partition_session=partition_session, + _commit_start_offset=expected_message_offset, + _commit_end_offset=expected_message_offset + 1, ) ], _partition_session=partition_session, @@ -505,6 +741,11 @@ async def test_read_batches( message_group_id = "test-message-group-id" message_group_id2 = "test-message-group-id-2" + partition1_mess1_expected_offset = partition_session.committed_offset + partition2_mess1_expected_offset = second_partition_session.committed_offset + partition2_mess2_expected_offset = second_partition_session.committed_offset + 1 + partition2_mess3_expected_offset = second_partition_session.committed_offset + 2 + batches = stream_reader._read_response_to_batches( StreamReadMessage.ReadResponse( bytes_size=3, @@ -515,7 +756,7 @@ async def test_read_batches( StreamReadMessage.ReadResponse.Batch( message_data=[ StreamReadMessage.ReadResponse.MessageData( - offset=2, + offset=partition1_mess1_expected_offset, seq_no=3, created_at=created_at, data=data, @@ -536,7 +777,7 @@ async def test_read_batches( StreamReadMessage.ReadResponse.Batch( message_data=[ StreamReadMessage.ReadResponse.MessageData( - offset=1, + offset=partition2_mess1_expected_offset, seq_no=2, created_at=created_at2, data=data, @@ -552,7 +793,7 @@ async def test_read_batches( StreamReadMessage.ReadResponse.Batch( message_data=[ StreamReadMessage.ReadResponse.MessageData( - offset=2, + offset=partition2_mess2_expected_offset, seq_no=3, created_at=created_at3, data=data2, @@ -560,7 +801,7 @@ async def test_read_batches( message_group_id=message_group_id, ), StreamReadMessage.ReadResponse.MessageData( - offset=4, + offset=partition2_mess3_expected_offset, seq_no=5, created_at=created_at4, data=data, @@ -591,11 +832,13 @@ async def test_read_batches( created_at=created_at, message_group_id=message_group_id, session_metadata=session_meta, - offset=2, + offset=partition1_mess1_expected_offset, written_at=written_at, producer_id=producer_id, data=data, _partition_session=partition_session, + _commit_start_offset=partition1_mess1_expected_offset, + _commit_end_offset=partition1_mess1_expected_offset + 1, ) ], _partition_session=partition_session, @@ -609,11 +852,13 @@ async def test_read_batches( created_at=created_at2, message_group_id=message_group_id, session_metadata=session_meta, - offset=1, + offset=partition2_mess1_expected_offset, written_at=written_at2, producer_id=producer_id, data=data, _partition_session=second_partition_session, + _commit_start_offset=partition2_mess1_expected_offset, + _commit_end_offset=partition2_mess1_expected_offset + 1, ) ], _partition_session=second_partition_session, @@ -627,22 +872,26 @@ async def test_read_batches( created_at=created_at3, message_group_id=message_group_id, session_metadata=session_meta2, - offset=2, + offset=partition2_mess2_expected_offset, written_at=written_at2, producer_id=producer_id2, data=data2, _partition_session=second_partition_session, + _commit_start_offset=partition2_mess2_expected_offset, + _commit_end_offset=partition2_mess2_expected_offset + 1, ), PublicMessage( seqno=5, created_at=created_at4, message_group_id=message_group_id2, session_metadata=session_meta2, - offset=4, + offset=partition2_mess3_expected_offset, written_at=written_at2, producer_id=producer_id, data=data, _partition_session=second_partition_session, + _commit_start_offset=partition2_mess3_expected_offset, + _commit_end_offset=partition2_mess3_expected_offset + 1, ), ], _partition_session=second_partition_session, @@ -652,17 +901,17 @@ async def test_read_batches( async def test_receive_batch_nowait(self, stream, stream_reader, partition_session): assert stream_reader.receive_batch_nowait() is None - mess1 = self.create_message(partition_session, 1) + mess1 = self.create_message(partition_session, 1, 1) await self.send_message(stream_reader, mess1) - mess2 = self.create_message(partition_session, 2) + mess2 = self.create_message(partition_session, 2, 1) await self.send_message(stream_reader, mess2) initial_buffer_size = stream_reader._buffer_size_bytes received = stream_reader.receive_batch_nowait() assert received == PublicBatch( - mess1.session_metadata, + session_metadata=mess1.session_metadata, messages=[mess1], _partition_session=mess1._partition_session, _bytes_size=self.default_batch_size, @@ -721,6 +970,7 @@ async def wait_messages(): stream_index = 0 async def stream_create( + reader_reconnector_id: int, driver: SupportedDriverType, settings: PublicReaderSettings, ): @@ -735,7 +985,7 @@ async def stream_create( with mock.patch.object(ReaderStream, "create", stream_create): reconnector = ReaderReconnector(mock.Mock(), PublicReaderSettings("", "")) - await reconnector.wait_message() + await wait_for_fast(reconnector.wait_message()) reader_stream_mock_with_error.wait_error.assert_any_await() reader_stream_mock_with_error.wait_messages.assert_any_await() diff --git a/ydb/_topic_reader/topic_reader_sync.py b/ydb/_topic_reader/topic_reader_sync.py index b30b547a..9652cb84 100644 --- a/ydb/_topic_reader/topic_reader_sync.py +++ b/ydb/_topic_reader/topic_reader_sync.py @@ -46,13 +46,19 @@ async def create_reader(): def __del__(self): self.close() - def _call(self, coro): + def _call(self, coro) -> concurrent.futures.Future: + """ + Call async function and return future fow wait result + """ if self._closed: raise TopicReaderClosedError() return asyncio.run_coroutine_threadsafe(coro, self._loop) def _call_sync(self, coro: Coroutine, timeout): + """ + Call async function, wait and return result + """ f = self._call(coro) try: return f.result(timeout) @@ -162,15 +168,13 @@ def commit_with_ack( if receive in timeout seconds (default - infinite): raise TimeoutError() """ - raise NotImplementedError() + return self._call_sync(self._async_reader.commit_with_ack(mess), None) - def async_commit_with_ack( - self, mess: ICommittable - ) -> Union[CommitResult, List[CommitResult]]: + def async_commit_with_ack(self, mess: ICommittable) -> concurrent.futures.Future: """ write commit message to a buffer and return Future for wait result. """ - raise NotImplementedError() + return self._call(self._async_reader.commit_with_ack(mess), None) def async_flush(self) -> concurrent.futures.Future: """ diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 4724ab2f..b46a13b8 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -151,6 +151,7 @@ async def wait_init(self) -> PublicWriterInitInfo: class WriterAsyncIOReconnector: _closed: bool + _loop: asyncio.AbstractEventLoop _credentials: Union[ydb.Credentials, None] _driver: ydb.aio.Driver _update_token_interval: int @@ -169,10 +170,12 @@ class WriterAsyncIOReconnector: def __init__(self, driver: SupportedDriverType, settings: WriterSettings): self._closed = False + self._loop = asyncio.get_running_loop() self._driver = driver self._credentials = driver._credentials self._init_message = settings.create_init_request() - self._init_info = asyncio.Future() + self._new_messages = asyncio.Queue() + self._init_info = self._loop.create_future() self._stream_connected = asyncio.Event() self._settings = settings @@ -180,7 +183,7 @@ def __init__(self, driver: SupportedDriverType, settings: WriterSettings): self._messages = deque() self._messages_future = deque() self._new_messages = asyncio.Queue() - self._stop_reason = asyncio.Future() + self._stop_reason = self._loop.create_future() self._background_tasks = [ asyncio.create_task(self._connection_loop(), name="connection_loop") ] @@ -233,7 +236,7 @@ async def write_with_ack_future( await self.wait_init() internal_messages = self._prepare_internal_messages(messages) - messages_future = [asyncio.Future() for _ in internal_messages] + messages_future = [self._loop.create_future() for _ in internal_messages] self._messages.extend(internal_messages) self._messages_future.extend(messages_future) diff --git a/ydb/_utilities.py b/ydb/_utilities.py index 544b154c..0b72a198 100644 --- a/ydb/_utilities.py +++ b/ydb/_utilities.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +import threading import codecs from concurrent import futures import functools @@ -157,3 +158,17 @@ def next(self): def __next__(self): return self._next() + + +class AtomicCounter: + _lock: threading.Lock + _value: int + + def __init__(self, initial_value: int = 0): + self._lock = threading.Lock() + self._value = initial_value + + def inc_and_get(self) -> int: + with self._lock: + self._value += 1 + return self._value From af833fff9a8043f4ae9928cbb4bf604d10d65e7e Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Wed, 1 Mar 2023 18:13:00 +0300 Subject: [PATCH 2/8] remove logic for reorder commits --- ydb/_topic_common/test_helpers.py | 4 +- ydb/_topic_reader/datatypes.py | 76 ++---- ydb/_topic_reader/datatypes_test.py | 254 ++++++------------ ydb/_topic_reader/topic_reader_asyncio.py | 8 +- .../topic_reader_asyncio_test.py | 146 ++-------- 5 files changed, 125 insertions(+), 363 deletions(-) diff --git a/ydb/_topic_common/test_helpers.py b/ydb/_topic_common/test_helpers.py index 60166d0d..d70cd9f1 100644 --- a/ydb/_topic_common/test_helpers.py +++ b/ydb/_topic_common/test_helpers.py @@ -39,7 +39,7 @@ def close(self): self.from_server.put_nowait(None) -class WaitConditionException(Exception): +class WaitConditionError(Exception): pass @@ -62,7 +62,7 @@ async def wait_condition( return await asyncio.sleep(0) - raise WaitConditionException("Bad condition in test") + raise WaitConditionError("Bad condition in test") async def wait_for_fast( diff --git a/ydb/_topic_reader/datatypes.py b/ydb/_topic_reader/datatypes.py index 06b8d690..6ca7681c 100644 --- a/ydb/_topic_reader/datatypes.py +++ b/ydb/_topic_reader/datatypes.py @@ -68,12 +68,6 @@ class PartitionSession: reader_reconnector_id: int reader_stream_id: int _next_message_start_commit_offset: int = field(init=False) - _send_commit_window_start: int = field(init=False) - - # todo: check if deque is optimal - _pending_commits: Deque[OffsetsRange] = field( - init=False, default_factory=lambda: deque() - ) # todo: check if deque is optimal _ack_waiters: Deque["PartitionSession.CommitAckWaiter"] = field( @@ -89,45 +83,17 @@ class PartitionSession: def __post_init__(self): self._next_message_start_commit_offset = self.committed_offset - self._send_commit_window_start = self.committed_offset try: self._loop = asyncio.get_running_loop() except RuntimeError: self._loop = None - def add_commit( - self, new_commit: OffsetsRange - ) -> "PartitionSession.CommitAckWaiter": - self._ensure_not_closed() - - self._add_to_commits(new_commit) - return self._add_waiter(new_commit.end) - - def _add_to_commits(self, new_commit: OffsetsRange): - index = bisect.bisect_left(self._pending_commits, new_commit) - - prev_commit = self._pending_commits[index - 1] if index > 0 else None - commit = ( - self._pending_commits[index] if index < len(self._pending_commits) else None - ) - - for c in (prev_commit, commit): - if c is not None and new_commit.is_intersected_with(c): - raise ValueError( - "new commit intersected with existed. New range: %s, existed: %s" - % (new_commit, c) - ) - - if commit is not None and commit.start == new_commit.end: - commit.start = new_commit.start - elif prev_commit is not None and prev_commit.end == new_commit.start: - prev_commit.end = new_commit.end - else: - self._pending_commits.insert(index, new_commit) - - def _add_waiter(self, end_offset: int) -> "PartitionSession.CommitAckWaiter": + def add_waiter(self, end_offset: int) -> "PartitionSession.CommitAckWaiter": waiter = PartitionSession.CommitAckWaiter(end_offset, self._create_future()) + if end_offset <= self.committed_offset: + waiter._finish_ok() + return waiter # fast way if len(self._ack_waiters) > 0 and self._ack_waiters[-1].end_offset < end_offset: @@ -143,26 +109,6 @@ def _create_future(self) -> asyncio.Future: else: return asyncio.Future() - def pop_commit_range(self) -> Optional[OffsetsRange]: - self._ensure_not_closed() - - if len(self._pending_commits) == 0: - return None - - if self._pending_commits[0].start != self._send_commit_window_start: - return None - - res = self._pending_commits.popleft() - while ( - len(self._pending_commits) > 0 and self._pending_commits[0].start == res.end - ): - commit = self._pending_commits.popleft() - res.end = commit.end - - self._send_commit_window_start = res.end - - return res - def ack_notify(self, offset: int): self._ensure_not_closed() @@ -176,7 +122,7 @@ def ack_notify(self, offset: int): while len(self._ack_waiters) > 0: if self._ack_waiters[0].end_offset <= offset: waiter = self._ack_waiters.popleft() - waiter.future.set_result(None) + waiter._finish_ok() else: break @@ -189,7 +135,7 @@ def close(self): self.state = PartitionSession.State.Stopped exception = topic_reader_asyncio.TopicReaderCommitToExpiredPartition() for waiter in self._ack_waiters: - waiter.future.set_exception(exception) + waiter._finish_error(exception) def _ensure_not_closed(self): if self.state == PartitionSession.State.Stopped: @@ -204,6 +150,16 @@ class State(enum.Enum): class CommitAckWaiter: end_offset: int future: asyncio.Future = field(compare=False) + _done: bool = field(default=False, init=False) + _exception: Optional[Exception] = field(default=None, init=False) + + def _finish_ok(self): + self._done = True + self.future.set_result(None) + + def _finish_error(self, error: Exception): + self._exception = error + self.future.set_exception(error) @dataclass diff --git a/ydb/_topic_reader/datatypes_test.py b/ydb/_topic_reader/datatypes_test.py index 6ead9a88..2ec1229f 100644 --- a/ydb/_topic_reader/datatypes_test.py +++ b/ydb/_topic_reader/datatypes_test.py @@ -1,13 +1,11 @@ import asyncio -import bisect import copy import functools from collections import deque -from typing import List, Optional, Type, Union +from typing import List import pytest -from ydb._grpc.grpcwrapper.ydb_topic import OffsetsRange from ydb._topic_common.test_helpers import wait_condition from ydb._topic_reader import topic_reader_asyncio from ydb._topic_reader.datatypes import PartitionSession @@ -73,155 +71,127 @@ def add_notify(future, notified_offset): assert notified == offsets_notified assert session.committed_offset == notify_offset - def test_add_commit(self, session): - commit = OffsetsRange( - self.session_comitted_offset, self.session_comitted_offset + 5 - ) - waiter = session.add_commit(commit) - assert waiter.end_offset == commit.end - + # noinspection PyTypeChecker @pytest.mark.parametrize( - "original,add,result", + "original,add,is_done,result", [ ( [], - OffsetsRange(1, 10), - [OffsetsRange(1, 10)], - ), - ( - [OffsetsRange(1, 10)], - OffsetsRange(15, 20), - [OffsetsRange(1, 10), OffsetsRange(15, 20)], - ), - ( - [OffsetsRange(15, 20)], - OffsetsRange(1, 10), - [OffsetsRange(1, 10), OffsetsRange(15, 20)], - ), - ( - [OffsetsRange(1, 10)], - OffsetsRange(10, 20), - [OffsetsRange(1, 20)], - ), - ( - [OffsetsRange(10, 20)], - OffsetsRange(1, 10), - [OffsetsRange(1, 20)], + session_comitted_offset - 5, + True, + [], ), ( - [OffsetsRange(1, 2), OffsetsRange(3, 4)], - OffsetsRange(2, 3), - [OffsetsRange(1, 2), OffsetsRange(2, 4)], + [PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None)], + session_comitted_offset + 0, + True, + [ + PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None), + ], ), - ( - [OffsetsRange(1, 10)], - OffsetsRange(5, 6), - ValueError, - ), - ], - ) - def test_add_to_commits( - self, - session, - original: List[OffsetsRange], - add: OffsetsRange, - result: Union[List[OffsetsRange], Type[Exception]], - ): - session._pending_commits = copy.deepcopy(original) - if isinstance(result, type) and issubclass(result, Exception): - with pytest.raises(result): - session._add_to_commits(add) - else: - session._add_to_commits(add) - assert session._pending_commits == result - - # noinspection PyTypeChecker - @pytest.mark.parametrize( - "original,add,result", - [ ( [], - 5, - [PartitionSession.CommitAckWaiter(5, None)], - ), - ( - [PartitionSession.CommitAckWaiter(5, None)], - 6, + session_comitted_offset + 5, + False, [ - PartitionSession.CommitAckWaiter(5, None), - PartitionSession.CommitAckWaiter(6, None), + PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None), ], ), ( - [PartitionSession.CommitAckWaiter(5, None)], - 4, + [PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None)], + session_comitted_offset + 6, + False, [ - PartitionSession.CommitAckWaiter(4, None), - PartitionSession.CommitAckWaiter(5, None), + PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None), + PartitionSession.CommitAckWaiter(session_comitted_offset + 6, None), ], ), ( - [PartitionSession.CommitAckWaiter(5, None)], - 0, + [PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None)], + session_comitted_offset + 4, + False, [ - PartitionSession.CommitAckWaiter(0, None), - PartitionSession.CommitAckWaiter(5, None), + PartitionSession.CommitAckWaiter(session_comitted_offset + 4, None), + PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None), ], ), ( - [PartitionSession.CommitAckWaiter(5, None)], - 100, + [PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None)], + session_comitted_offset + 100, + False, [ - PartitionSession.CommitAckWaiter(5, None), - PartitionSession.CommitAckWaiter(100, None), + PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None), + PartitionSession.CommitAckWaiter( + session_comitted_offset + 100, None + ), ], ), ( [ - PartitionSession.CommitAckWaiter(5, None), - PartitionSession.CommitAckWaiter(100, None), + PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None), + PartitionSession.CommitAckWaiter( + session_comitted_offset + 100, None + ), ], - 50, + session_comitted_offset + 50, + False, [ - PartitionSession.CommitAckWaiter(5, None), - PartitionSession.CommitAckWaiter(50, None), - PartitionSession.CommitAckWaiter(100, None), + PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None), + PartitionSession.CommitAckWaiter( + session_comitted_offset + 50, None + ), + PartitionSession.CommitAckWaiter( + session_comitted_offset + 100, None + ), ], ), ( [ - PartitionSession.CommitAckWaiter(5, None), - PartitionSession.CommitAckWaiter(7, None), + PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None), + PartitionSession.CommitAckWaiter(session_comitted_offset + 7, None), ], - 6, + session_comitted_offset + 6, + False, [ - PartitionSession.CommitAckWaiter(5, None), - PartitionSession.CommitAckWaiter(6, None), - PartitionSession.CommitAckWaiter(7, None), + PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None), + PartitionSession.CommitAckWaiter(session_comitted_offset + 6, None), + PartitionSession.CommitAckWaiter(session_comitted_offset + 7, None), ], ), ( [ - PartitionSession.CommitAckWaiter(5, None), - PartitionSession.CommitAckWaiter(100, None), + PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None), + PartitionSession.CommitAckWaiter( + session_comitted_offset + 100, None + ), ], - 6, + session_comitted_offset + 6, + False, [ - PartitionSession.CommitAckWaiter(5, None), - PartitionSession.CommitAckWaiter(6, None), - PartitionSession.CommitAckWaiter(100, None), + PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None), + PartitionSession.CommitAckWaiter(session_comitted_offset + 6, None), + PartitionSession.CommitAckWaiter( + session_comitted_offset + 100, None + ), ], ), ( [ - PartitionSession.CommitAckWaiter(5, None), - PartitionSession.CommitAckWaiter(100, None), + PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None), + PartitionSession.CommitAckWaiter( + session_comitted_offset + 100, None + ), ], - 99, + session_comitted_offset + 99, + False, [ - PartitionSession.CommitAckWaiter(5, None), - PartitionSession.CommitAckWaiter(99, None), - PartitionSession.CommitAckWaiter(100, None), + PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None), + PartitionSession.CommitAckWaiter( + session_comitted_offset + 99, None + ), + PartitionSession.CommitAckWaiter( + session_comitted_offset + 100, None + ), ], ), ], @@ -231,17 +201,16 @@ def test_add_waiter( session, original: List[PartitionSession.CommitAckWaiter], add: int, + is_done: bool, result: List[PartitionSession.CommitAckWaiter], ): session._ack_waiters = copy.deepcopy(original) - res = session._add_waiter(add) + res = session.add_waiter(add) assert result == session._ack_waiters - - index = bisect.bisect_left(session._ack_waiters, res) - assert res is session._ack_waiters[index] + assert res.future.done() == is_done def test_close_notify_waiters(self, session): - waiter = session._add_waiter(session.committed_offset + 1) + waiter = session.add_waiter(session.committed_offset + 1) session.close() with pytest.raises(topic_reader_asyncio.TopicReaderCommitToExpiredPartition): @@ -250,66 +219,3 @@ def test_close_notify_waiters(self, session): def test_close_twice(self, session): session.close() session.close() - - @pytest.mark.parametrize( - "commits,result,rest", - [ - ([], None, []), - ( - [OffsetsRange(session_comitted_offset + 1, 20)], - None, - [OffsetsRange(session_comitted_offset + 1, 20)], - ), - ( - [OffsetsRange(session_comitted_offset, session_comitted_offset + 1)], - OffsetsRange(session_comitted_offset, session_comitted_offset + 1), - [], - ), - ( - [ - OffsetsRange(session_comitted_offset, session_comitted_offset + 1), - OffsetsRange( - session_comitted_offset + 1, session_comitted_offset + 2 - ), - ], - OffsetsRange(session_comitted_offset, session_comitted_offset + 2), - [], - ), - ( - [ - OffsetsRange(session_comitted_offset, session_comitted_offset + 1), - OffsetsRange( - session_comitted_offset + 1, session_comitted_offset + 2 - ), - OffsetsRange( - session_comitted_offset + 10, session_comitted_offset + 20 - ), - ], - OffsetsRange(session_comitted_offset, session_comitted_offset + 2), - [ - OffsetsRange( - session_comitted_offset + 10, session_comitted_offset + 20 - ) - ], - ), - ], - ) - def test_get_commit_range( - self, - session, - commits: List[OffsetsRange], - result: Optional[OffsetsRange], - rest: List[OffsetsRange], - ): - send_commit_window_start = session._send_commit_window_start - - session._pending_commits = deque(commits) - res = session.pop_commit_range() - assert res == result - assert session._pending_commits == deque(rest) - - if res is None: - assert session._send_commit_window_start == send_commit_window_start - else: - assert session._send_commit_window_start != send_commit_window_start - assert session._send_commit_window_start == res.end diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index cc0839f7..835fc786 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -377,15 +377,15 @@ def commit( "commit messages after server stop the partition read session" ) - waiter = partition_session.add_commit(batch._commit_get_offsets_range()) + commit_range = batch._commit_get_offsets_range() + waiter = partition_session.add_waiter(commit_range.end) - send_range = partition_session.pop_commit_range() - if send_range: + if not waiter.future.done(): client_message = StreamReadMessage.CommitOffsetRequest( commit_offsets=[ StreamReadMessage.CommitOffsetRequest.PartitionCommitOffset( partition_session_id=partition_session.id, - offsets=[send_range], + offsets=[commit_range], ) ] ) diff --git a/ydb/_topic_reader/topic_reader_asyncio_test.py b/ydb/_topic_reader/topic_reader_asyncio_test.py index c73be69f..e4609ea0 100644 --- a/ydb/_topic_reader/topic_reader_asyncio_test.py +++ b/ydb/_topic_reader/topic_reader_asyncio_test.py @@ -1,9 +1,7 @@ import asyncio import datetime import typing -from collections import deque from dataclasses import dataclass -from typing import List, Optional from unittest import mock import pytest @@ -15,8 +13,12 @@ from .topic_reader_asyncio import ReaderStream, ReaderReconnector from .._grpc.grpcwrapper.common_utils import SupportedDriverType, ServerStatus from .._grpc.grpcwrapper.ydb_topic import StreamReadMessage, Codec, OffsetsRange -from .._topic_common import test_helpers -from .._topic_common.test_helpers import StreamMock, wait_condition, wait_for_fast +from .._topic_common.test_helpers import ( + StreamMock, + wait_condition, + wait_for_fast, + WaitConditionError, +) # Workaround for good IDE and universal for runtime if typing.TYPE_CHECKING: @@ -234,124 +236,21 @@ class TestError(Exception): stream_reader_finish_with_error.receive_batch_nowait() @pytest.mark.parametrize( - "pending_ranges,commit,send_range,rest_ranges", + "commit,send_range", [ ( - [], OffsetsRange( partition_session_committed_offset, partition_session_committed_offset + 1, ), - OffsetsRange( - partition_session_committed_offset, - partition_session_committed_offset + 1, - ), - [], - ), - ( - [], - OffsetsRange( - partition_session_committed_offset + 1, - partition_session_committed_offset + 2, - ), - None, - [ - OffsetsRange( - partition_session_committed_offset + 1, - partition_session_committed_offset + 2, - ) - ], - ), - ( - [ - OffsetsRange( - partition_session_committed_offset + 5, - partition_session_committed_offset + 10, - ) - ], - OffsetsRange( - partition_session_committed_offset + 1, - partition_session_committed_offset + 2, - ), - None, - [ - OffsetsRange( - partition_session_committed_offset + 1, - partition_session_committed_offset + 2, - ), - OffsetsRange( - partition_session_committed_offset + 5, - partition_session_committed_offset + 10, - ), - ], - ), - ( - [ - OffsetsRange( - partition_session_committed_offset + 1, - partition_session_committed_offset + 2, - ) - ], - OffsetsRange( - partition_session_committed_offset, - partition_session_committed_offset + 1, - ), - OffsetsRange( - partition_session_committed_offset, - partition_session_committed_offset + 2, - ), - [], + True, ), ( - [ - OffsetsRange( - partition_session_committed_offset + 1, - partition_session_committed_offset + 2, - ), - OffsetsRange( - partition_session_committed_offset + 2, - partition_session_committed_offset + 3, - ), - ], OffsetsRange( + partition_session_committed_offset - 1, partition_session_committed_offset, - partition_session_committed_offset + 1, ), - OffsetsRange( - partition_session_committed_offset, - partition_session_committed_offset + 3, - ), - [], - ), - ( - [ - OffsetsRange( - partition_session_committed_offset + 1, - partition_session_committed_offset + 2, - ), - OffsetsRange( - partition_session_committed_offset + 2, - partition_session_committed_offset + 3, - ), - OffsetsRange( - partition_session_committed_offset + 4, - partition_session_committed_offset + 5, - ), - ], - OffsetsRange( - partition_session_committed_offset, - partition_session_committed_offset + 1, - ), - OffsetsRange( - partition_session_committed_offset, - partition_session_committed_offset + 3, - ), - [ - OffsetsRange( - partition_session_committed_offset + 4, - partition_session_committed_offset + 5, - ) - ], + False, ), ], ) @@ -360,10 +259,8 @@ async def test_send_commit_messages( stream, stream_reader: ReaderStream, partition_session, - pending_ranges: List[OffsetsRange], commit: OffsetsRange, - send_range: Optional[OffsetsRange], - rest_ranges: List[OffsetsRange], + send_range: bool, ): @dataclass class Commitable(datatypes.ICommittable): @@ -376,9 +273,9 @@ def _commit_get_partition_session(self) -> datatypes.PartitionSession: def _commit_get_offsets_range(self) -> OffsetsRange: return OffsetsRange(self.start, self.end) - partition_session._pending_commits = deque(pending_ranges) + start_ack_waiters = partition_session._ack_waiters.copy() - stream_reader.commit(Commitable(commit.start, commit.end)) + waiter = stream_reader.commit(Commitable(commit.start, commit.end)) async def wait_message(): return await wait_for_fast(stream.from_client.get(), timeout=0) @@ -389,24 +286,27 @@ async def wait_message(): commit_offsets=[ StreamReadMessage.CommitOffsetRequest.PartitionCommitOffset( partition_session_id=partition_session.id, - offsets=[send_range], + offsets=[commit], ) ] ) + assert partition_session._ack_waiters[-1].end_offset == commit.end else: - with pytest.raises(test_helpers.WaitConditionException): - await wait_message() + assert waiter.future.done() - assert partition_session._pending_commits == deque(rest_ranges) + with pytest.raises(WaitConditionError): + msg = await wait_message() + pass + assert start_ack_waiters == partition_session._ack_waiters async def test_commit_ack_received( self, stream_reader, stream, partition_session, second_partition_session ): offset1 = self.partition_session_committed_offset + 1 - waiter1 = partition_session._add_waiter(offset1) + waiter1 = partition_session.add_waiter(offset1) offset2 = self.second_partition_session_offset + 2 - waiter2 = second_partition_session._add_waiter(offset2) + waiter2 = second_partition_session.add_waiter(offset2) stream.from_server.put_nowait( StreamReadMessage.FromServer( @@ -432,7 +332,7 @@ async def test_commit_ack_received( async def test_close_ack_waiters_when_close_stream_reader( self, stream_reader_started: ReaderStream, partition_session ): - waiter = partition_session._add_waiter( + waiter = partition_session.add_waiter( self.partition_session_committed_offset + 1 ) await wait_for_fast(stream_reader_started.close()) From d13d85d318099bc781302e06624d0de4073d969c Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Wed, 1 Mar 2023 18:51:16 +0300 Subject: [PATCH 3/8] fix style --- .github/workflows/style.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/style.yaml b/.github/workflows/style.yaml index c280042b..8723d8f2 100644 --- a/.github/workflows/style.yaml +++ b/.github/workflows/style.yaml @@ -2,7 +2,6 @@ name: Style checks on: push: - - main pull_request: jobs: From 9e0e44d8d59114f91a36c12e7617bad747fea034 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Thu, 2 Mar 2023 17:35:01 +0300 Subject: [PATCH 4/8] fix typos --- tests/topics/test_topic_reader.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/topics/test_topic_reader.py b/tests/topics/test_topic_reader.py index 734b64c7..a874c743 100644 --- a/tests/topics/test_topic_reader.py +++ b/tests/topics/test_topic_reader.py @@ -15,11 +15,11 @@ async def test_read_and_commit_message( self, driver, topic_path, topic_with_messages, topic_consumer ): - reader = driver.topic_client.topic_reader(topic_consumer, topic_path) + reader = driver.topic_client.reader(topic_consumer, topic_path) batch = await reader.receive_batch() await reader.commit_with_ack(batch) - reader = driver.topic_client.topic_reader(topic_consumer, topic_path) + reader = driver.topic_client.reader(topic_consumer, topic_path) batch2 = await reader.receive_batch() assert batch.messages[0] != batch2.messages[0] @@ -36,10 +36,10 @@ def test_read_message( def test_read_and_commit_message( self, driver_sync, topic_path, topic_with_messages, topic_consumer ): - reader = driver_sync.topic_client.topic_reader(topic_consumer, topic_path) + reader = driver_sync.topic_client.reader(topic_consumer, topic_path) batch = reader.receive_batch() reader.commit_with_ack(batch) - reader = driver_sync.topic_client.topic_reader(topic_consumer, topic_path) + reader = driver_sync.topic_client.reader(topic_consumer, topic_path) batch2 = reader.receive_batch() assert batch.messages[0] != batch2.messages[0] From 9b746fc4b1a08ea51e51103383216d5ed1ea1504 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Thu, 2 Mar 2023 18:21:34 +0300 Subject: [PATCH 5/8] typo while check reconnector_id and style --- ydb/_topic_common/test_helpers.py | 4 +++- ydb/_topic_reader/topic_reader_asyncio.py | 12 +++--------- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/ydb/_topic_common/test_helpers.py b/ydb/_topic_common/test_helpers.py index d70cd9f1..96a812ab 100644 --- a/ydb/_topic_common/test_helpers.py +++ b/ydb/_topic_common/test_helpers.py @@ -54,9 +54,11 @@ async def wait_condition( if timeout is None: timeout = 1 + minimal_loop_count_for_wait = 1000 + start = time.monotonic() counter = 0 - while (time.monotonic() - start < timeout) or counter < 1000: + while (time.monotonic() - start < timeout) or counter < minimal_loop_count_for_wait: counter += 1 if f(): return diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index 835fc786..303f4c91 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -361,10 +361,7 @@ def commit( ) -> datatypes.PartitionSession.CommitAckWaiter: partition_session = batch._commit_get_partition_session() - if ( - partition_session.reader_reconnector_id - != partition_session.reader_reconnector_id - ): + if partition_session.reader_reconnector_id != self._reader_reconnector_id: raise TopicReaderError("reader can commit only self-produced messages") if partition_session.reader_stream_id != self._id: @@ -498,11 +495,8 @@ def _on_read_response(self, message: StreamReadMessage.ReadResponse): def _on_commit_response(self, message: StreamReadMessage.CommitOffsetResponse): for partition_offset in message.partitions_committed_offsets: - try: - session = self._partition_sessions[ - partition_offset.partition_session_id - ] - except KeyError: + session = self._partition_sessions.get(partition_offset.partition_session_id) + if session is None: continue session.ack_notify(partition_offset.committed_offset) From 37e117c2acec5e5b2cb60eb696f24c297042897e Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Thu, 2 Mar 2023 19:10:47 +0300 Subject: [PATCH 6/8] style --- ydb/_topic_reader/topic_reader_asyncio.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index 303f4c91..44125e54 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -495,7 +495,9 @@ def _on_read_response(self, message: StreamReadMessage.ReadResponse): def _on_commit_response(self, message: StreamReadMessage.CommitOffsetResponse): for partition_offset in message.partitions_committed_offsets: - session = self._partition_sessions.get(partition_offset.partition_session_id) + session = self._partition_sessions.get( + partition_offset.partition_session_id + ) if session is None: continue session.ack_notify(partition_offset.committed_offset) From 41b33569991232dfe8f29459fa8c2d32575f94c4 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Fri, 3 Mar 2023 16:00:27 +0300 Subject: [PATCH 7/8] style --- ydb/_topic_reader/topic_reader_asyncio.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index 44125e54..fa940136 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -470,8 +470,9 @@ def _on_start_partition_session( def _on_partition_session_stop( self, message: StreamReadMessage.StopPartitionSessionRequest ): - partition = self._partition_sessions.get(message.partition_session_id) - if partition is None: + try: + partition = self._partition_sessions.get(message.partition_session_id) + except KeyError: # may if receive stop partition with graceful=false after response on stop partition # with graceful=true and remove partition from internal dictionary return From 2334f3591095a407a37307213e2d4905e3cc530e Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Fri, 3 Mar 2023 16:17:08 +0300 Subject: [PATCH 8/8] typo --- ydb/_topic_reader/topic_reader_asyncio.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index fa940136..ab0981f6 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -471,7 +471,7 @@ def _on_partition_session_stop( self, message: StreamReadMessage.StopPartitionSessionRequest ): try: - partition = self._partition_sessions.get(message.partition_session_id) + partition = self._partition_sessions[message.partition_session_id] except KeyError: # may if receive stop partition with graceful=false after response on stop partition # with graceful=true and remove partition from internal dictionary