Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
* Close grpc streams while closing readers/writers
* Add control plane operations for topic api: create, drop

## 3.0.1b4 ##
Expand Down
1 change: 1 addition & 0 deletions tests/topics/test_topic_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ async def test_read_message(
reader = driver.topic_client.topic_reader(topic_consumer, topic_path)

assert await reader.receive_batch() is not None
await reader.close()
33 changes: 24 additions & 9 deletions ydb/_grpc/grpcwrapper/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ class UnknownGrpcMessageError(issues.Error):
pass


_stop_grpc_connection_marker = object()


class QueueToIteratorAsyncIO:
__slots__ = ("_queue",)

Expand All @@ -79,10 +82,10 @@ def __aiter__(self):
return self

async def __anext__(self):
try:
return await self._queue.get()
except asyncio.QueueEmpty:
item = await self._queue.get()
if item is _stop_grpc_connection_marker:
raise StopAsyncIteration()
return item


class AsyncQueueToSyncIteratorAsyncIO:
Expand All @@ -100,13 +103,10 @@ def __iter__(self):
return self

def __next__(self):
try:
res = asyncio.run_coroutine_threadsafe(
self._queue.get(), self._loop
).result()
return res
except asyncio.QueueEmpty:
item = asyncio.run_coroutine_threadsafe(self._queue.get(), self._loop).result()
if item is _stop_grpc_connection_marker:
raise StopIteration()
return item


class SyncIteratorToAsyncIterator:
Expand All @@ -133,6 +133,10 @@ async def receive(self) -> Any:
def write(self, wrap_message: IToProto):
...

@abc.abstractmethod
def close(self):
...


SupportedDriverType = Union[ydb.Driver, ydb.aio.Driver]

Expand All @@ -142,11 +146,15 @@ class GrpcWrapperAsyncIO(IGrpcWrapperAsyncIO):
from_server_grpc: AsyncIterator
convert_server_grpc_to_wrapper: Callable[[Any], Any]
_connection_state: str
_stream_call: Optional[
Union[grpc.aio.StreamStreamCall, "grpc._channel._MultiThreadedRendezvous"]
]

def __init__(self, convert_server_grpc_to_wrapper):
self.from_client_grpc = asyncio.Queue()
self.convert_server_grpc_to_wrapper = convert_server_grpc_to_wrapper
self._connection_state = "new"
self._stream_call = None

async def start(self, driver: SupportedDriverType, stub, method):
if asyncio.iscoroutinefunction(driver.__call__):
Expand All @@ -155,13 +163,19 @@ async def start(self, driver: SupportedDriverType, stub, method):
await self._start_sync_driver(driver, stub, method)
self._connection_state = "started"

def close(self):
self.from_client_grpc.put_nowait(_stop_grpc_connection_marker)
if self._stream_call:
self._stream_call.cancel()

async def _start_asyncio_driver(self, driver: ydb.aio.Driver, stub, method):
requests_iterator = QueueToIteratorAsyncIO(self.from_client_grpc)
stream_call = await driver(
requests_iterator,
stub,
method,
)
self._stream_call = stream_call
self.from_server_grpc = stream_call.__aiter__()

async def _start_sync_driver(self, driver: ydb.Driver, stub, method):
Expand All @@ -172,6 +186,7 @@ async def _start_sync_driver(self, driver: ydb.Driver, stub, method):
stub,
method,
)
self._stream_call = stream_call
self.from_server_grpc = SyncIteratorToAsyncIterator(stream_call.__iter__())

async def receive(self) -> Any:
Expand Down
16 changes: 16 additions & 0 deletions ydb/_topic_common/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,36 @@
class StreamMock(IGrpcWrapperAsyncIO):
from_server: asyncio.Queue
from_client: asyncio.Queue
_closed: bool

def __init__(self):
self.from_server = asyncio.Queue()
self.from_client = asyncio.Queue()
self._closed = False

async def receive(self) -> typing.Any:
if self._closed:
raise Exception("read from closed StreamMock")

item = await self.from_server.get()
if item is None:
raise StopAsyncIteration()
if isinstance(item, Exception):
raise item
return item

def write(self, wrap_message: IToProto):
if self._closed:
raise Exception("write to closed StreamMock")
self.from_client.put_nowait(wrap_message)

def close(self):
if self._closed:
return

self._closed = True
self.from_server.put_nowait(None)


async def wait_condition(f: typing.Callable[[], bool], timeout=1):
start = time.monotonic()
Expand Down
1 change: 1 addition & 0 deletions ydb/_topic_reader/topic_reader_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,7 @@ async def close(self):
self._closed = True
self._set_first_error(TopicReaderStreamClosedError())
self._state_changed.set()
self._stream.close()

for task in self._background_tasks:
task.cancel()
Expand Down
7 changes: 7 additions & 0 deletions ydb/_topic_writer/topic_writer_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ async def _connection_loop(self):
pending = []

# noinspection PyBroadException
stream_writer = None
try:
stream_writer = await WriterAsyncIOStream.create(
self._driver, self._init_message, self._get_token
Expand All @@ -322,6 +323,7 @@ async def _connection_loop(self):
done, pending = await asyncio.wait(
[send_loop, receive_loop], return_when=asyncio.FIRST_COMPLETED
)
stream_writer.close()
done.pop().result()
except issues.Error as err:
# todo log error
Expand All @@ -338,6 +340,8 @@ async def _connection_loop(self):
self._stop(err)
return
finally:
if stream_writer:
stream_writer.close()
if len(pending) > 0:
for task in pending:
task.cancel()
Expand Down Expand Up @@ -417,6 +421,9 @@ def __init__(
):
self._token_getter = token_getter

def close(self):
self._stream.close()

@staticmethod
async def create(
driver: SupportedDriverType,
Expand Down
17 changes: 17 additions & 0 deletions ydb/_topic_writer/topic_writer_asyncio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,20 +158,37 @@ class StreamWriterMock:
from_client: asyncio.Queue
from_server: asyncio.Queue

_closed: bool

def __init__(self):
self.last_seqno = 0
self.from_server = asyncio.Queue()
self.from_client = asyncio.Queue()
self._closed = False

def write(self, messages: typing.List[InternalMessage]):
if self._closed:
raise Exception("write to closed StreamWriterMock")

self.from_client.put_nowait(messages)

async def receive(self) -> StreamWriteMessage.WriteResponse:
if self._closed:
raise Exception("read from closed StreamWriterMock")

item = await self.from_server.get()
if isinstance(item, Exception):
raise item
return item

def close(self):
if self._closed:
return

self.from_server.put_nowait(
Exception("waited message while StreamWriterMock closed")
)

@pytest.fixture(autouse=True)
async def stream_writer_double_queue(self, monkeypatch):
class DoubleQueueWriters:
Expand Down