diff --git a/CHANGELOG.md b/CHANGELOG.md index 61f06737..fd306185 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,4 @@ +* Close grpc streams while closing readers/writers * Add control plane operations for topic api: create, drop ## 3.0.1b4 ## diff --git a/tests/topics/test_topic_reader.py b/tests/topics/test_topic_reader.py index 6d87fc0b..ac338fbd 100644 --- a/tests/topics/test_topic_reader.py +++ b/tests/topics/test_topic_reader.py @@ -9,3 +9,4 @@ async def test_read_message( reader = driver.topic_client.topic_reader(topic_consumer, topic_path) assert await reader.receive_batch() is not None + await reader.close() diff --git a/ydb/_grpc/grpcwrapper/common_utils.py b/ydb/_grpc/grpcwrapper/common_utils.py index 3a2e6c2b..1e56ad05 100644 --- a/ydb/_grpc/grpcwrapper/common_utils.py +++ b/ydb/_grpc/grpcwrapper/common_utils.py @@ -69,6 +69,9 @@ class UnknownGrpcMessageError(issues.Error): pass +_stop_grpc_connection_marker = object() + + class QueueToIteratorAsyncIO: __slots__ = ("_queue",) @@ -79,10 +82,10 @@ def __aiter__(self): return self async def __anext__(self): - try: - return await self._queue.get() - except asyncio.QueueEmpty: + item = await self._queue.get() + if item is _stop_grpc_connection_marker: raise StopAsyncIteration() + return item class AsyncQueueToSyncIteratorAsyncIO: @@ -100,13 +103,10 @@ def __iter__(self): return self def __next__(self): - try: - res = asyncio.run_coroutine_threadsafe( - self._queue.get(), self._loop - ).result() - return res - except asyncio.QueueEmpty: + item = asyncio.run_coroutine_threadsafe(self._queue.get(), self._loop).result() + if item is _stop_grpc_connection_marker: raise StopIteration() + return item class SyncIteratorToAsyncIterator: @@ -133,6 +133,10 @@ async def receive(self) -> Any: def write(self, wrap_message: IToProto): ... + @abc.abstractmethod + def close(self): + ... + SupportedDriverType = Union[ydb.Driver, ydb.aio.Driver] @@ -142,11 +146,15 @@ class GrpcWrapperAsyncIO(IGrpcWrapperAsyncIO): from_server_grpc: AsyncIterator convert_server_grpc_to_wrapper: Callable[[Any], Any] _connection_state: str + _stream_call: Optional[ + Union[grpc.aio.StreamStreamCall, "grpc._channel._MultiThreadedRendezvous"] + ] def __init__(self, convert_server_grpc_to_wrapper): self.from_client_grpc = asyncio.Queue() self.convert_server_grpc_to_wrapper = convert_server_grpc_to_wrapper self._connection_state = "new" + self._stream_call = None async def start(self, driver: SupportedDriverType, stub, method): if asyncio.iscoroutinefunction(driver.__call__): @@ -155,6 +163,11 @@ async def start(self, driver: SupportedDriverType, stub, method): await self._start_sync_driver(driver, stub, method) self._connection_state = "started" + def close(self): + self.from_client_grpc.put_nowait(_stop_grpc_connection_marker) + if self._stream_call: + self._stream_call.cancel() + async def _start_asyncio_driver(self, driver: ydb.aio.Driver, stub, method): requests_iterator = QueueToIteratorAsyncIO(self.from_client_grpc) stream_call = await driver( @@ -162,6 +175,7 @@ async def _start_asyncio_driver(self, driver: ydb.aio.Driver, stub, method): stub, method, ) + self._stream_call = stream_call self.from_server_grpc = stream_call.__aiter__() async def _start_sync_driver(self, driver: ydb.Driver, stub, method): @@ -172,6 +186,7 @@ async def _start_sync_driver(self, driver: ydb.Driver, stub, method): stub, method, ) + self._stream_call = stream_call self.from_server_grpc = SyncIteratorToAsyncIterator(stream_call.__iter__()) async def receive(self) -> Any: diff --git a/ydb/_topic_common/test_helpers.py b/ydb/_topic_common/test_helpers.py index bea6fea5..9023f759 100644 --- a/ydb/_topic_common/test_helpers.py +++ b/ydb/_topic_common/test_helpers.py @@ -8,20 +8,36 @@ class StreamMock(IGrpcWrapperAsyncIO): from_server: asyncio.Queue from_client: asyncio.Queue + _closed: bool def __init__(self): self.from_server = asyncio.Queue() self.from_client = asyncio.Queue() + self._closed = False async def receive(self) -> typing.Any: + if self._closed: + raise Exception("read from closed StreamMock") + item = await self.from_server.get() + if item is None: + raise StopAsyncIteration() if isinstance(item, Exception): raise item return item def write(self, wrap_message: IToProto): + if self._closed: + raise Exception("write to closed StreamMock") self.from_client.put_nowait(wrap_message) + def close(self): + if self._closed: + return + + self._closed = True + self.from_server.put_nowait(None) + async def wait_condition(f: typing.Callable[[], bool], timeout=1): start = time.monotonic() diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index 95bd1008..a3f792de 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -496,6 +496,7 @@ async def close(self): self._closed = True self._set_first_error(TopicReaderStreamClosedError()) self._state_changed.set() + self._stream.close() for task in self._background_tasks: task.cancel() diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 669507d8..c0ef2491 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -296,6 +296,7 @@ async def _connection_loop(self): pending = [] # noinspection PyBroadException + stream_writer = None try: stream_writer = await WriterAsyncIOStream.create( self._driver, self._init_message, self._get_token @@ -322,6 +323,7 @@ async def _connection_loop(self): done, pending = await asyncio.wait( [send_loop, receive_loop], return_when=asyncio.FIRST_COMPLETED ) + stream_writer.close() done.pop().result() except issues.Error as err: # todo log error @@ -338,6 +340,8 @@ async def _connection_loop(self): self._stop(err) return finally: + if stream_writer: + stream_writer.close() if len(pending) > 0: for task in pending: task.cancel() @@ -417,6 +421,9 @@ def __init__( ): self._token_getter = token_getter + def close(self): + self._stream.close() + @staticmethod async def create( driver: SupportedDriverType, diff --git a/ydb/_topic_writer/topic_writer_asyncio_test.py b/ydb/_topic_writer/topic_writer_asyncio_test.py index 6658adbd..1c96097f 100644 --- a/ydb/_topic_writer/topic_writer_asyncio_test.py +++ b/ydb/_topic_writer/topic_writer_asyncio_test.py @@ -158,20 +158,37 @@ class StreamWriterMock: from_client: asyncio.Queue from_server: asyncio.Queue + _closed: bool + def __init__(self): self.last_seqno = 0 self.from_server = asyncio.Queue() self.from_client = asyncio.Queue() + self._closed = False def write(self, messages: typing.List[InternalMessage]): + if self._closed: + raise Exception("write to closed StreamWriterMock") + self.from_client.put_nowait(messages) async def receive(self) -> StreamWriteMessage.WriteResponse: + if self._closed: + raise Exception("read from closed StreamWriterMock") + item = await self.from_server.get() if isinstance(item, Exception): raise item return item + def close(self): + if self._closed: + return + + self.from_server.put_nowait( + Exception("waited message while StreamWriterMock closed") + ) + @pytest.fixture(autouse=True) async def stream_writer_double_queue(self, monkeypatch): class DoubleQueueWriters: