Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
* Fix exception for convert sync to async iterator
* Fixed start many sync writers/readers in parallel

## 3.3.0 ##
* Added support to set many topics and topic reader settings for read in one reader
* Added ydb.TopicWriterInitInfo, ydb.TopicWriteResult as public types
Expand Down
16 changes: 16 additions & 0 deletions tests/topics/test_topic_writer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from __future__ import annotations
from typing import List

import pytest

import ydb.aio
Expand Down Expand Up @@ -196,3 +199,16 @@ def test_write_encoded(self, driver_sync: ydb.Driver, topic_path: str, codec):
writer.write("a" * 1000)
writer.write("b" * 1000)
writer.write("c" * 1000)

def test_start_many_sync_writers_in_parallel(self, driver_sync: ydb.Driver, topic_path):
target_count = 100
writers = [] # type: List[ydb.TopicWriter]
for i in range(target_count):
writer = driver_sync.topic_client.writer(topic_path)
writers.append(writer)

for i, writer in enumerate(writers):
writer.write(str(i))

for writer in writers:
writer.close()
38 changes: 24 additions & 14 deletions ydb/_grpc/grpcwrapper/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import abc
import asyncio
import concurrent.futures
import contextvars
import datetime
import functools
Expand Down Expand Up @@ -111,19 +112,20 @@ def __next__(self):
return item


class SyncIteratorToAsyncIterator:
def __init__(self, sync_iterator: Iterator):
class SyncToAsyncIterator:
def __init__(self, sync_iterator: Iterator, executor: concurrent.futures.Executor):
self._sync_iterator = sync_iterator
self._executor = executor

def __aiter__(self):
return self

async def __anext__(self):
try:
res = await to_thread(self._sync_iterator.__next__)
res = await to_thread(self._sync_iterator.__next__, executor=self._executor)
return res
except StopAsyncIteration:
raise StopIteration()
except StopIteration:
raise StopAsyncIteration()


class IGrpcWrapperAsyncIO(abc.ABC):
Expand All @@ -149,12 +151,17 @@ class GrpcWrapperAsyncIO(IGrpcWrapperAsyncIO):
convert_server_grpc_to_wrapper: Callable[[Any], Any]
_connection_state: str
_stream_call: Optional[Union[grpc.aio.StreamStreamCall, "grpc._channel._MultiThreadedRendezvous"]]
_wait_executor: Optional[concurrent.futures.ThreadPoolExecutor]

def __init__(self, convert_server_grpc_to_wrapper):
self.from_client_grpc = asyncio.Queue()
self.convert_server_grpc_to_wrapper = convert_server_grpc_to_wrapper
self._connection_state = "new"
self._stream_call = None
self._wait_executor = None

def __del__(self):
self._clean_executor(wait=False)

async def start(self, driver: SupportedDriverType, stub, method):
if asyncio.iscoroutinefunction(driver.__call__):
Expand All @@ -168,6 +175,12 @@ def close(self):
if self._stream_call:
self._stream_call.cancel()

self._clean_executor(wait=True)

def _clean_executor(self, wait: bool):
if self._wait_executor:
self._wait_executor.shutdown(wait)

async def _start_asyncio_driver(self, driver: ydb.aio.Driver, stub, method):
requests_iterator = QueueToIteratorAsyncIO(self.from_client_grpc)
stream_call = await driver(
Expand All @@ -180,14 +193,11 @@ async def _start_asyncio_driver(self, driver: ydb.aio.Driver, stub, method):

async def _start_sync_driver(self, driver: ydb.Driver, stub, method):
requests_iterator = AsyncQueueToSyncIteratorAsyncIO(self.from_client_grpc)
stream_call = await to_thread(
driver,
requests_iterator,
stub,
method,
)
self._wait_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)

stream_call = await to_thread(driver, requests_iterator, stub, method, executor=self._wait_executor)
self._stream_call = stream_call
self.from_server_grpc = SyncIteratorToAsyncIterator(stream_call.__iter__())
self.from_server_grpc = SyncToAsyncIterator(stream_call.__iter__(), self._wait_executor)

async def receive(self) -> Any:
# todo handle grpc exceptions and convert it to internal exceptions
Expand Down Expand Up @@ -255,7 +265,7 @@ def callback_from_asyncio(callback: Union[Callable, Coroutine]) -> [asyncio.Futu
return loop.run_in_executor(None, callback)


async def to_thread(func, /, *args, **kwargs):
async def to_thread(func, *args, executor: Optional[concurrent.futures.Executor], **kwargs):
"""Asynchronously run function *func* in a separate thread.

Any *args and **kwargs supplied for this function are directly passed
Expand All @@ -271,7 +281,7 @@ async def to_thread(func, /, *args, **kwargs):
loop = asyncio.get_running_loop()
ctx = contextvars.copy_context()
func_call = functools.partial(ctx.run, func, *args, **kwargs)
return await loop.run_in_executor(None, func_call)
return await loop.run_in_executor(executor, func_call)


def proto_duration_from_timedelta(t: Optional[datetime.timedelta]) -> Optional[ProtoDuration]:
Expand Down