diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index f1e6c455..5d4583fc 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -40,7 +40,6 @@ class WriterAsyncIO: _loop: asyncio.AbstractEventLoop _reconnector: "WriterAsyncIOReconnector" - _lock: asyncio.Lock _closed: bool @property @@ -48,7 +47,6 @@ def last_seqno(self) -> int: raise NotImplementedError() def __init__(self, driver: SupportedDriverType, settings: PublicWriterSettings): - self._lock = asyncio.Lock() self._loop = asyncio.get_running_loop() self._closed = False self._reconnector = WriterAsyncIOReconnector( @@ -68,10 +66,10 @@ def __del__(self): self._loop.call_soon(self.close) async def close(self): - async with self._lock: - if self._closed: - return - self._closed = True + if self._closed: + return + + self._closed = True await self._reconnector.close() @@ -159,70 +157,92 @@ async def wait_init(self) -> PublicWriterInitInfo: class WriterAsyncIOReconnector: + _closed: bool _credentials: Union[ydb.Credentials, None] _driver: ydb.aio.Driver _update_token_interval: int _token_get_function: TokenGetterFuncType _init_message: StreamWriteMessage.InitRequest - _new_messages: asyncio.Queue _init_info: asyncio.Future _stream_connected: asyncio.Event _settings: WriterSettings - _lock: asyncio.Lock _last_known_seq_no: int _messages: Deque[InternalMessage] _messages_future: Deque[asyncio.Future] - _stop_reason: Optional[Exception] + _new_messages: asyncio.Queue + _stop_reason: asyncio.Future _background_tasks: List[asyncio.Task] def __init__(self, driver: SupportedDriverType, settings: WriterSettings): + self._closed = False self._driver = driver self._credentials = driver._credentials self._init_message = settings.create_init_request() - self._new_messages = asyncio.Queue() self._init_info = asyncio.Future() self._stream_connected = asyncio.Event() self._settings = settings - self._lock = asyncio.Lock() self._last_known_seq_no = 0 self._messages = deque() self._messages_future = deque() - self._stop_reason = None + self._new_messages = asyncio.Queue() + self._stop_reason = asyncio.Future() self._background_tasks = [ asyncio.create_task(self._connection_loop(), name="connection_loop") ] async def close(self): - await self._check_stop() - await self._stop(TopicWriterStopped()) + if self._closed: + return + + self._closed = True + + self._stop(TopicWriterStopped()) + + background_tasks = self._background_tasks + + for task in background_tasks: + task.cancel() + + await asyncio.wait(self._background_tasks) async def wait_init(self) -> PublicWriterInitInfo: - return await self._init_info + done, _ = await asyncio.wait( + [self._init_info, self._stop_reason], return_when=asyncio.FIRST_COMPLETED + ) + res = done.pop() # type: asyncio.Future + res_val = res.result() + + if isinstance(res_val, BaseException): + raise res_val + + return res_val + + async def wait_stop(self) -> Exception: + return await self._stop_reason async def write_with_ack( self, messages: List[PublicMessage] ) -> List[asyncio.Future]: # todo check internal buffer limit - await self._check_stop() + self._check_stop() if self._settings.auto_seqno: await self.wait_init() - async with self._lock: - internal_messages = self._prepare_internal_messages_locked(messages) - messages_future = [asyncio.Future() for _ in internal_messages] + internal_messages = self._prepare_internal_messages(messages) + messages_future = [asyncio.Future() for _ in internal_messages] - self._messages.extend(internal_messages) - self._messages_future.extend(messages_future) + self._messages.extend(internal_messages) + self._messages_future.extend(messages_future) for m in internal_messages: self._new_messages.put_nowait(m) return messages_future - def _prepare_internal_messages_locked(self, messages: List[PublicMessage]): + def _prepare_internal_messages(self, messages: List[PublicMessage]): if self._settings.auto_created_at: now = datetime.datetime.now() else: @@ -263,10 +283,9 @@ def _prepare_internal_messages_locked(self, messages: List[PublicMessage]): return res - async def _check_stop(self): - async with self._lock: - if self._stop_reason is not None: - raise self._stop_reason + def _check_stop(self): + if self._stop_reason.done(): + raise self._stop_reason.result() async def _connection_loop(self): retry_settings = RetrySettings() # todo @@ -275,23 +294,16 @@ async def _connection_loop(self): attempt = 0 # todo calc and reset pending = [] - async def on_stop(e): - for t in pending: - self._background_tasks.append(t) - pending.clear() - await self._stop(e) - # noinspection PyBroadException try: stream_writer = await WriterAsyncIOStream.create( self._driver, self._init_message, self._get_token ) try: - async with self._lock: - self._last_known_seq_no = stream_writer.last_seqno - self._init_info.set_result( - PublicWriterInitInfo(last_seqno=stream_writer.last_seqno) - ) + self._last_known_seq_no = stream_writer.last_seqno + self._init_info.set_result( + PublicWriterInitInfo(last_seqno=stream_writer.last_seqno) + ) except asyncio.InvalidStateError: pass @@ -316,13 +328,13 @@ async def on_stop(e): err_info = check_retriable_error(err, retry_settings, attempt) if not err_info.is_retriable: - await on_stop(err) + self._stop(err) return await asyncio.sleep(err_info.sleep_timeout_seconds) - except Exception as e: - await on_stop(e) + except (asyncio.CancelledError, Exception) as err: + self._stop(err) return finally: if len(pending) > 0: @@ -333,11 +345,11 @@ async def on_stop(e): async def _read_loop(self, writer: "WriterAsyncIOStream"): while True: resp = await writer.receive() - async with self._lock: - for ack in resp.acks: - self._handle_receive_ack_need_lock(ack) - def _handle_receive_ack_need_lock(self, ack): + for ack in resp.acks: + self._handle_receive_ack(ack) + + def _handle_receive_ack(self, ack): current_message = self._messages.popleft() message_future = self._messages_future.popleft() if current_message.seq_no != ack.seq_no: @@ -351,8 +363,7 @@ def _handle_receive_ack_need_lock(self, ack): async def _send_loop(self, writer: "WriterAsyncIOStream"): try: - async with self._lock: - messages = list(self._messages) + messages = list(self._messages) last_seq_no = 0 for m in messages: @@ -364,24 +375,18 @@ async def _send_loop(self, writer: "WriterAsyncIOStream"): if m.seq_no > last_seq_no: writer.write([m]) except Exception as e: - await self._stop(e) + self._stop(e) finally: pass - async def _stop(self, reason: Exception): + def _stop(self, reason: Exception): if reason is None: raise Exception("writer stop reason can not be None") - async with self._lock: - if self._stop_reason is not None: - return - self._stop_reason = reason - background_tasks = self._background_tasks - - for task in background_tasks: - task.cancel() + if self._stop_reason.done(): + return - await asyncio.wait(self._background_tasks) + self._stop_reason.set_result(reason) def _get_token(self) -> str: raise NotImplementedError()