diff --git a/tests/topics/test_topic_reader.py b/tests/topics/test_topic_reader.py index c59deca3..2227a739 100644 --- a/tests/topics/test_topic_reader.py +++ b/tests/topics/test_topic_reader.py @@ -96,6 +96,32 @@ async def test_commit_offset_with_session_id_works(self, driver, topic_with_mess msg2 = await reader.receive_message() assert msg2.seqno == 2 + async def test_commit_offset_retry_on_ydb_errors(self, driver, topic_with_messages, topic_consumer, monkeypatch): + async with driver.topic_client.reader(topic_with_messages, topic_consumer) as reader: + message = await reader.receive_message() + + call_count = 0 + original_driver_call = driver.topic_client._driver + + async def mock_driver_call(*args, **kwargs): + nonlocal call_count + call_count += 1 + + if call_count == 1: + raise ydb.Unavailable("Service temporarily unavailable") + elif call_count == 2: + raise ydb.Cancelled("Operation was cancelled") + else: + return await original_driver_call(*args, **kwargs) + + monkeypatch.setattr(driver.topic_client, "_driver", mock_driver_call) + + await driver.topic_client.commit_offset( + topic_with_messages, topic_consumer, message.partition_id, message.offset + 1 + ) + + assert call_count == 3 + async def test_reader_reconnect_after_commit_offset(self, driver, topic_with_messages, topic_consumer): async with driver.topic_client.reader(topic_with_messages, topic_consumer) as reader: for out in ["123", "456", "789", "0"]: @@ -257,6 +283,33 @@ def test_commit_offset_with_session_id_works(self, driver_sync, topic_with_messa msg2 = reader.receive_message() assert msg2.seqno == 2 + def test_commit_offset_retry_on_ydb_errors(self, driver_sync, topic_with_messages, topic_consumer, monkeypatch): + with driver_sync.topic_client.reader(topic_with_messages, topic_consumer) as reader: + message = reader.receive_message() + + # Counter to track retry attempts + call_count = 0 + original_driver_call = driver_sync.topic_client._driver + + def mock_driver_call(*args, **kwargs): + nonlocal call_count + call_count += 1 + + if call_count == 1: + raise ydb.Unavailable("Service temporarily unavailable") + elif call_count == 2: + raise ydb.Cancelled("Operation was cancelled") + else: + return original_driver_call(*args, **kwargs) + + monkeypatch.setattr(driver_sync.topic_client, "_driver", mock_driver_call) + + driver_sync.topic_client.commit_offset( + topic_with_messages, topic_consumer, message.partition_id, message.offset + 1 + ) + + assert call_count == 3 + def test_reader_reconnect_after_commit_offset(self, driver_sync, topic_with_messages, topic_consumer): with driver_sync.topic_client.reader(topic_with_messages, topic_consumer) as reader: for out in ["123", "456", "789", "0"]: diff --git a/ydb/retries.py b/ydb/retries.py index 5331f1b0..75332768 100644 --- a/ydb/retries.py +++ b/ydb/retries.py @@ -1,4 +1,6 @@ import asyncio +import functools +import inspect import random import time @@ -161,3 +163,64 @@ async def retry_operation_async(callee, retry_settings=None, *args, **kwargs): return await next_opt.result except BaseException as e: # pylint: disable=W0703 next_opt.set_exception(e) + + +def ydb_retry( + max_retries=10, + max_session_acquire_timeout=None, + on_ydb_error_callback=None, + backoff_ceiling=6, + backoff_slot_duration=1, + get_session_client_timeout=5, + fast_backoff_settings=None, + slow_backoff_settings=None, + idempotent=False, + retry_cancelled=False, +): + """ + Decorator for automatic function retry in case of YDB errors. + + Supports both synchronous and asynchronous functions. + + :param max_retries: Maximum number of retries (default: 10) + :param max_session_acquire_timeout: Maximum session acquisition timeout (default: None) + :param on_ydb_error_callback: Callback for handling YDB errors (default: None) + :param backoff_ceiling: Ceiling for backoff algorithm (default: 6) + :param backoff_slot_duration: Slot duration for backoff (default: 1) + :param get_session_client_timeout: Session client timeout (default: 5) + :param fast_backoff_settings: Fast backoff settings (default: None) + :param slow_backoff_settings: Slow backoff settings (default: None) + :param idempotent: Whether the operation is idempotent (default: False) + :param retry_cancelled: Whether to retry cancelled operations (default: False) + """ + + def decorator(func): + retry_settings = RetrySettings( + max_retries=max_retries, + max_session_acquire_timeout=max_session_acquire_timeout, + on_ydb_error_callback=on_ydb_error_callback, + backoff_ceiling=backoff_ceiling, + backoff_slot_duration=backoff_slot_duration, + get_session_client_timeout=get_session_client_timeout, + fast_backoff_settings=fast_backoff_settings, + slow_backoff_settings=slow_backoff_settings, + idempotent=idempotent, + retry_cancelled=retry_cancelled, + ) + + if inspect.iscoroutinefunction(func): + + @functools.wraps(func) + async def async_wrapper(*args, **kwargs): + return await retry_operation_async(func, retry_settings, *args, **kwargs) + + return async_wrapper + else: + + @functools.wraps(func) + def sync_wrapper(*args, **kwargs): + return retry_operation_sync(func, retry_settings, *args, **kwargs) + + return sync_wrapper + + return decorator diff --git a/ydb/topic.py b/ydb/topic.py index 5e86be68..f457b7dc 100644 --- a/ydb/topic.py +++ b/ydb/topic.py @@ -98,6 +98,8 @@ PublicAlterAutoPartitioningSettings as TopicAlterAutoPartitioningSettings, ) +from .retries import ydb_retry + logger = logging.getLogger(__name__) @@ -348,6 +350,7 @@ def tx_writer( return TopicTxWriterAsyncIO(tx=tx, driver=self._driver, settings=settings, _client=self) + @ydb_retry(retry_cancelled=True, idempotent=True) async def commit_offset( self, path: str, consumer: str, partition_id: int, offset: int, read_session_id: Optional[str] = None ) -> None: @@ -645,6 +648,7 @@ def tx_writer( return TopicTxWriter(tx, self._driver, settings, _parent=self) + @ydb_retry(retry_cancelled=True, idempotent=True) def commit_offset( self, path: str, consumer: str, partition_id: int, offset: int, read_session_id: Optional[str] = None ) -> None: