From 1e74b11be42ecd9e7fdfb1e0d3cdf3b8d48fc521 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 14 Feb 2023 14:39:37 +0300 Subject: [PATCH 1/3] simplify topic writer --- ydb/_topic_writer/topic_writer_asyncio.py | 113 +++++++++++----------- 1 file changed, 56 insertions(+), 57 deletions(-) diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index f1e6c455..217249a1 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -40,7 +40,6 @@ class WriterAsyncIO: _loop: asyncio.AbstractEventLoop _reconnector: "WriterAsyncIOReconnector" - _lock: asyncio.Lock _closed: bool @property @@ -48,7 +47,6 @@ def last_seqno(self) -> int: raise NotImplementedError() def __init__(self, driver: SupportedDriverType, settings: PublicWriterSettings): - self._lock = asyncio.Lock() self._loop = asyncio.get_running_loop() self._closed = False self._reconnector = WriterAsyncIOReconnector( @@ -68,10 +66,10 @@ def __del__(self): self._loop.call_soon(self.close) async def close(self): - async with self._lock: - if self._closed: - return - self._closed = True + if self._closed: + return + + self._closed = True await self._reconnector.close() @@ -164,65 +162,81 @@ class WriterAsyncIOReconnector: _update_token_interval: int _token_get_function: TokenGetterFuncType _init_message: StreamWriteMessage.InitRequest - _new_messages: asyncio.Queue _init_info: asyncio.Future _stream_connected: asyncio.Event _settings: WriterSettings - _lock: asyncio.Lock _last_known_seq_no: int _messages: Deque[InternalMessage] _messages_future: Deque[asyncio.Future] - _stop_reason: Optional[Exception] + _new_messages: asyncio.Queue + _stop_reason: asyncio.Future _background_tasks: List[asyncio.Task] def __init__(self, driver: SupportedDriverType, settings: WriterSettings): self._driver = driver self._credentials = driver._credentials self._init_message = settings.create_init_request() - self._new_messages = asyncio.Queue() self._init_info = asyncio.Future() self._stream_connected = asyncio.Event() self._settings = settings - self._lock = asyncio.Lock() self._last_known_seq_no = 0 self._messages = deque() self._messages_future = deque() - self._stop_reason = None + self._new_messages = asyncio.Queue() + self._stop_reason = asyncio.Future() self._background_tasks = [ asyncio.create_task(self._connection_loop(), name="connection_loop") ] async def close(self): - await self._check_stop() - await self._stop(TopicWriterStopped()) + self._check_stop() + self._stop(TopicWriterStopped()) + + background_tasks = self._background_tasks + + for task in background_tasks: + task.cancel() + + await asyncio.wait(self._background_tasks) async def wait_init(self) -> PublicWriterInitInfo: - return await self._init_info + done, _ = await asyncio.wait( + [self._init_info, self._stop_reason], return_when=asyncio.FIRST_COMPLETED + ) + res = done.pop() # type: asyncio.Future + res_val = res.result() + + if isinstance(res_val, Exception): + raise res_val + + return res_val + + async def wait_stop(self) -> Exception: + return await self._stop_reason async def write_with_ack( self, messages: List[PublicMessage] ) -> List[asyncio.Future]: # todo check internal buffer limit - await self._check_stop() + self._check_stop() if self._settings.auto_seqno: await self.wait_init() - async with self._lock: - internal_messages = self._prepare_internal_messages_locked(messages) - messages_future = [asyncio.Future() for _ in internal_messages] + internal_messages = self._prepare_internal_messages(messages) + messages_future = [asyncio.Future() for _ in internal_messages] - self._messages.extend(internal_messages) - self._messages_future.extend(messages_future) + self._messages.extend(internal_messages) + self._messages_future.extend(messages_future) for m in internal_messages: self._new_messages.put_nowait(m) return messages_future - def _prepare_internal_messages_locked(self, messages: List[PublicMessage]): + def _prepare_internal_messages(self, messages: List[PublicMessage]): if self._settings.auto_created_at: now = datetime.datetime.now() else: @@ -263,10 +277,9 @@ def _prepare_internal_messages_locked(self, messages: List[PublicMessage]): return res - async def _check_stop(self): - async with self._lock: - if self._stop_reason is not None: - raise self._stop_reason + def _check_stop(self): + if self._stop_reason.done(): + raise self._stop_reason.result() async def _connection_loop(self): retry_settings = RetrySettings() # todo @@ -275,23 +288,16 @@ async def _connection_loop(self): attempt = 0 # todo calc and reset pending = [] - async def on_stop(e): - for t in pending: - self._background_tasks.append(t) - pending.clear() - await self._stop(e) - # noinspection PyBroadException try: stream_writer = await WriterAsyncIOStream.create( self._driver, self._init_message, self._get_token ) try: - async with self._lock: - self._last_known_seq_no = stream_writer.last_seqno - self._init_info.set_result( - PublicWriterInitInfo(last_seqno=stream_writer.last_seqno) - ) + self._last_known_seq_no = stream_writer.last_seqno + self._init_info.set_result( + PublicWriterInitInfo(last_seqno=stream_writer.last_seqno) + ) except asyncio.InvalidStateError: pass @@ -316,13 +322,13 @@ async def on_stop(e): err_info = check_retriable_error(err, retry_settings, attempt) if not err_info.is_retriable: - await on_stop(err) + self._stop(err) return await asyncio.sleep(err_info.sleep_timeout_seconds) - except Exception as e: - await on_stop(e) + except (asyncio.CancelledError, Exception) as err: + self._stop(err) return finally: if len(pending) > 0: @@ -333,11 +339,11 @@ async def on_stop(e): async def _read_loop(self, writer: "WriterAsyncIOStream"): while True: resp = await writer.receive() - async with self._lock: - for ack in resp.acks: - self._handle_receive_ack_need_lock(ack) - def _handle_receive_ack_need_lock(self, ack): + for ack in resp.acks: + self._handle_receive_ack(ack) + + def _handle_receive_ack(self, ack): current_message = self._messages.popleft() message_future = self._messages_future.popleft() if current_message.seq_no != ack.seq_no: @@ -351,8 +357,7 @@ def _handle_receive_ack_need_lock(self, ack): async def _send_loop(self, writer: "WriterAsyncIOStream"): try: - async with self._lock: - messages = list(self._messages) + messages = list(self._messages) last_seq_no = 0 for m in messages: @@ -364,24 +369,18 @@ async def _send_loop(self, writer: "WriterAsyncIOStream"): if m.seq_no > last_seq_no: writer.write([m]) except Exception as e: - await self._stop(e) + self._stop(e) finally: pass - async def _stop(self, reason: Exception): + def _stop(self, reason: Exception): if reason is None: raise Exception("writer stop reason can not be None") - async with self._lock: - if self._stop_reason is not None: - return - self._stop_reason = reason - background_tasks = self._background_tasks - - for task in background_tasks: - task.cancel() + if self._stop_reason.done(): + return - await asyncio.wait(self._background_tasks) + self._stop_reason.set_result(reason) def _get_token(self) -> str: raise NotImplementedError() From 1b95a9760b72eabec0dc3b89d49718959af05b34 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Wed, 15 Feb 2023 14:10:30 +0300 Subject: [PATCH 2/3] fix wait_init for raise exception better. --- ydb/_topic_writer/topic_writer_asyncio.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 217249a1..24cde408 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -208,7 +208,7 @@ async def wait_init(self) -> PublicWriterInitInfo: res = done.pop() # type: asyncio.Future res_val = res.result() - if isinstance(res_val, Exception): + if isinstance(res_val, BaseException): raise res_val return res_val From a4146ffab0f983080cc135437b3046ed7739f59c Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Wed, 15 Feb 2023 17:25:43 +0300 Subject: [PATCH 3/3] fix closed for internal writer --- ydb/_topic_writer/topic_writer_asyncio.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 24cde408..5d4583fc 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -157,6 +157,7 @@ async def wait_init(self) -> PublicWriterInitInfo: class WriterAsyncIOReconnector: + _closed: bool _credentials: Union[ydb.Credentials, None] _driver: ydb.aio.Driver _update_token_interval: int @@ -174,6 +175,7 @@ class WriterAsyncIOReconnector: _background_tasks: List[asyncio.Task] def __init__(self, driver: SupportedDriverType, settings: WriterSettings): + self._closed = False self._driver = driver self._credentials = driver._credentials self._init_message = settings.create_init_request() @@ -191,7 +193,11 @@ def __init__(self, driver: SupportedDriverType, settings: WriterSettings): ] async def close(self): - self._check_stop() + if self._closed: + return + + self._closed = True + self._stop(TopicWriterStopped()) background_tasks = self._background_tasks