diff --git a/tests/topics/test_topic_reader.py b/tests/topics/test_topic_reader.py index 2a451baf..84d61a43 100644 --- a/tests/topics/test_topic_reader.py +++ b/tests/topics/test_topic_reader.py @@ -5,12 +5,26 @@ @pytest.mark.asyncio class TestTopicReaderAsyncIO: + async def test_read_batch( + self, driver, topic_path, topic_with_messages, topic_consumer + ): + reader = driver.topic_client.reader(topic_consumer, topic_path) + batch = await reader.receive_batch() + + assert batch is not None + assert len(batch.messages) > 0 + + await reader.close() + async def test_read_message( self, driver, topic_path, topic_with_messages, topic_consumer ): reader = driver.topic_client.reader(topic_consumer, topic_path) + msg = await reader.receive_message() + + assert msg is not None + assert msg.seqno - assert await reader.receive_batch() is not None await reader.close() async def test_read_and_commit_message( @@ -59,12 +73,26 @@ def decode(b: bytes): class TestTopicReaderSync: + def test_read_batch( + self, driver_sync, topic_path, topic_with_messages, topic_consumer + ): + reader = driver_sync.topic_client.reader(topic_consumer, topic_path) + batch = reader.receive_batch() + + assert batch is not None + assert len(batch.messages) > 0 + + reader.close() + def test_read_message( self, driver_sync, topic_path, topic_with_messages, topic_consumer ): reader = driver_sync.topic_client.reader(topic_consumer, topic_path) + msg = reader.receive_message() + + assert msg is not None + assert msg.seqno - assert reader.receive_batch() is not None reader.close() def test_read_and_commit_message( diff --git a/ydb/_topic_reader/datatypes.py b/ydb/_topic_reader/datatypes.py index 3845995f..5376c76d 100644 --- a/ydb/_topic_reader/datatypes.py +++ b/ydb/_topic_reader/datatypes.py @@ -179,6 +179,9 @@ def _commit_get_offsets_range(self) -> OffsetsRange: self.messages[-1]._commit_get_offsets_range().end, ) + def empty(self) -> bool: + return len(self.messages) == 0 + # ISessionAlive implementation @property def is_alive(self) -> bool: @@ -187,3 +190,6 @@ def is_alive(self) -> bool: state == PartitionSession.State.Active or state == PartitionSession.State.GracefulShutdown ) + + def pop_message(self) -> PublicMessage: + return self.messages.pop(0) diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index bb87d3cc..0068e4ba 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -95,14 +95,6 @@ def messages( """ raise NotImplementedError() - async def receive_message(self) -> typing.Union[topic_reader.PublicMessage, None]: - """ - Block until receive new message - - use asyncio.wait_for for wait with timeout. - """ - raise NotImplementedError() - def batches( self, *, @@ -133,6 +125,15 @@ async def receive_batch( await self._reconnector.wait_message() return self._reconnector.receive_batch_nowait() + async def receive_message(self) -> typing.Optional[datatypes.PublicMessage]: + """ + Block until receive new message + + use asyncio.wait_for for wait with timeout. + """ + await self._reconnector.wait_message() + return self._reconnector.receive_message_nowait() + async def commit_on_exit( self, mess: datatypes.ICommittable ) -> typing.AsyncContextManager: @@ -244,6 +245,9 @@ async def wait_message(self): def receive_batch_nowait(self): return self._stream_reader.receive_batch_nowait() + def receive_message_nowait(self): + return self._stream_reader.receive_message_nowait() + def commit( self, batch: datatypes.ICommittable ) -> datatypes.PartitionSession.CommitAckWaiter: @@ -397,12 +401,24 @@ def receive_batch_nowait(self): raise self._get_first_error() if not self._message_batches: - return + return None batch = self._message_batches.popleft() self._buffer_release_bytes(batch._bytes_size) return batch + def receive_message_nowait(self): + try: + batch = self._message_batches[0] + message = batch.pop_message() + except IndexError: + return None + + if batch.empty(): + self._message_batches.popleft() + + return message + def commit( self, batch: datatypes.ICommittable ) -> datatypes.PartitionSession.CommitAckWaiter: diff --git a/ydb/_topic_reader/topic_reader_asyncio_test.py b/ydb/_topic_reader/topic_reader_asyncio_test.py index 2924cb4d..a310298e 100644 --- a/ydb/_topic_reader/topic_reader_asyncio_test.py +++ b/ydb/_topic_reader/topic_reader_asyncio_test.py @@ -4,6 +4,7 @@ import datetime import gzip import typing +from collections import deque from dataclasses import dataclass from unittest import mock @@ -53,6 +54,34 @@ def default_executor(): executor.shutdown() +def stub_partition_session(): + return datatypes.PartitionSession( + id=0, + state=datatypes.PartitionSession.State.Active, + topic_path="asd", + partition_id=1, + committed_offset=0, + reader_reconnector_id=415, + reader_stream_id=513, + ) + + +def stub_message(id: int): + return PublicMessage( + seqno=id, + created_at=datetime.datetime(2023, 3, 18, 14, 15), + message_group_id="", + session_metadata={}, + offset=0, + written_at=datetime.datetime(2023, 3, 18, 14, 15), + producer_id="", + data=bytes(), + _partition_session=stub_partition_session(), + _commit_start_offset=0, + _commit_end_offset=1, + ) + + @pytest.fixture() def default_reader_settings(default_executor): return PublicReaderSettings( @@ -179,7 +208,9 @@ async def stream_reader_finish_with_error( @staticmethod def create_message( - partition_session: datatypes.PartitionSession, seqno: int, offset_delta: int + partition_session: typing.Optional[datatypes.PartitionSession], + seqno: int, + offset_delta: int, ): return PublicMessage( seqno=seqno, @@ -963,6 +994,101 @@ async def test_read_batches( _codec=Codec.CODEC_RAW, ) + @pytest.mark.parametrize( + "batches_before,expected_message,batches_after", + [ + ([], None, []), + ( + [ + PublicBatch( + session_metadata={}, + messages=[stub_message(1)], + _partition_session=stub_partition_session(), + _bytes_size=0, + _codec=Codec.CODEC_RAW, + ) + ], + stub_message(1), + [], + ), + ( + [ + PublicBatch( + session_metadata={}, + messages=[stub_message(1), stub_message(2)], + _partition_session=stub_partition_session(), + _bytes_size=0, + _codec=Codec.CODEC_RAW, + ), + PublicBatch( + session_metadata={}, + messages=[stub_message(3), stub_message(4)], + _partition_session=stub_partition_session(), + _bytes_size=0, + _codec=Codec.CODEC_RAW, + ), + ], + stub_message(1), + [ + PublicBatch( + session_metadata={}, + messages=[stub_message(2)], + _partition_session=stub_partition_session(), + _bytes_size=0, + _codec=Codec.CODEC_RAW, + ), + PublicBatch( + session_metadata={}, + messages=[stub_message(3), stub_message(4)], + _partition_session=stub_partition_session(), + _bytes_size=0, + _codec=Codec.CODEC_RAW, + ), + ], + ), + ( + [ + PublicBatch( + session_metadata={}, + messages=[stub_message(1)], + _partition_session=stub_partition_session(), + _bytes_size=0, + _codec=Codec.CODEC_RAW, + ), + PublicBatch( + session_metadata={}, + messages=[stub_message(2), stub_message(3)], + _partition_session=stub_partition_session(), + _bytes_size=0, + _codec=Codec.CODEC_RAW, + ), + ], + stub_message(1), + [ + PublicBatch( + session_metadata={}, + messages=[stub_message(2), stub_message(3)], + _partition_session=stub_partition_session(), + _bytes_size=0, + _codec=Codec.CODEC_RAW, + ) + ], + ), + ], + ) + async def test_read_message( + self, + stream_reader, + batches_before: typing.List[datatypes.PublicBatch], + expected_message: PublicMessage, + batches_after: typing.List[datatypes.PublicBatch], + ): + stream_reader._message_batches = deque(batches_before) + mess = stream_reader.receive_message_nowait() + + assert mess == expected_message + assert list(stream_reader._message_batches) == batches_after + async def test_receive_batch_nowait(self, stream, stream_reader, partition_session): assert stream_reader.receive_batch_nowait() is None diff --git a/ydb/_topic_reader/topic_reader_sync.py b/ydb/_topic_reader/topic_reader_sync.py index 30bf92a1..ed9730fa 100644 --- a/ydb/_topic_reader/topic_reader_sync.py +++ b/ydb/_topic_reader/topic_reader_sync.py @@ -83,19 +83,28 @@ def messages( It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available. if no new message in timeout seconds (default - infinite): stop iterations by raise StopIteration - if timeout <= 0 - it will fast non block method, get messages from internal buffer only. + if timeout <= 0 - it will fast wait only one event loop cycle - without wait any i/o operations or pauses, + get messages from internal buffer only. """ raise NotImplementedError() - def receive_message(self, *, timeout: Union[float, None] = None) -> PublicMessage: + def receive_message( + self, *, timeout: TimeoutType = None + ) -> datatypes.PublicMessage: """ Block until receive new message It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available. + receive_message(timeout=0) may return None even right after async_wait_message() is ok - because lost of partition + or connection to server lost if no new message in timeout seconds (default - infinite): raise TimeoutError() - if timeout <= 0 - it will fast non block method, get messages from internal buffer only. + if timeout <= 0 - it will fast wait only one event loop cycle - without wait any i/o operations or pauses, get messages from internal buffer only. """ - raise NotImplementedError() + self._check_closed() + + return self._caller.safe_call_with_result( + self._async_reader.receive_message(), timeout + ) def async_wait_message(self) -> concurrent.futures.Future: """ @@ -105,7 +114,11 @@ def async_wait_message(self) -> concurrent.futures.Future: Possible situation when receive signal about message available, but no messages when try to receive a message. If message expired between send event and try to retrieve message (for example connection broken). """ - raise NotImplementedError() + self._check_closed() + + return self._caller.unsafe_call_with_future( + self._async_reader._reconnector.wait_message() + ) def batches( self, @@ -119,7 +132,7 @@ def batches( It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available. if no new message in timeout seconds (default - infinite): stop iterations by raise StopIteration - if timeout <= 0 - it will fast non block method, get messages from internal buffer only. + if timeout <= 0 - it will fast wait only one event loop cycle - without wait any i/o operations or pauses, get messages from internal buffer only. """ raise NotImplementedError() @@ -135,7 +148,7 @@ def receive_batch( It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available. if no new message in timeout seconds (default - infinite): raise TimeoutError() - if timeout <= 0 - it will fast non block method, get messages from internal buffer only. + if timeout <= 0 - it will fast wait only one event loop cycle - without wait any i/o operations or pauses, get messages from internal buffer only. """ self._check_closed()