diff --git a/tests/topics/test_topic_writer.py b/tests/topics/test_topic_writer.py index 327cb81e..3817e34d 100644 --- a/tests/topics/test_topic_writer.py +++ b/tests/topics/test_topic_writer.py @@ -1,4 +1,6 @@ from __future__ import annotations + +import asyncio from typing import List import pytest @@ -96,12 +98,39 @@ async def test_write_multi_message_with_ack( None, ], ) - async def test_write_encoded(self, driver: ydb.Driver, topic_path: str, codec): + async def test_write_encoded(self, driver: ydb.aio.Driver, topic_path: str, codec): async with driver.topic_client.writer(topic_path, codec=codec) as writer: await writer.write("a" * 1000) await writer.write("b" * 1000) await writer.write("c" * 1000) + async def test_create_writer_after_stop(self, driver: ydb.aio.Driver, topic_path: str): + await driver.stop() + with pytest.raises(ydb.Error): + async with driver.topic_client.writer(topic_path) as writer: + await writer.write_with_ack("123") + + async def test_send_message_after_stop(self, driver: ydb.aio.Driver, topic_path: str): + writer = driver.topic_client.writer(topic_path) + await driver.stop() + with pytest.raises(ydb.Error): + await writer.write_with_ack("123") + + async def test_preserve_exception_on_cm_close(self, driver: ydb.aio.Driver, topic_path: str): + class TestException(Exception): + pass + + with pytest.raises(TestException): + async with driver.topic_client.writer(topic_path) as writer: + await writer.wait_init() + await driver.stop() # will raise exception on topic writer __exit__ + + # ensure writer has exception internally + with pytest.raises((ydb.Error, asyncio.CancelledError)): + await writer.write_with_ack("123") + + raise TestException() + class TestTopicWriterSync: def test_send_message(self, driver_sync: ydb.Driver, topic_path): @@ -212,3 +241,30 @@ def test_start_many_sync_writers_in_parallel(self, driver_sync: ydb.Driver, topi for writer in writers: writer.close() + + def test_create_writer_after_stop(self, driver_sync: ydb.Driver, topic_path: str): + driver_sync.stop() + with pytest.raises(ydb.Error): + with driver_sync.topic_client.writer(topic_path) as writer: + writer.write_with_ack("123") + + def test_send_message_after_stop(self, driver_sync: ydb.Driver, topic_path: str): + writer = driver_sync.topic_client.writer(topic_path) + driver_sync.stop() + with pytest.raises(ydb.Error): + writer.write_with_ack("123") + + def test_preserve_exception_on_cm_close(self, driver_sync: ydb.Driver, topic_path: str): + class TestException(Exception): + pass + + with pytest.raises(TestException): + with driver_sync.topic_client.writer(topic_path) as writer: + writer.wait_init() + driver_sync.stop() # will raise exception on topic writer __exit__ + + # ensure writer has exception internally + with pytest.raises(ydb.Error): + writer.write_with_ack("123") + + raise TestException() diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 30ab9fb3..d83187fc 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -65,7 +65,11 @@ async def __aenter__(self) -> "WriterAsyncIO": return self async def __aexit__(self, exc_type, exc_val, exc_tb): - await self.close() + try: + await self.close() + except BaseException: + if exc_val is None: + raise def __del__(self): if self._closed or self._loop.is_closed(): @@ -330,7 +334,7 @@ def _prepare_internal_messages(self, messages: List[PublicMessage]) -> List[Inte def _check_stop(self): if self._stop_reason.done(): - raise self._stop_reason.result() + raise self._stop_reason.exception() async def _connection_loop(self): retry_settings = RetrySettings() # todo @@ -543,7 +547,7 @@ def _stop(self, reason: BaseException): if self._stop_reason.done(): return - self._stop_reason.set_result(reason) + self._stop_reason.set_exception(reason) for f in self._messages_future: f.set_exception(reason) diff --git a/ydb/_topic_writer/topic_writer_sync.py b/ydb/_topic_writer/topic_writer_sync.py index 43c4fec9..a5193caf 100644 --- a/ydb/_topic_writer/topic_writer_sync.py +++ b/ydb/_topic_writer/topic_writer_sync.py @@ -56,7 +56,11 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): - self.close() + try: + self.close() + except BaseException: + if exc_val is None: + raise def __del__(self): self.close(flush=False) diff --git a/ydb/aio/pool.py b/ydb/aio/pool.py index 08bfaacb..6e95dd6f 100644 --- a/ydb/aio/pool.py +++ b/ydb/aio/pool.py @@ -241,6 +241,8 @@ async def __call__( preferred_endpoint=None, fast_fail=False, ): + if self._stopped: + raise issues.Error("Driver was stopped") wait_timeout = settings.timeout if settings else 10 try: connection = await self._store.get(preferred_endpoint, fast_fail=fast_fail, wait_timeout=wait_timeout) diff --git a/ydb/pool.py b/ydb/pool.py index 736344e3..e0bf2f15 100644 --- a/ydb/pool.py +++ b/ydb/pool.py @@ -429,6 +429,9 @@ def __call__( :return: A result of computation """ + if self._stopped: + raise issues.Error("Driver was stopped") + tracing.trace(self.tracer, {"request": request, "stub": stub, "rpc_name": rpc_name}) try: connection = self._store.get(preferred_endpoint)