Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,26 @@ async def topic2_path(driver, topic_consumer, database) -> str:
return topic_path


@pytest.fixture()
@pytest.mark.asyncio()
async def topic_with_two_partitions_path(driver, topic_consumer, database) -> str:
topic_path = database + "/test-topic-two-partitions"

try:
await driver.topic_client.drop_topic(topic_path)
except issues.SchemeError:
pass

await driver.topic_client.create_topic(
path=topic_path,
consumers=[topic_consumer],
min_active_partitions=2,
partition_count_limit=2,
)

return topic_path


@pytest.fixture()
@pytest.mark.asyncio()
async def topic_with_messages(driver, topic_consumer, database):
Expand Down
44 changes: 44 additions & 0 deletions tests/topics/test_topic_reader.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio

import pytest

import ydb
Expand Down Expand Up @@ -161,3 +163,45 @@ def decode(b: bytes):
with driver_sync.topic_client.reader(topic_path, topic_consumer, decoders={codec: decode}) as reader:
batch = reader.receive_batch()
assert batch.messages[0].data.decode() == "123"


@pytest.mark.asyncio
class TestBugFixesAsync:
async def test_issue_297_bad_handle_stop_partition(
self, driver, topic_consumer, topic_with_two_partitions_path: str
):
async def wait(fut):
return await asyncio.wait_for(fut, timeout=10)

topic = topic_with_two_partitions_path # type: str

async with driver.topic_client.writer(topic, partition_id=0) as writer:
await writer.write_with_ack("00")

async with driver.topic_client.writer(topic, partition_id=1) as writer:
await writer.write_with_ack("01")

# Start first reader and receive messages from both partitions
reader0 = driver.topic_client.reader(topic, consumer=topic_consumer)
await wait(reader0.receive_message())
await wait(reader0.receive_message())

# Start second reader for same topic, same consumer, partition 1
reader1 = driver.topic_client.reader(topic, consumer=topic_consumer)

# receive uncommited message
await reader1.receive_message()

# write one message for every partition
async with driver.topic_client.writer(topic, partition_id=0) as writer:
await writer.write_with_ack("10")
async with driver.topic_client.writer(topic, partition_id=1) as writer:
await writer.write_with_ack("11")

msg0 = await wait(reader0.receive_message())
msg1 = await wait(reader1.receive_message())

datas = [msg0.data.decode(), msg1.data.decode()]
datas.sort()

assert datas == ["10", "11"]
69 changes: 61 additions & 8 deletions ydb/_grpc/grpcwrapper/ydb_topic.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import datetime
import enum
import typing
Expand All @@ -8,6 +10,7 @@

from . import ydb_topic_public_types
from ... import scheme
from ... import issues

# Workaround for good IDE and universal for runtime
if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -588,16 +591,32 @@ def from_proto(
)

@dataclass
class PartitionSessionStatusRequest:
class PartitionSessionStatusRequest(IToProto):
partition_session_id: int

def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.PartitionSessionStatusRequest:
return ydb_topic_pb2.StreamReadMessage.PartitionSessionStatusRequest(
partition_session_id=self.partition_session_id
)

@dataclass
class PartitionSessionStatusResponse:
class PartitionSessionStatusResponse(IFromProto):
partition_session_id: int
partition_offsets: "OffsetsRange"
committed_offset: int
write_time_high_watermark: float

@staticmethod
def from_proto(
msg: ydb_topic_pb2.StreamReadMessage.PartitionSessionStatusResponse,
) -> "StreamReadMessage.PartitionSessionStatusResponse":
return StreamReadMessage.PartitionSessionStatusResponse(
partition_session_id=msg.partition_session_id,
partition_offsets=OffsetsRange.from_proto(msg.partition_offsets),
committed_offset=msg.committed_offset,
write_time_high_watermark=msg.write_time_high_watermark,
)

@dataclass
class StartPartitionSessionRequest(IFromProto):
partition_session: "StreamReadMessage.PartitionSession"
Expand Down Expand Up @@ -632,15 +651,30 @@ def to_proto(
return res

@dataclass
class StopPartitionSessionRequest:
class StopPartitionSessionRequest(IFromProto):
partition_session_id: int
graceful: bool
committed_offset: int

@staticmethod
def from_proto(
msg: ydb_topic_pb2.StreamReadMessage.StopPartitionSessionRequest,
) -> StreamReadMessage.StopPartitionSessionRequest:
return StreamReadMessage.StopPartitionSessionRequest(
partition_session_id=msg.partition_session_id,
graceful=msg.graceful,
committed_offset=msg.committed_offset,
)

@dataclass
class StopPartitionSessionResponse:
class StopPartitionSessionResponse(IToProto):
partition_session_id: int

def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.StopPartitionSessionResponse:
return ydb_topic_pb2.StreamReadMessage.StopPartitionSessionResponse(
partition_session_id=self.partition_session_id,
)

@dataclass
class FromClient(IToProto):
client_message: "ReaderMessagesFromClientToServer"
Expand All @@ -660,6 +694,10 @@ def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.FromClient:
res.update_token_request.CopyFrom(self.client_message.to_proto())
elif isinstance(self.client_message, StreamReadMessage.StartPartitionSessionResponse):
res.start_partition_session_response.CopyFrom(self.client_message.to_proto())
elif isinstance(self.client_message, StreamReadMessage.StopPartitionSessionResponse):
res.stop_partition_session_response.CopyFrom(self.client_message.to_proto())
elif isinstance(self.client_message, StreamReadMessage.PartitionSessionStatusRequest):
res.start_partition_session_response.CopyFrom(self.client_message.to_proto())
else:
raise NotImplementedError("Unknown message type: %s" % type(self.client_message))
return res
Expand Down Expand Up @@ -694,17 +732,32 @@ def from_proto(
return StreamReadMessage.FromServer(
server_status=server_status,
server_message=StreamReadMessage.StartPartitionSessionRequest.from_proto(
msg.start_partition_session_request
msg.start_partition_session_request,
),
)
elif mess_type == "stop_partition_session_request":
return StreamReadMessage.FromServer(
server_status=server_status,
server_message=StreamReadMessage.StopPartitionSessionRequest.from_proto(
msg.stop_partition_session_request
),
)
elif mess_type == "update_token_response":
return StreamReadMessage.FromServer(
server_status=server_status,
server_message=UpdateTokenResponse.from_proto(msg.update_token_response),
)

# todo replace exception to log
raise NotImplementedError()
elif mess_type == "partition_session_status_response":
return StreamReadMessage.FromServer(
server_status=server_status,
server_message=StreamReadMessage.PartitionSessionStatusResponse.from_proto(
msg.partition_session_status_response
),
)
else:
raise issues.UnexpectedGrpcMessage(
"Unexpected message while parse ReaderMessagesFromServerToClient: '%s'" % mess_type
)


ReaderMessagesFromClientToServer = Union[
Expand Down
53 changes: 29 additions & 24 deletions ydb/_topic_reader/topic_reader_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
Codec,
)
from .._errors import check_retriable_error
import logging

logger = logging.getLogger(__name__)


class TopicReaderError(YdbError):
Expand Down Expand Up @@ -146,7 +149,6 @@ class ReaderReconnector:

def __init__(self, driver: Driver, settings: topic_reader.PublicReaderSettings):
self._id = self._static_reader_reconnector_counter.inc_and_get()

self._settings = settings
self._driver = driver
self._background_tasks = set()
Expand Down Expand Up @@ -395,39 +397,42 @@ async def _read_messages_loop(self):
)
)
while True:
message = await self._stream.receive() # type: StreamReadMessage.FromServer
_process_response(message.server_status)
try:
message = await self._stream.receive() # type: StreamReadMessage.FromServer
_process_response(message.server_status)

if isinstance(message.server_message, StreamReadMessage.ReadResponse):
self._on_read_response(message.server_message)
if isinstance(message.server_message, StreamReadMessage.ReadResponse):
self._on_read_response(message.server_message)

elif isinstance(message.server_message, StreamReadMessage.CommitOffsetResponse):
self._on_commit_response(message.server_message)
elif isinstance(message.server_message, StreamReadMessage.CommitOffsetResponse):
self._on_commit_response(message.server_message)

elif isinstance(
message.server_message,
StreamReadMessage.StartPartitionSessionRequest,
):
self._on_start_partition_session(message.server_message)
elif isinstance(
message.server_message,
StreamReadMessage.StartPartitionSessionRequest,
):
self._on_start_partition_session(message.server_message)

elif isinstance(
message.server_message,
StreamReadMessage.StopPartitionSessionRequest,
):
self._on_partition_session_stop(message.server_message)
elif isinstance(
message.server_message,
StreamReadMessage.StopPartitionSessionRequest,
):
self._on_partition_session_stop(message.server_message)

elif isinstance(message.server_message, UpdateTokenResponse):
self._update_token_event.set()
elif isinstance(message.server_message, UpdateTokenResponse):
self._update_token_event.set()

else:
raise NotImplementedError(
"Unexpected type of StreamReadMessage.FromServer message: %s" % message.server_message
)
else:
raise issues.UnexpectedGrpcMessage(
"Unexpected message in _read_messages_loop: %s" % type(message.server_message)
)
except issues.UnexpectedGrpcMessage as e:
logger.exception("unexpected message in stream reader: %s" % e)

self._state_changed.set()
except Exception as e:
self._set_first_error(e)
raise
return

async def _update_token_loop(self):
while True:
Expand Down
23 changes: 23 additions & 0 deletions ydb/_topic_reader/topic_reader_asyncio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,6 +1127,29 @@ async def test_update_token(self, stream):

await reader.close()

async def test_read_unknown_message(self, stream, stream_reader, caplog):
class TestMessage:
pass

# noinspection PyTypeChecker
stream.from_server.put_nowait(
StreamReadMessage.FromServer(
server_status=ServerStatus(
status=issues.StatusCode.SUCCESS,
issues=[],
),
server_message=TestMessage(),
)
)

def logged():
for rec in caplog.records:
if TestMessage.__name__ in rec.message:
return True
return False

await wait_condition(logged)


@pytest.mark.asyncio
class TestReaderReconnector:
Expand Down
5 changes: 5 additions & 0 deletions ydb/issues.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,11 @@ class SessionPoolEmpty(Error, queue.Empty):
status = StatusCode.SESSION_POOL_EMPTY


class UnexpectedGrpcMessage(Error):
def __init__(self, message: str):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

useless __init__

super().__init__(message)


def _format_issues(issues):
if not issues:
return ""
Expand Down
4 changes: 2 additions & 2 deletions ydb/topic.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def reader(
if not decoder_executor:
decoder_executor = self._executor

args = locals()
args = locals().copy()
del args["self"]

settings = TopicReaderSettings(**args)
Expand All @@ -188,7 +188,7 @@ def writer(
encoders: Optional[Mapping[_ydb_topic_public_types.PublicCodec, Callable[[bytes], bytes]]] = None,
encoder_executor: Optional[concurrent.futures.Executor] = None, # default shared client executor pool
) -> TopicWriterAsyncIO:
args = locals()
args = locals().copy()
del args["self"]

settings = TopicWriterSettings(**args)
Expand Down