Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 62 additions & 57 deletions ydb/_topic_writer/topic_writer_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,13 @@
class WriterAsyncIO:
_loop: asyncio.AbstractEventLoop
_reconnector: "WriterAsyncIOReconnector"
_lock: asyncio.Lock
_closed: bool

@property
def last_seqno(self) -> int:
raise NotImplementedError()

def __init__(self, driver: SupportedDriverType, settings: PublicWriterSettings):
self._lock = asyncio.Lock()
self._loop = asyncio.get_running_loop()
self._closed = False
self._reconnector = WriterAsyncIOReconnector(
Expand All @@ -68,10 +66,10 @@ def __del__(self):
self._loop.call_soon(self.close)

async def close(self):
async with self._lock:
if self._closed:
return
self._closed = True
if self._closed:
return

self._closed = True

await self._reconnector.close()

Expand Down Expand Up @@ -159,70 +157,92 @@ async def wait_init(self) -> PublicWriterInitInfo:


class WriterAsyncIOReconnector:
_closed: bool
_credentials: Union[ydb.Credentials, None]
_driver: ydb.aio.Driver
_update_token_interval: int
_token_get_function: TokenGetterFuncType
_init_message: StreamWriteMessage.InitRequest
_new_messages: asyncio.Queue
_init_info: asyncio.Future
_stream_connected: asyncio.Event
_settings: WriterSettings

_lock: asyncio.Lock
_last_known_seq_no: int
_messages: Deque[InternalMessage]
_messages_future: Deque[asyncio.Future]
_stop_reason: Optional[Exception]
_new_messages: asyncio.Queue
_stop_reason: asyncio.Future
_background_tasks: List[asyncio.Task]

def __init__(self, driver: SupportedDriverType, settings: WriterSettings):
self._closed = False
self._driver = driver
self._credentials = driver._credentials
self._init_message = settings.create_init_request()
self._new_messages = asyncio.Queue()
self._init_info = asyncio.Future()
self._stream_connected = asyncio.Event()
self._settings = settings

self._lock = asyncio.Lock()
self._last_known_seq_no = 0
self._messages = deque()
self._messages_future = deque()
self._stop_reason = None
self._new_messages = asyncio.Queue()
self._stop_reason = asyncio.Future()
self._background_tasks = [
asyncio.create_task(self._connection_loop(), name="connection_loop")
]

async def close(self):
await self._check_stop()
await self._stop(TopicWriterStopped())
if self._closed:
return

self._closed = True

self._stop(TopicWriterStopped())

background_tasks = self._background_tasks

for task in background_tasks:
task.cancel()

await asyncio.wait(self._background_tasks)

async def wait_init(self) -> PublicWriterInitInfo:
return await self._init_info
done, _ = await asyncio.wait(
[self._init_info, self._stop_reason], return_when=asyncio.FIRST_COMPLETED
)
res = done.pop() # type: asyncio.Future
res_val = res.result()

if isinstance(res_val, BaseException):
raise res_val

return res_val

async def wait_stop(self) -> Exception:
return await self._stop_reason

async def write_with_ack(
self, messages: List[PublicMessage]
) -> List[asyncio.Future]:
# todo check internal buffer limit
await self._check_stop()
self._check_stop()

if self._settings.auto_seqno:
await self.wait_init()

async with self._lock:
internal_messages = self._prepare_internal_messages_locked(messages)
messages_future = [asyncio.Future() for _ in internal_messages]
internal_messages = self._prepare_internal_messages(messages)
messages_future = [asyncio.Future() for _ in internal_messages]

self._messages.extend(internal_messages)
self._messages_future.extend(messages_future)
self._messages.extend(internal_messages)
self._messages_future.extend(messages_future)

for m in internal_messages:
self._new_messages.put_nowait(m)

return messages_future

def _prepare_internal_messages_locked(self, messages: List[PublicMessage]):
def _prepare_internal_messages(self, messages: List[PublicMessage]):
if self._settings.auto_created_at:
now = datetime.datetime.now()
else:
Expand Down Expand Up @@ -263,10 +283,9 @@ def _prepare_internal_messages_locked(self, messages: List[PublicMessage]):

return res

async def _check_stop(self):
async with self._lock:
if self._stop_reason is not None:
raise self._stop_reason
def _check_stop(self):
if self._stop_reason.done():
raise self._stop_reason.result()

async def _connection_loop(self):
retry_settings = RetrySettings() # todo
Expand All @@ -275,23 +294,16 @@ async def _connection_loop(self):
attempt = 0 # todo calc and reset
pending = []

async def on_stop(e):
for t in pending:
self._background_tasks.append(t)
pending.clear()
await self._stop(e)

# noinspection PyBroadException
try:
stream_writer = await WriterAsyncIOStream.create(
self._driver, self._init_message, self._get_token
)
try:
async with self._lock:
self._last_known_seq_no = stream_writer.last_seqno
self._init_info.set_result(
PublicWriterInitInfo(last_seqno=stream_writer.last_seqno)
)
self._last_known_seq_no = stream_writer.last_seqno
self._init_info.set_result(
PublicWriterInitInfo(last_seqno=stream_writer.last_seqno)
)
except asyncio.InvalidStateError:
pass

Expand All @@ -316,13 +328,13 @@ async def on_stop(e):

err_info = check_retriable_error(err, retry_settings, attempt)
if not err_info.is_retriable:
await on_stop(err)
self._stop(err)
return

await asyncio.sleep(err_info.sleep_timeout_seconds)

except Exception as e:
await on_stop(e)
except (asyncio.CancelledError, Exception) as err:
self._stop(err)
return
finally:
if len(pending) > 0:
Expand All @@ -333,11 +345,11 @@ async def on_stop(e):
async def _read_loop(self, writer: "WriterAsyncIOStream"):
while True:
resp = await writer.receive()
async with self._lock:
for ack in resp.acks:
self._handle_receive_ack_need_lock(ack)

def _handle_receive_ack_need_lock(self, ack):
for ack in resp.acks:
self._handle_receive_ack(ack)

def _handle_receive_ack(self, ack):
current_message = self._messages.popleft()
message_future = self._messages_future.popleft()
if current_message.seq_no != ack.seq_no:
Expand All @@ -351,8 +363,7 @@ def _handle_receive_ack_need_lock(self, ack):

async def _send_loop(self, writer: "WriterAsyncIOStream"):
try:
async with self._lock:
messages = list(self._messages)
messages = list(self._messages)

last_seq_no = 0
for m in messages:
Expand All @@ -364,24 +375,18 @@ async def _send_loop(self, writer: "WriterAsyncIOStream"):
if m.seq_no > last_seq_no:
writer.write([m])
except Exception as e:
await self._stop(e)
self._stop(e)
finally:
pass

async def _stop(self, reason: Exception):
def _stop(self, reason: Exception):
if reason is None:
raise Exception("writer stop reason can not be None")

async with self._lock:
if self._stop_reason is not None:
return
self._stop_reason = reason
background_tasks = self._background_tasks

for task in background_tasks:
task.cancel()
if self._stop_reason.done():
return

await asyncio.wait(self._background_tasks)
self._stop_reason.set_result(reason)

def _get_token(self) -> str:
raise NotImplementedError()
Expand Down