diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index c0ef2491..1b175e8e 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -65,13 +65,13 @@ def __del__(self): self._loop.call_soon(self.close) - async def close(self): + async def close(self, *, flush: bool = True): if self._closed: return self._closed = True - await self._reconnector.close() + await self._reconnector.close(flush) async def write_with_ack( self, @@ -109,13 +109,13 @@ async def write_with_ack_future( For wait with timeout use asyncio.wait_for. """ if isinstance(messages, PublicMessage): - futures = await self._reconnector.write_with_ack([messages]) + futures = await self._reconnector.write_with_ack_future([messages]) return futures[0] if isinstance(messages, list): for m in messages: if not isinstance(m, PublicMessage): raise NotImplementedError() - return await self._reconnector.write_with_ack(messages) + return await self._reconnector.write_with_ack_future(messages) raise NotImplementedError() async def write( @@ -185,7 +185,7 @@ def __init__(self, driver: SupportedDriverType, settings: WriterSettings): asyncio.create_task(self._connection_loop(), name="connection_loop") ] - async def close(self, flush: bool = True): + async def close(self, flush: bool): if self._closed: return @@ -223,7 +223,7 @@ async def wait_init(self) -> PublicWriterInitInfo: async def wait_stop(self) -> Exception: return await self._stop_reason - async def write_with_ack( + async def write_with_ack_future( self, messages: List[PublicMessage] ) -> List[asyncio.Future]: # todo check internal buffer limit diff --git a/ydb/_topic_writer/topic_writer_asyncio_test.py b/ydb/_topic_writer/topic_writer_asyncio_test.py index 1c96097f..7f19a4dd 100644 --- a/ydb/_topic_writer/topic_writer_asyncio_test.py +++ b/ydb/_topic_writer/topic_writer_asyncio_test.py @@ -295,7 +295,7 @@ async def test_reconnect_and_resent_non_acked_messages_on_retriable_error( seqno=2, created_at=now, ) - await reconnector.write_with_ack([message1, message2]) + await reconnector.write_with_ack_future([message1, message2]) # sent to first stream stream_writer = get_stream_writer() @@ -317,7 +317,7 @@ async def test_reconnect_and_resent_non_acked_messages_on_retriable_error( assert second_sent_msg == expected_messages second_writer.from_server.put_nowait(self.make_default_ack_message(seq_no=2)) - await reconnector.close() + await reconnector.close(flush=True) async def test_stop_on_unexpected_exception( self, reconnector: WriterAsyncIOReconnector, get_stream_writer @@ -337,7 +337,7 @@ class TestException(Exception): async def wait_stop(): while True: - await reconnector.write_with_ack([message]) + await reconnector.write_with_ack_future([message]) await asyncio.sleep(0.1) await asyncio.wait_for(wait_stop(), 1) @@ -380,7 +380,7 @@ async def test_write_message( data="123", seqno=3, ) - await reconnector.write_with_ack([message]) + await reconnector.write_with_ack_future([message]) sent_messages = await asyncio.wait_for(stream_writer.from_client.get(), 1) assert sent_messages == [InternalMessage(message)] @@ -399,8 +399,8 @@ async def test_auto_seq_no( reconnector = WriterAsyncIOReconnector(default_driver, settings) - await reconnector.write_with_ack([PublicMessage(data="123")]) - await reconnector.write_with_ack([PublicMessage(data="456")]) + await reconnector.write_with_ack_future([PublicMessage(data="123")]) + await reconnector.write_with_ack_future([PublicMessage(data="456")]) stream_writer = get_stream_writer() @@ -415,22 +415,26 @@ async def test_auto_seq_no( ] == sent with pytest.raises(TopicWriterError): - await reconnector.write_with_ack( + await reconnector.write_with_ack_future( [PublicMessage(seqno=last_seq_no + 3, data="123")] ) await reconnector.close(flush=False) async def test_deny_double_seqno(self, reconnector: WriterAsyncIOReconnector): - await reconnector.write_with_ack([PublicMessage(seqno=10, data="123")]) + await reconnector.write_with_ack_future([PublicMessage(seqno=10, data="123")]) with pytest.raises(TopicWriterError): - await reconnector.write_with_ack([PublicMessage(seqno=9, data="123")]) + await reconnector.write_with_ack_future( + [PublicMessage(seqno=9, data="123")] + ) with pytest.raises(TopicWriterError): - await reconnector.write_with_ack([PublicMessage(seqno=10, data="123")]) + await reconnector.write_with_ack_future( + [PublicMessage(seqno=10, data="123")] + ) - await reconnector.write_with_ack([PublicMessage(seqno=11, data="123")]) + await reconnector.write_with_ack_future([PublicMessage(seqno=11, data="123")]) await reconnector.close(flush=False) @@ -443,7 +447,7 @@ async def test_auto_created_at( settings = copy.deepcopy(default_settings) settings.auto_created_at = True reconnector = WriterAsyncIOReconnector(default_driver, settings) - await reconnector.write_with_ack([PublicMessage(seqno=4, data="123")]) + await reconnector.write_with_ack_future([PublicMessage(seqno=4, data="123")]) stream_writer = get_stream_writer() sent = await stream_writer.from_client.get() @@ -468,7 +472,7 @@ def __init__(self): self.futures = [] self.messages_writted = asyncio.Event() - async def write_with_ack(self, messages: typing.List[InternalMessage]): + async def write_with_ack_future(self, messages: typing.List[InternalMessage]): async with self.lock: futures = [asyncio.Future() for _ in messages] self.messages.extend(messages)