diff --git a/ydb/_topic_common/test_helpers.py b/ydb/_topic_common/test_helpers.py index 96a812ab..084146f8 100644 --- a/ydb/_topic_common/test_helpers.py +++ b/ydb/_topic_common/test_helpers.py @@ -68,7 +68,7 @@ async def wait_condition( async def wait_for_fast( - awaitable: typing.Awaitable, + awaitable: typing.Union[typing.Awaitable, typing.Coroutine], timeout: typing.Optional[typing.Union[float, int]] = None, ): fut = asyncio.ensure_future(awaitable) diff --git a/ydb/_topic_reader/topic_reader_asyncio_test.py b/ydb/_topic_reader/topic_reader_asyncio_test.py index 02667d37..860676d0 100644 --- a/ydb/_topic_reader/topic_reader_asyncio_test.py +++ b/ydb/_topic_reader/topic_reader_asyncio_test.py @@ -732,6 +732,75 @@ def session_count(): with pytest.raises(asyncio.QueueEmpty): stream.from_client.get_nowait() + @pytest.mark.parametrize( + "graceful", + ( + [True], + [False], + ), + ) + async def test_free_buffer_after_partition_stop(self, stream, stream_reader, partition_session, graceful): + initial_buffer_size = stream_reader._buffer_size_bytes + message_size = initial_buffer_size - 1 + + t = datetime.datetime.now() + + stream.from_server.put_nowait( + StreamReadMessage.FromServer( + server_status=ServerStatus(issues.StatusCode.SUCCESS, []), + server_message=StreamReadMessage.ReadResponse( + bytes_size=message_size, + partition_data=[ + StreamReadMessage.ReadResponse.PartitionData( + partition_session_id=partition_session.id, + batches=[ + StreamReadMessage.ReadResponse.Batch( + message_data=[ + StreamReadMessage.ReadResponse.MessageData( + partition_session.committed_offset + 1, + seq_no=123, + created_at=t, + data=bytes(), + uncompresed_size=message_size, + message_group_id="test-message-group", + ) + ], + producer_id="asd", + write_session_meta={}, + codec=Codec.CODEC_RAW, + written_at=t, + ) + ], + ) + ], + ), + ) + ) + + def message_received(): + return len(stream_reader._message_batches) > 0 + + await wait_condition(message_received) + + assert stream_reader._buffer_size_bytes == initial_buffer_size - message_size + + stream.from_server.put_nowait( + StreamReadMessage.FromServer( + server_status=ServerStatus(issues.StatusCode.SUCCESS, []), + server_message=StreamReadMessage.StopPartitionSessionRequest( + partition_session_id=partition_session.id, + graceful=graceful, + committed_offset=partition_session.committed_offset, + ), + ) + ) + + await wait_condition(lambda: partition_session.closed) + + batch = stream_reader.receive_batch_nowait() + assert not batch.alive + assert stream_reader._buffer_size_bytes == initial_buffer_size + async def test_receive_message_from_server( self, stream_reader,