Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions ydb/_topic_writer/topic_writer_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,13 @@ def __del__(self):

self._loop.call_soon(self.close)

async def close(self):
async def close(self, *, flush: bool = True):
if self._closed:
return

self._closed = True

await self._reconnector.close()
await self._reconnector.close(flush)

async def write_with_ack(
self,
Expand Down Expand Up @@ -109,13 +109,13 @@ async def write_with_ack_future(
For wait with timeout use asyncio.wait_for.
"""
if isinstance(messages, PublicMessage):
futures = await self._reconnector.write_with_ack([messages])
futures = await self._reconnector.write_with_ack_future([messages])
return futures[0]
if isinstance(messages, list):
for m in messages:
if not isinstance(m, PublicMessage):
raise NotImplementedError()
return await self._reconnector.write_with_ack(messages)
return await self._reconnector.write_with_ack_future(messages)
raise NotImplementedError()

async def write(
Expand Down Expand Up @@ -185,7 +185,7 @@ def __init__(self, driver: SupportedDriverType, settings: WriterSettings):
asyncio.create_task(self._connection_loop(), name="connection_loop")
]

async def close(self, flush: bool = True):
async def close(self, flush: bool):
if self._closed:
return

Expand Down Expand Up @@ -223,7 +223,7 @@ async def wait_init(self) -> PublicWriterInitInfo:
async def wait_stop(self) -> Exception:
return await self._stop_reason

async def write_with_ack(
async def write_with_ack_future(
self, messages: List[PublicMessage]
) -> List[asyncio.Future]:
# todo check internal buffer limit
Expand Down
30 changes: 17 additions & 13 deletions ydb/_topic_writer/topic_writer_asyncio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ async def test_reconnect_and_resent_non_acked_messages_on_retriable_error(
seqno=2,
created_at=now,
)
await reconnector.write_with_ack([message1, message2])
await reconnector.write_with_ack_future([message1, message2])

# sent to first stream
stream_writer = get_stream_writer()
Expand All @@ -317,7 +317,7 @@ async def test_reconnect_and_resent_non_acked_messages_on_retriable_error(
assert second_sent_msg == expected_messages

second_writer.from_server.put_nowait(self.make_default_ack_message(seq_no=2))
await reconnector.close()
await reconnector.close(flush=True)

async def test_stop_on_unexpected_exception(
self, reconnector: WriterAsyncIOReconnector, get_stream_writer
Expand All @@ -337,7 +337,7 @@ class TestException(Exception):

async def wait_stop():
while True:
await reconnector.write_with_ack([message])
await reconnector.write_with_ack_future([message])
await asyncio.sleep(0.1)

await asyncio.wait_for(wait_stop(), 1)
Expand Down Expand Up @@ -380,7 +380,7 @@ async def test_write_message(
data="123",
seqno=3,
)
await reconnector.write_with_ack([message])
await reconnector.write_with_ack_future([message])

sent_messages = await asyncio.wait_for(stream_writer.from_client.get(), 1)
assert sent_messages == [InternalMessage(message)]
Expand All @@ -399,8 +399,8 @@ async def test_auto_seq_no(

reconnector = WriterAsyncIOReconnector(default_driver, settings)

await reconnector.write_with_ack([PublicMessage(data="123")])
await reconnector.write_with_ack([PublicMessage(data="456")])
await reconnector.write_with_ack_future([PublicMessage(data="123")])
await reconnector.write_with_ack_future([PublicMessage(data="456")])

stream_writer = get_stream_writer()

Expand All @@ -415,22 +415,26 @@ async def test_auto_seq_no(
] == sent

with pytest.raises(TopicWriterError):
await reconnector.write_with_ack(
await reconnector.write_with_ack_future(
[PublicMessage(seqno=last_seq_no + 3, data="123")]
)

await reconnector.close(flush=False)

async def test_deny_double_seqno(self, reconnector: WriterAsyncIOReconnector):
await reconnector.write_with_ack([PublicMessage(seqno=10, data="123")])
await reconnector.write_with_ack_future([PublicMessage(seqno=10, data="123")])

with pytest.raises(TopicWriterError):
await reconnector.write_with_ack([PublicMessage(seqno=9, data="123")])
await reconnector.write_with_ack_future(
[PublicMessage(seqno=9, data="123")]
)

with pytest.raises(TopicWriterError):
await reconnector.write_with_ack([PublicMessage(seqno=10, data="123")])
await reconnector.write_with_ack_future(
[PublicMessage(seqno=10, data="123")]
)

await reconnector.write_with_ack([PublicMessage(seqno=11, data="123")])
await reconnector.write_with_ack_future([PublicMessage(seqno=11, data="123")])

await reconnector.close(flush=False)

Expand All @@ -443,7 +447,7 @@ async def test_auto_created_at(
settings = copy.deepcopy(default_settings)
settings.auto_created_at = True
reconnector = WriterAsyncIOReconnector(default_driver, settings)
await reconnector.write_with_ack([PublicMessage(seqno=4, data="123")])
await reconnector.write_with_ack_future([PublicMessage(seqno=4, data="123")])

stream_writer = get_stream_writer()
sent = await stream_writer.from_client.get()
Expand All @@ -468,7 +472,7 @@ def __init__(self):
self.futures = []
self.messages_writted = asyncio.Event()

async def write_with_ack(self, messages: typing.List[InternalMessage]):
async def write_with_ack_future(self, messages: typing.List[InternalMessage]):
async with self.lock:
futures = [asyncio.Future() for _ in messages]
self.messages.extend(messages)
Expand Down