From 0832e278ec5f91b54eb9d01d2dc6eff1025f010c Mon Sep 17 00:00:00 2001 From: Jude Date: Sun, 17 Nov 2024 09:04:49 +0000 Subject: [PATCH] Add functools.retry --- Lib/functools.py | 49 ++++++++++++++++++++++++++++- Lib/test/test_functools.py | 64 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+), 1 deletion(-) diff --git a/Lib/functools.py b/Lib/functools.py index eff6540c7f606e..030ad601a9ec82 100644 --- a/Lib/functools.py +++ b/Lib/functools.py @@ -12,7 +12,7 @@ __all__ = ['update_wrapper', 'wraps', 'WRAPPER_ASSIGNMENTS', 'WRAPPER_UPDATES', 'total_ordering', 'cache', 'cmp_to_key', 'lru_cache', 'reduce', 'partial', 'partialmethod', 'singledispatch', 'singledispatchmethod', - 'cached_property', 'Placeholder'] + 'cached_property', 'Placeholder', 'retry'] from abc import get_cache_token from collections import namedtuple @@ -21,6 +21,7 @@ from reprlib import recursive_repr from types import GenericAlias, MethodType, MappingProxyType, UnionType from _thread import RLock +from time import sleep ################################################################################ ### update_wrapper() and wraps() decorator @@ -1121,3 +1122,49 @@ def __get__(self, instance, owner=None): return val __class_getitem__ = classmethod(GenericAlias) + + +################################################################################ +### retry() - simple retry decorator +################################################################################ + +def retry(_kwargs=None, *, retry_attempts=3, interval_seconds=.1, backoff_type='linear'): + """ + This function is intended to be used as a decorator and will retry + the function that it decorates if an excpetion is raised in that + function. Several aspects of the retries can be configured with + keyword arguments. Also, no keyword arguments can be used to retry + with the default values. + + NOTE: if using backoff_type='exponential', ensure that + interval_seconds > 1, otherise the subsequent retries will be + shorter. + """ + + def _retry(user_function): + @wraps(user_function) + def _retry_wrapper_user_function(*args, **kwargs): + for attempt_number in range(retry_attempts+1): + try: + return_value = user_function(*args, **kwargs) + break + except Exception as e: + if attempt_number < retry_attempts: + if backoff_type == 'exponential': + # If user inputs interval_seconds < 1 with exponential backoff, retries will get shorter + sleep(interval_seconds * (attempt_number + 1)**2) + elif backoff_type == 'linear': + sleep(interval_seconds) + else: + # Retry attempts reached, raise the last exception + raise e + return return_value + return _retry_wrapper_user_function + + if backoff_type != 'linear' and backoff_type != 'exponential': + raise TypeError("Keyword argument backoff_type must be 'exponential' or 'linear'.") + + if _kwargs is None: + return _retry + else: + return _retry(_kwargs) diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index 6d60f6941c4c5d..0059073ee8ac51 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -3381,5 +3381,69 @@ def prop(self): self.assertEqual(t.prop, 1) +class TestRetry(unittest.TestCase): + + def test_function_fail(self): + + @functools.retry + def fail_function1(): + raise ValueError + + with self.assertRaises(ValueError): + fail_function1() + + @functools.retry(interval_seconds=.1, retry_attempts=1, backoff_type='exponential') + def fail_function2(): + raise ValueError + + with self.assertRaises(ValueError): + fail_function2() + + def test_function_success(self): + + @functools.retry + def success_function(a, b, c='test_value'): + return (a, b, c) + + value = success_function('a', 'b') + + self.assertEqual(value, ('a', 'b', 'test_value')) + + class TestObject: + def __init__(self, call_count): + self.call_count = call_count + test_object = TestObject(0) + + @functools.retry(interval_seconds=.01, retry_attempts=3, backoff_type='exponential') + def success_function_after_3_failures(test_object): + test_object.call_count += 1 + if test_object.call_count > 3: + return True + raise ValueError('Some error message!') + + value = success_function_after_3_failures(test_object) + self.assertTrue(value) + + def test_backoff_type(self): + @functools.retry(backoff_type='exponential') + def user_function1(): + return True + value1 = user_function1() + self.assertTrue(value1) + + @functools.retry(backoff_type='linear') + def user_function2(): + return True + value2 = user_function2() + self.assertTrue(value2) + + with self.assertRaises(TypeError): + @functools.retry(backoff_type='incorrect_value') + def user_function3(): + return True + + user_function3() + + if __name__ == '__main__': unittest.main()