Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion Lib/functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
__all__ = ['update_wrapper', 'wraps', 'WRAPPER_ASSIGNMENTS', 'WRAPPER_UPDATES',
'total_ordering', 'cache', 'cmp_to_key', 'lru_cache', 'reduce',
'partial', 'partialmethod', 'singledispatch', 'singledispatchmethod',
'cached_property', 'Placeholder']
'cached_property', 'Placeholder', 'retry']

from abc import get_cache_token
from collections import namedtuple
Expand All @@ -21,6 +21,7 @@
from reprlib import recursive_repr
from types import GenericAlias, MethodType, MappingProxyType, UnionType
from _thread import RLock
from time import sleep

################################################################################
### update_wrapper() and wraps() decorator
Expand Down Expand Up @@ -1121,3 +1122,49 @@ def __get__(self, instance, owner=None):
return val

__class_getitem__ = classmethod(GenericAlias)


################################################################################
### retry() - simple retry decorator
################################################################################

def retry(_kwargs=None, *, retry_attempts=3, interval_seconds=.1, backoff_type='linear'):
"""
This function is intended to be used as a decorator and will retry
the function that it decorates if an excpetion is raised in that
function. Several aspects of the retries can be configured with
keyword arguments. Also, no keyword arguments can be used to retry
with the default values.

NOTE: if using backoff_type='exponential', ensure that
interval_seconds > 1, otherise the subsequent retries will be
shorter.
"""

def _retry(user_function):
@wraps(user_function)
def _retry_wrapper_user_function(*args, **kwargs):
for attempt_number in range(retry_attempts+1):
try:
return_value = user_function(*args, **kwargs)
break
except Exception as e:
if attempt_number < retry_attempts:
if backoff_type == 'exponential':
# If user inputs interval_seconds < 1 with exponential backoff, retries will get shorter
sleep(interval_seconds * (attempt_number + 1)**2)
elif backoff_type == 'linear':
sleep(interval_seconds)
else:
# Retry attempts reached, raise the last exception
raise e
return return_value
return _retry_wrapper_user_function

if backoff_type != 'linear' and backoff_type != 'exponential':
raise TypeError("Keyword argument backoff_type must be 'exponential' or 'linear'.")

if _kwargs is None:
return _retry
else:
return _retry(_kwargs)
64 changes: 64 additions & 0 deletions Lib/test/test_functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3381,5 +3381,69 @@ def prop(self):
self.assertEqual(t.prop, 1)


class TestRetry(unittest.TestCase):

def test_function_fail(self):

@functools.retry
def fail_function1():
raise ValueError

with self.assertRaises(ValueError):
fail_function1()

@functools.retry(interval_seconds=.1, retry_attempts=1, backoff_type='exponential')
def fail_function2():
raise ValueError

with self.assertRaises(ValueError):
fail_function2()

def test_function_success(self):

@functools.retry
def success_function(a, b, c='test_value'):
return (a, b, c)

value = success_function('a', 'b')

self.assertEqual(value, ('a', 'b', 'test_value'))

class TestObject:
def __init__(self, call_count):
self.call_count = call_count
test_object = TestObject(0)

@functools.retry(interval_seconds=.01, retry_attempts=3, backoff_type='exponential')
def success_function_after_3_failures(test_object):
test_object.call_count += 1
if test_object.call_count > 3:
return True
raise ValueError('Some error message!')

value = success_function_after_3_failures(test_object)
self.assertTrue(value)

def test_backoff_type(self):
@functools.retry(backoff_type='exponential')
def user_function1():
return True
value1 = user_function1()
self.assertTrue(value1)

@functools.retry(backoff_type='linear')
def user_function2():
return True
value2 = user_function2()
self.assertTrue(value2)

with self.assertRaises(TypeError):
@functools.retry(backoff_type='incorrect_value')
def user_function3():
return True

user_function3()


if __name__ == '__main__':
unittest.main()
Loading