diff --git a/tests/conftest.py b/tests/conftest.py index c2ff8f39..b1de747d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -176,6 +176,26 @@ async def topic2_path(driver, topic_consumer, database) -> str: return topic_path +@pytest.fixture() +@pytest.mark.asyncio() +async def topic_with_two_partitions_path(driver, topic_consumer, database) -> str: + topic_path = database + "/test-topic-two-partitions" + + try: + await driver.topic_client.drop_topic(topic_path) + except issues.SchemeError: + pass + + await driver.topic_client.create_topic( + path=topic_path, + consumers=[topic_consumer], + min_active_partitions=2, + partition_count_limit=2, + ) + + return topic_path + + @pytest.fixture() @pytest.mark.asyncio() async def topic_with_messages(driver, topic_consumer, database): diff --git a/tests/topics/test_topic_reader.py b/tests/topics/test_topic_reader.py index 3f85662b..46bc0faa 100644 --- a/tests/topics/test_topic_reader.py +++ b/tests/topics/test_topic_reader.py @@ -1,3 +1,5 @@ +import asyncio + import pytest import ydb @@ -161,3 +163,45 @@ def decode(b: bytes): with driver_sync.topic_client.reader(topic_path, topic_consumer, decoders={codec: decode}) as reader: batch = reader.receive_batch() assert batch.messages[0].data.decode() == "123" + + +@pytest.mark.asyncio +class TestBugFixesAsync: + async def test_issue_297_bad_handle_stop_partition( + self, driver, topic_consumer, topic_with_two_partitions_path: str + ): + async def wait(fut): + return await asyncio.wait_for(fut, timeout=10) + + topic = topic_with_two_partitions_path # type: str + + async with driver.topic_client.writer(topic, partition_id=0) as writer: + await writer.write_with_ack("00") + + async with driver.topic_client.writer(topic, partition_id=1) as writer: + await writer.write_with_ack("01") + + # Start first reader and receive messages from both partitions + reader0 = driver.topic_client.reader(topic, consumer=topic_consumer) + await wait(reader0.receive_message()) + await wait(reader0.receive_message()) + + # Start second reader for same topic, same consumer, partition 1 + reader1 = driver.topic_client.reader(topic, consumer=topic_consumer) + + # receive uncommited message + await reader1.receive_message() + + # write one message for every partition + async with driver.topic_client.writer(topic, partition_id=0) as writer: + await writer.write_with_ack("10") + async with driver.topic_client.writer(topic, partition_id=1) as writer: + await writer.write_with_ack("11") + + msg0 = await wait(reader0.receive_message()) + msg1 = await wait(reader1.receive_message()) + + datas = [msg0.data.decode(), msg1.data.decode()] + datas.sort() + + assert datas == ["10", "11"] diff --git a/ydb/_grpc/grpcwrapper/ydb_topic.py b/ydb/_grpc/grpcwrapper/ydb_topic.py index f20b80a9..5b5e294a 100644 --- a/ydb/_grpc/grpcwrapper/ydb_topic.py +++ b/ydb/_grpc/grpcwrapper/ydb_topic.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import datetime import enum import typing @@ -8,6 +10,7 @@ from . import ydb_topic_public_types from ... import scheme +from ... import issues # Workaround for good IDE and universal for runtime if typing.TYPE_CHECKING: @@ -588,16 +591,32 @@ def from_proto( ) @dataclass - class PartitionSessionStatusRequest: + class PartitionSessionStatusRequest(IToProto): partition_session_id: int + def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.PartitionSessionStatusRequest: + return ydb_topic_pb2.StreamReadMessage.PartitionSessionStatusRequest( + partition_session_id=self.partition_session_id + ) + @dataclass - class PartitionSessionStatusResponse: + class PartitionSessionStatusResponse(IFromProto): partition_session_id: int partition_offsets: "OffsetsRange" committed_offset: int write_time_high_watermark: float + @staticmethod + def from_proto( + msg: ydb_topic_pb2.StreamReadMessage.PartitionSessionStatusResponse, + ) -> "StreamReadMessage.PartitionSessionStatusResponse": + return StreamReadMessage.PartitionSessionStatusResponse( + partition_session_id=msg.partition_session_id, + partition_offsets=OffsetsRange.from_proto(msg.partition_offsets), + committed_offset=msg.committed_offset, + write_time_high_watermark=msg.write_time_high_watermark, + ) + @dataclass class StartPartitionSessionRequest(IFromProto): partition_session: "StreamReadMessage.PartitionSession" @@ -632,15 +651,30 @@ def to_proto( return res @dataclass - class StopPartitionSessionRequest: + class StopPartitionSessionRequest(IFromProto): partition_session_id: int graceful: bool committed_offset: int + @staticmethod + def from_proto( + msg: ydb_topic_pb2.StreamReadMessage.StopPartitionSessionRequest, + ) -> StreamReadMessage.StopPartitionSessionRequest: + return StreamReadMessage.StopPartitionSessionRequest( + partition_session_id=msg.partition_session_id, + graceful=msg.graceful, + committed_offset=msg.committed_offset, + ) + @dataclass - class StopPartitionSessionResponse: + class StopPartitionSessionResponse(IToProto): partition_session_id: int + def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.StopPartitionSessionResponse: + return ydb_topic_pb2.StreamReadMessage.StopPartitionSessionResponse( + partition_session_id=self.partition_session_id, + ) + @dataclass class FromClient(IToProto): client_message: "ReaderMessagesFromClientToServer" @@ -660,6 +694,10 @@ def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.FromClient: res.update_token_request.CopyFrom(self.client_message.to_proto()) elif isinstance(self.client_message, StreamReadMessage.StartPartitionSessionResponse): res.start_partition_session_response.CopyFrom(self.client_message.to_proto()) + elif isinstance(self.client_message, StreamReadMessage.StopPartitionSessionResponse): + res.stop_partition_session_response.CopyFrom(self.client_message.to_proto()) + elif isinstance(self.client_message, StreamReadMessage.PartitionSessionStatusRequest): + res.start_partition_session_response.CopyFrom(self.client_message.to_proto()) else: raise NotImplementedError("Unknown message type: %s" % type(self.client_message)) return res @@ -694,7 +732,14 @@ def from_proto( return StreamReadMessage.FromServer( server_status=server_status, server_message=StreamReadMessage.StartPartitionSessionRequest.from_proto( - msg.start_partition_session_request + msg.start_partition_session_request, + ), + ) + elif mess_type == "stop_partition_session_request": + return StreamReadMessage.FromServer( + server_status=server_status, + server_message=StreamReadMessage.StopPartitionSessionRequest.from_proto( + msg.stop_partition_session_request ), ) elif mess_type == "update_token_response": @@ -702,9 +747,17 @@ def from_proto( server_status=server_status, server_message=UpdateTokenResponse.from_proto(msg.update_token_response), ) - - # todo replace exception to log - raise NotImplementedError() + elif mess_type == "partition_session_status_response": + return StreamReadMessage.FromServer( + server_status=server_status, + server_message=StreamReadMessage.PartitionSessionStatusResponse.from_proto( + msg.partition_session_status_response + ), + ) + else: + raise issues.UnexpectedGrpcMessage( + "Unexpected message while parse ReaderMessagesFromServerToClient: '%s'" % mess_type + ) ReaderMessagesFromClientToServer = Union[ diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index 539d6831..ebe7bd6b 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -26,6 +26,9 @@ Codec, ) from .._errors import check_retriable_error +import logging + +logger = logging.getLogger(__name__) class TopicReaderError(YdbError): @@ -146,7 +149,6 @@ class ReaderReconnector: def __init__(self, driver: Driver, settings: topic_reader.PublicReaderSettings): self._id = self._static_reader_reconnector_counter.inc_and_get() - self._settings = settings self._driver = driver self._background_tasks = set() @@ -395,39 +397,42 @@ async def _read_messages_loop(self): ) ) while True: - message = await self._stream.receive() # type: StreamReadMessage.FromServer - _process_response(message.server_status) + try: + message = await self._stream.receive() # type: StreamReadMessage.FromServer + _process_response(message.server_status) - if isinstance(message.server_message, StreamReadMessage.ReadResponse): - self._on_read_response(message.server_message) + if isinstance(message.server_message, StreamReadMessage.ReadResponse): + self._on_read_response(message.server_message) - elif isinstance(message.server_message, StreamReadMessage.CommitOffsetResponse): - self._on_commit_response(message.server_message) + elif isinstance(message.server_message, StreamReadMessage.CommitOffsetResponse): + self._on_commit_response(message.server_message) - elif isinstance( - message.server_message, - StreamReadMessage.StartPartitionSessionRequest, - ): - self._on_start_partition_session(message.server_message) + elif isinstance( + message.server_message, + StreamReadMessage.StartPartitionSessionRequest, + ): + self._on_start_partition_session(message.server_message) - elif isinstance( - message.server_message, - StreamReadMessage.StopPartitionSessionRequest, - ): - self._on_partition_session_stop(message.server_message) + elif isinstance( + message.server_message, + StreamReadMessage.StopPartitionSessionRequest, + ): + self._on_partition_session_stop(message.server_message) - elif isinstance(message.server_message, UpdateTokenResponse): - self._update_token_event.set() + elif isinstance(message.server_message, UpdateTokenResponse): + self._update_token_event.set() - else: - raise NotImplementedError( - "Unexpected type of StreamReadMessage.FromServer message: %s" % message.server_message - ) + else: + raise issues.UnexpectedGrpcMessage( + "Unexpected message in _read_messages_loop: %s" % type(message.server_message) + ) + except issues.UnexpectedGrpcMessage as e: + logger.exception("unexpected message in stream reader: %s" % e) self._state_changed.set() except Exception as e: self._set_first_error(e) - raise + return async def _update_token_loop(self): while True: diff --git a/ydb/_topic_reader/topic_reader_asyncio_test.py b/ydb/_topic_reader/topic_reader_asyncio_test.py index a8e59dfc..c1019b02 100644 --- a/ydb/_topic_reader/topic_reader_asyncio_test.py +++ b/ydb/_topic_reader/topic_reader_asyncio_test.py @@ -1127,6 +1127,29 @@ async def test_update_token(self, stream): await reader.close() + async def test_read_unknown_message(self, stream, stream_reader, caplog): + class TestMessage: + pass + + # noinspection PyTypeChecker + stream.from_server.put_nowait( + StreamReadMessage.FromServer( + server_status=ServerStatus( + status=issues.StatusCode.SUCCESS, + issues=[], + ), + server_message=TestMessage(), + ) + ) + + def logged(): + for rec in caplog.records: + if TestMessage.__name__ in rec.message: + return True + return False + + await wait_condition(logged) + @pytest.mark.asyncio class TestReaderReconnector: diff --git a/ydb/issues.py b/ydb/issues.py index f15c475c..a489d4e0 100644 --- a/ydb/issues.py +++ b/ydb/issues.py @@ -156,6 +156,11 @@ class SessionPoolEmpty(Error, queue.Empty): status = StatusCode.SESSION_POOL_EMPTY +class UnexpectedGrpcMessage(Error): + def __init__(self, message: str): + super().__init__(message) + + def _format_issues(issues): if not issues: return "" diff --git a/ydb/topic.py b/ydb/topic.py index abf93903..190f5329 100644 --- a/ydb/topic.py +++ b/ydb/topic.py @@ -168,7 +168,7 @@ def reader( if not decoder_executor: decoder_executor = self._executor - args = locals() + args = locals().copy() del args["self"] settings = TopicReaderSettings(**args) @@ -188,7 +188,7 @@ def writer( encoders: Optional[Mapping[_ydb_topic_public_types.PublicCodec, Callable[[bytes], bytes]]] = None, encoder_executor: Optional[concurrent.futures.Executor] = None, # default shared client executor pool ) -> TopicWriterAsyncIO: - args = locals() + args = locals().copy() del args["self"] settings = TopicWriterSettings(**args)