Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ async def topic_with_messages(driver, topic_path):
ydb.TopicWriterMessage(data="123".encode()),
ydb.TopicWriterMessage(data="456".encode()),
)
await writer.write_with_ack(
ydb.TopicWriterMessage(data="789".encode()),
ydb.TopicWriterMessage(data="0".encode()),
)
await writer.close()


Expand Down
23 changes: 23 additions & 0 deletions tests/topics/test_topic_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,18 @@ async def test_read_message(
assert await reader.receive_batch() is not None
await reader.close()

async def test_read_and_commit_message(
self, driver, topic_path, topic_with_messages, topic_consumer
):

reader = driver.topic_client.reader(topic_consumer, topic_path)
batch = await reader.receive_batch()
await reader.commit_with_ack(batch)

reader = driver.topic_client.reader(topic_consumer, topic_path)
batch2 = await reader.receive_batch()
assert batch.messages[0] != batch2.messages[0]


class TestTopicReaderSync:
def test_read_message(
Expand All @@ -20,3 +32,14 @@ def test_read_message(

assert reader.receive_batch() is not None
reader.close()

def test_read_and_commit_message(
self, driver_sync, topic_path, topic_with_messages, topic_consumer
):
reader = driver_sync.topic_client.reader(topic_consumer, topic_path)
batch = reader.receive_batch()
reader.commit_with_ack(batch)

reader = driver_sync.topic_client.reader(topic_consumer, topic_path)
batch2 = reader.receive_batch()
assert batch.messages[0] != batch2.messages[0]
112 changes: 98 additions & 14 deletions ydb/_grpc/grpcwrapper/ydb_topic.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,23 @@ def to_public(self) -> List[ydb_topic_public_types.PublicCodec]:
return list(map(Codec.to_public, self.codecs))


@dataclass
class OffsetsRange(IFromProto):
start: int
end: int
@dataclass(order=True)
class OffsetsRange(IFromProto, IToProto):
"""
half-opened interval, include [start, end) offsets
"""

__slots__ = ("start", "end")

start: int # first offset
end: int # offset after last, included to range

def __post_init__(self):
if self.end < self.start:
raise ValueError(
"offset end must be not less then start. Got start=%s end=%s"
% (self.start, self.end)
)

@staticmethod
def from_proto(msg: ydb_topic_pb2.OffsetsRange) -> "OffsetsRange":
Expand All @@ -79,6 +92,20 @@ def from_proto(msg: ydb_topic_pb2.OffsetsRange) -> "OffsetsRange":
end=msg.end,
)

def to_proto(self) -> ydb_topic_pb2.OffsetsRange:
return ydb_topic_pb2.OffsetsRange(
start=self.start,
end=self.end,
)

def is_intersected_with(self, other: "OffsetsRange") -> bool:
return (
self.start <= other.start < self.end
or self.start < other.end <= self.end
or other.start <= self.start < other.end
or other.start < self.end <= other.end
)


@dataclass
class UpdateTokenRequest(IToProto):
Expand Down Expand Up @@ -527,23 +554,67 @@ def from_proto(
)

@dataclass
class CommitOffsetRequest:
class CommitOffsetRequest(IToProto):
commit_offsets: List["PartitionCommitOffset"]

def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.CommitOffsetRequest:
res = ydb_topic_pb2.StreamReadMessage.CommitOffsetRequest(
commit_offsets=list(
map(
StreamReadMessage.CommitOffsetRequest.PartitionCommitOffset.to_proto,
self.commit_offsets,
)
),
)
return res

@dataclass
class PartitionCommitOffset:
class PartitionCommitOffset(IToProto):
partition_session_id: int
offsets: List["OffsetsRange"]

def to_proto(
self,
) -> ydb_topic_pb2.StreamReadMessage.CommitOffsetRequest.PartitionCommitOffset:
res = ydb_topic_pb2.StreamReadMessage.CommitOffsetRequest.PartitionCommitOffset(
partition_session_id=self.partition_session_id,
offsets=list(map(OffsetsRange.to_proto, self.offsets)),
)
return res

@dataclass
class CommitOffsetResponse:
partitions_committed_offsets: List["PartitionCommittedOffset"]
class CommitOffsetResponse(IFromProto):
partitions_committed_offsets: List[
"StreamReadMessage.CommitOffsetResponse.PartitionCommittedOffset"
]

@staticmethod
def from_proto(
msg: ydb_topic_pb2.StreamReadMessage.CommitOffsetResponse,
) -> "StreamReadMessage.CommitOffsetResponse":
return StreamReadMessage.CommitOffsetResponse(
partitions_committed_offsets=list(
map(
StreamReadMessage.CommitOffsetResponse.PartitionCommittedOffset.from_proto,
msg.partitions_committed_offsets,
)
)
)

@dataclass
class PartitionCommittedOffset:
class PartitionCommittedOffset(IFromProto):
partition_session_id: int
committed_offset: int

@staticmethod
def from_proto(
msg: ydb_topic_pb2.StreamReadMessage.CommitOffsetResponse.PartitionCommittedOffset,
) -> "StreamReadMessage.CommitOffsetResponse.PartitionCommittedOffset":
return StreamReadMessage.CommitOffsetResponse.PartitionCommittedOffset(
partition_session_id=msg.partition_session_id,
committed_offset=msg.committed_offset,
)

@dataclass
class PartitionSessionStatusRequest:
partition_session_id: int
Expand Down Expand Up @@ -576,16 +647,18 @@ def from_proto(
@dataclass
class StartPartitionSessionResponse(IToProto):
partition_session_id: int
read_offset: int
commit_offset: int
read_offset: Optional[int]
commit_offset: Optional[int]

def to_proto(
self,
) -> ydb_topic_pb2.StreamReadMessage.StartPartitionSessionResponse:
res = ydb_topic_pb2.StreamReadMessage.StartPartitionSessionResponse()
res.partition_session_id = self.partition_session_id
res.read_offset = self.read_offset
res.commit_offset = self.commit_offset
if self.read_offset is not None:
res.read_offset = self.read_offset
if self.commit_offset is not None:
res.commit_offset = self.commit_offset
return res

@dataclass
Expand All @@ -609,6 +682,8 @@ def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.FromClient:
res = ydb_topic_pb2.StreamReadMessage.FromClient()
if isinstance(self.client_message, StreamReadMessage.ReadRequest):
res.read_request.CopyFrom(self.client_message.to_proto())
elif isinstance(self.client_message, StreamReadMessage.CommitOffsetRequest):
res.commit_offset_request.CopyFrom(self.client_message.to_proto())
elif isinstance(self.client_message, StreamReadMessage.InitRequest):
res.init_request.CopyFrom(self.client_message.to_proto())
elif isinstance(
Expand All @@ -618,7 +693,9 @@ def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.FromClient:
self.client_message.to_proto()
)
else:
raise NotImplementedError()
raise NotImplementedError(
"Unknown message type: %s" % type(self.client_message)
)
return res

@dataclass
Expand All @@ -639,6 +716,13 @@ def from_proto(
msg.read_response
),
)
elif mess_type == "commit_offset_response":
return StreamReadMessage.FromServer(
server_status=server_status,
server_message=StreamReadMessage.CommitOffsetResponse.from_proto(
msg.commit_offset_response
),
)
elif mess_type == "init_response":
return StreamReadMessage.FromServer(
server_status=server_status,
Expand Down
27 changes: 27 additions & 0 deletions ydb/_grpc/grpcwrapper/ydb_topic_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from ydb._grpc.grpcwrapper.ydb_topic import OffsetsRange


def test_offsets_range_intersected():
# not intersected
for test in [(0, 1, 1, 2), (1, 2, 3, 5)]:
assert not OffsetsRange(test[0], test[1]).is_intersected_with(
OffsetsRange(test[2], test[3])
)
assert not OffsetsRange(test[2], test[3]).is_intersected_with(
OffsetsRange(test[0], test[1])
)

# intersected
for test in [
(1, 2, 1, 2),
(1, 10, 1, 2),
(1, 10, 2, 3),
(1, 10, 5, 15),
(10, 20, 5, 15),
]:
assert OffsetsRange(test[0], test[1]).is_intersected_with(
OffsetsRange(test[2], test[3])
)
assert OffsetsRange(test[2], test[3]).is_intersected_with(
OffsetsRange(test[0], test[1])
)
31 changes: 26 additions & 5 deletions ydb/_topic_common/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,38 @@ def close(self):
self.from_server.put_nowait(None)


async def wait_condition(f: typing.Callable[[], bool], timeout=1):
class WaitConditionError(Exception):
pass


async def wait_condition(
f: typing.Callable[[], bool],
timeout: typing.Optional[typing.Union[float, int]] = None,
):
"""
timeout default is 1 second
if timeout is 0 - only counter work. It userful if test need fast timeout for condition (without wait full timeout)
"""
if timeout is None:
timeout = 1

minimal_loop_count_for_wait = 1000

start = time.monotonic()
counter = 0
while (time.monotonic() - start < timeout) or counter < 1000:
while (time.monotonic() - start < timeout) or counter < minimal_loop_count_for_wait:
counter += 1
if f():
return
await asyncio.sleep(0)

raise Exception("Bad condition in test")
raise WaitConditionError("Bad condition in test")


async def wait_for_fast(fut):
return await asyncio.wait_for(fut, 1)
async def wait_for_fast(
awaitable: typing.Awaitable,
timeout: typing.Optional[typing.Union[float, int]] = None,
):
fut = asyncio.ensure_future(awaitable)
await wait_condition(lambda: fut.done(), timeout)
return fut.result()
Loading