Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion ydb/_topic_reader/topic_reader_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,14 +348,17 @@ def receive_batch_nowait(self):
return batch

def receive_message_nowait(self):
if self._get_first_error():
raise self._get_first_error()

try:
batch = self._message_batches[0]
message = batch.pop_message()
except IndexError:
return None

if batch.empty():
self._message_batches.popleft()
self.receive_batch_nowait()

return message

Expand Down
48 changes: 42 additions & 6 deletions ydb/_topic_reader/topic_reader_asyncio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,16 @@ def create_message(
)

async def send_message(self, stream_reader, message: PublicMessage):
await self.send_batch(stream_reader, [message])

async def send_batch(self, stream_reader, batch: typing.List[PublicMessage]):
if len(batch) == 0:
return

first_message = batch[0]
for message in batch:
assert message._partition_session is first_message._partition_session

def batch_count():
return len(stream_reader._message_batches)

Expand All @@ -225,7 +235,7 @@ def batch_count():
server_message=StreamReadMessage.ReadResponse(
partition_data=[
StreamReadMessage.ReadResponse.PartitionData(
partition_session_id=message._partition_session.id,
partition_session_id=first_message._partition_session.id,
batches=[
StreamReadMessage.ReadResponse.Batch(
message_data=[
Expand All @@ -237,11 +247,12 @@ def batch_count():
uncompresed_size=len(message.data),
message_group_id=message.message_group_id,
)
for message in batch
],
producer_id=message.producer_id,
write_session_meta=message.session_metadata,
producer_id=first_message.producer_id,
write_session_meta=first_message.session_metadata,
codec=Codec.CODEC_RAW,
written_at=message.written_at,
written_at=first_message.written_at,
)
],
)
Expand Down Expand Up @@ -1066,13 +1077,15 @@ async def test_read_message(
async def test_receive_batch_nowait(self, stream, stream_reader, partition_session):
assert stream_reader.receive_batch_nowait() is None

initial_buffer_size = stream_reader._buffer_size_bytes

mess1 = self.create_message(partition_session, 1, 1)
await self.send_message(stream_reader, mess1)

mess2 = self.create_message(partition_session, 2, 1)
await self.send_message(stream_reader, mess2)

initial_buffer_size = stream_reader._buffer_size_bytes
assert stream_reader._buffer_size_bytes == initial_buffer_size - 2 * self.default_batch_size

received = stream_reader.receive_batch_nowait()
assert received == PublicBatch(
Expand All @@ -1090,14 +1103,37 @@ async def test_receive_batch_nowait(self, stream, stream_reader, partition_sessi
_codec=Codec.CODEC_RAW,
)

assert stream_reader._buffer_size_bytes == initial_buffer_size + 2 * self.default_batch_size
assert stream_reader._buffer_size_bytes == initial_buffer_size

assert StreamReadMessage.ReadRequest(self.default_batch_size) == stream.from_client.get_nowait().client_message
assert StreamReadMessage.ReadRequest(self.default_batch_size) == stream.from_client.get_nowait().client_message

with pytest.raises(asyncio.QueueEmpty):
stream.from_client.get_nowait()

async def test_receive_message_nowait(self, stream, stream_reader, partition_session):
assert stream_reader.receive_batch_nowait() is None

initial_buffer_size = stream_reader._buffer_size_bytes

await self.send_batch(
stream_reader, [self.create_message(partition_session, 1, 1), self.create_message(partition_session, 2, 1)]
)
await self.send_batch(
stream_reader,
[
self.create_message(partition_session, 10, 1),
],
)

assert stream_reader._buffer_size_bytes == initial_buffer_size - 2 * self.default_batch_size

for expected_seqno in [1, 2, 10]:
mess = stream_reader.receive_message_nowait()
assert mess.seqno == expected_seqno

assert stream_reader._buffer_size_bytes == initial_buffer_size

async def test_update_token(self, stream):
settings = PublicReaderSettings(
consumer="test-consumer",
Expand Down