diff --git a/snippets/deprecate.py b/snippets/deprecate.py new file mode 100644 index 0000000..3aecc60 --- /dev/null +++ b/snippets/deprecate.py @@ -0,0 +1,181 @@ +""" +A utility class for deprecating code. +""" + +from copy import copy +import functools +import types +import warnings + + +class Deprecator: + """ + Decorator class to mark functions and methods as deprecated with a uniform + warning message at the time the function is called. The message has the + form + + {function_name} is deprecated: {message}. It is not guaranteed to be in + service after {version}. + + unless `pending=True` was given. Then the message will be + + {function_name} will be deprecated in a future version: {message}. + + If message and version are not initialized or given during the decorating + call the respective parts are left out from the message. + + >>> deprecate = Deprecator() + >>> @deprecate + ... def foo(a, b): + ... pass + >>> foo(1, 2) + DeprecationWarning: __main__.foo is deprecated + + >>> @deprecate("use bar() instead") + ... def foo(a, b): + ... pass + >>> foo(1, 2) + DeprecationWarning: __main__.foo is deprecated: use bar instead + + >>> @deprecate("use bar() instead", version="0.4.0") + ... def foo(a, b): + ... pass + >>> foo(1, 2) + DeprecationWarning: __main__.foo is deprecated: use bar instead. It is not + guaranteed to be in service in vers. 0.4.0 + + >>> deprecate = Deprecator(message="I say no!", version="0.5.0") + >>> @deprecate + ... def foo(a, b): + ... pass + >>> foo(1, 2) + DeprecationWarning: __main__.foo is deprecated: I say no! It is not + guaranteed to be in service in vers. 0.5.0 + + Alternatively the decorator can also be called with `arguments` set to a dictionary + mapping names of keyword arguments to deprecation messages. In this case the + warning will only be emitted when the decorated function is called with arguments + in that dictionary. + + >>> deprecate = Deprecator() + >>> @deprecate(arguments={"bar": "use baz instead."}) + ... def foo(bar=None, baz=None): + ... pass + >>> foo(baz=True) + >>> foo(bar=True) + DeprecationWarning: __main__.foo(bar=True) is deprecated: use baz instead. + + As a short-cut, it is also possible to pass the values in the arguments dict + directly as keyword arguments to the decorator. + + >>> @deprecate(bar="use baz instead.") + ... def foo(bar=None, baz=None): + ... pass + >>> foo(baz=True) + >>> foo(bar=True) + DeprecationWarning: __main__.foo(bar=True) is deprecated: use baz instead. + """ + + def __init__(self, message=None, version=None, pending=False): + """ + Initialize default values for deprecation message and version. + + Args: + message (str): default deprecation message + version (str): default version after which the function might be removed + pending (bool): only warn about future deprecation, warning category will + be PendingDeprecationWarning instead of DeprecationWarning + """ + self.message = message + self.version = version + self.category = PendingDeprecationWarning if pending else DeprecationWarning + + def __copy__(self): + cp = type(self)(message=self.message, version=self.version) + cp.category = self.category + return cp + + def __call__(self, message=None, version=None, arguments=None, **kwargs): + depr = copy(self) + if isinstance(message, types.FunctionType): + return depr.__deprecate_function(message) + else: + depr.message = message + depr.version = version + depr.arguments = arguments if arguments is not None else {} + depr.arguments.update(kwargs) + return depr.wrap + + def _build_message(self): + if self.category == PendingDeprecationWarning: + message_format = "{} will be deprecated" + else: + message_format = "{} is deprecated" + + if self.message is not None: + message_format += ": {}.".format(self.message) + else: + message_format += "." + + if self.version is not None: + message_format += ( + " It is not guaranteed to be in service in vers. {}".format( + self.version + ) + ) + + return message_format + + def __deprecate_function(self, function): + message = self._build_message().format( + "{}.{}".format(function.__module__, function.__name__) + ) + + @functools.wraps(function) + def decorated(*args, **kwargs): + warnings.warn(message, category=self.category, stacklevel=2) + return function(*args, **kwargs) + + return decorated + + def __deprecate_argument(self, function): + message_format = self._build_message() + + @functools.wraps(function) + def decorated(*args, **kwargs): + for kw in kwargs: + if kw in self.arguments: + warnings.warn( + message_format.format( + "{}.{}({}={})".format( + function.__module__, function.__name__, kw, kwargs[kw] + ) + ), + category=self.category, + stacklevel=2, + ) + return function(*args, **kwargs) + + return decorated + + def wrap(self, function): + """ + Wrap the given function to emit a DeprecationWarning at call time. The warning + message is constructed from the given message and version. If + :attr:`.arguments` is set then the warning is only emitted, when the decorated + function is called with keyword arguments found in that dictionary. + + Args: + function (callable): function to mark as deprecated + + Return: + function: raises DeprecationWarning when given function is called + """ + if not self.arguments: + return self.__deprecate_function(function) + else: + return self.__deprecate_argument(function) + + +deprecate = Deprecator() +deprecate_soon = Deprecator(pending=True) diff --git a/snippets/import_alarm.py b/snippets/import_alarm.py new file mode 100644 index 0000000..826902e --- /dev/null +++ b/snippets/import_alarm.py @@ -0,0 +1,97 @@ +""" +Graceful failure for missing optional dependencies. +""" + +import functools +import warnings + + +class ImportAlarmError(ImportError): + """To be raised in addition to warning under test conditions""" + + +class ImportAlarm: + """ + This class allows you to fail gracefully when some object has optional dependencies + and the user does not have those dependencies installed. + + Example: + + >>> try: + ... from mystery_package import Enigma, Puzzle, Conundrum + ... import_alarm = ImportAlarm() + >>> except ImportError: + >>> import_alarm = ImportAlarm( + ... "MysteryJob relies on mystery_package, but this was unavailable. Please ensure your python environment " + ... "has access to mystery_package, e.g. with `conda install -c conda-forge mystery_package`" + ... ) + ... + >>> class MysteryJob: + ... @import_alarm + ... def __init__(self, project, job_name) + ... super().__init__() + ... self.riddles = [Enigma(), Puzzle(), Conundrum()] + + This class is also a context manager that can be used as a short-cut, like this: + + >>> with ImportAlarm( + ... "MysteryJob relies on mystery_package, but this was unavailable." + ... ) as import_alarm: + ... import mystery_package + + If you do not use `import_alarm` as a decorator, but only to get a consistent + warning message, call :meth:`.warn_if_failed()` after the with statement. + + >>> import_alarm.warn_if_failed() + """ + + def __init__(self, message=None, _fail_on_warning: bool = False): + """ + Initialize message value. + + Args: + message (str): What to say alongside your ImportError when the decorated + function is called. (Default is None, which says nothing and raises no + error.) + """ + self.message = message + # Catching warnings in tests can be janky, so instead open a flag for failing + # instead. + self._fail_on_warning = _fail_on_warning + + def __call__(self, func): + return self.wrapper(func) + + def wrapper(self, function): + @functools.wraps(function) + def decorator(*args, **kwargs): + self.warn_if_failed() + return function(*args, **kwargs) + + return decorator + + def warn_if_failed(self): + """ + Print warning message if import has failed. In case you are not using + :class:`ImportAlarm` as a decorator you can call this method manually to + trigger the warning. + """ + if self.message is not None: + warnings.warn(self.message, category=ImportWarning) + if self._fail_on_warning: + raise ImportAlarmError(self.message) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + if exc_type is None and exc_value is None and traceback is None: + # import successful, so silence our warning + self.message = None + return + if issubclass(exc_type, ImportError): + # import broken; retain message, but suppress error + return True + else: + # unrelated error during import, re-raise + return False diff --git a/snippets/retry.py b/snippets/retry.py new file mode 100644 index 0000000..92a4af2 --- /dev/null +++ b/snippets/retry.py @@ -0,0 +1,67 @@ +""" + +""" + +from itertools import count +import time +from typing import Callable, Optional, Type, TypeVar, Tuple, Union +import warnings + +T = TypeVar("T") + + +def retry( + func: Callable[[], T], + error: Union[Type[Exception], Tuple[Type[Exception], ...]], + msg: str, + at_most: Optional[int] = None, + delay: float = 1.0, + delay_factor: float = 1.0, + log: bool | object = True, +) -> T: + """ + Try to call `func` until it no longer raises `error`. + + Any other exception besides `error` is still raised. + + Args: + func (callable): function to call, should take no arguments + error (Exception or tuple thereof): any exceptions to be caught + msg (str): messing to be written to the log if `error` occurs. + at_most (int, optional): retry at most this many times, None means retry + forever + delay (float): time to wait between retries in seconds + delay_factor (float): multiply `delay` between retries by this factor + logger (bool|object): Whether to pass a message to `warnings.warn` on each + retry. (Default is True.) Optionally, an object with a :meth:`warn` method + can be passed and the message will be sent there instead + (e.g. `snippets.logger.logger`). + + Raises: + `error`: if `at_most` is exceeded the last error is re-raised + Exception: any exception raised by `func` that does not match `error` + + Returns: + object: whatever is returned by `func` + """ + if at_most is None: + tries = count() + else: + tries = range(at_most) + for i in tries: + try: + return func() + except error as e: + warning = f"{msg} Trying again in {delay}s. Tried {i + 1} times so far..." + if isinstance(log, bool): + if log: + warnings.warn(warning) + else: + log.warn(warning) + time.sleep(delay) + delay *= delay_factor + # e drops out of the namespace after the except clause ends, so + # assign it here to a dummy variable so that we can re-raise it + # in case the error persists + err = e + raise err from None diff --git a/tests/unit/test_deprecate.py b/tests/unit/test_deprecate.py new file mode 100644 index 0000000..69b0d02 --- /dev/null +++ b/tests/unit/test_deprecate.py @@ -0,0 +1,130 @@ +import unittest +import warnings + +from snippets.deprecate import deprecate, deprecate_soon + + +class TestDeprecator(unittest.TestCase): + def test_deprecate(self): + """Function decorated with `deprecate` should raise a warning.""" + + @deprecate + def foo(a): + return 2 * a + + @deprecate("use baz instead", version="0.2.0") + def bar(a): + return 4 * a + + with warnings.catch_warnings(record=True) as w: + self.assertEqual( + foo(1), 2, "Decorated function does not return original " "return value" + ) + self.assertTrue(len(w) > 0, "No warning raised!") + self.assertEqual( + w[0].category, + DeprecationWarning, + "Raised warning is not a DeprecationWarning", + ) + + with warnings.catch_warnings(record=True) as w: + self.assertEqual( + bar(1), 4, "Decorated function does not return original " "return value" + ) + + expected_message = ( + "use baz instead. It is not guaranteed to be in " "service in vers. 0.2.0" + ) + self.assertTrue( + w[0].message.args[0].endswith(expected_message), + "Warning message does not reflect decorator arguments.", + ) + + @deprecate_soon + def baz(a): + return 3 * a + + with warnings.catch_warnings(record=True) as w: + self.assertEqual( + baz(1), 3, "Decorated function does not return original " "return value" + ) + self.assertEqual( + w[0].category, + PendingDeprecationWarning, + "Raised warning is not a PendingDeprecationWarning", + ) + + def test_deprecate_args(self): + """DeprecationWarning should only be raised when the given arguments occur.""" + + @deprecate(arguments={"bar": "use foo instead"}) + def foo(a, foo=None, bar=None): + return 2 * a + + with warnings.catch_warnings(record=True) as w: + self.assertEqual( + foo(1, bar=True), + 2, + "Decorated function does not return original " "return value", + ) + self.assertTrue(len(w) > 0, "No warning raised!") + + with warnings.catch_warnings(record=True) as w: + self.assertEqual( + foo(1, foo=True), + 2, + "Decorated function does not return original " "return value", + ) + self.assertEqual( + len(w), 0, "Warning raised, but deprecated argument was not given." + ) + + def test_deprecate_kwargs(self): + """ + DeprecationWarning should only be raised when the given arguments occur, also + when given via kwargs. + """ + + @deprecate(bar="use baz instead") + def foo(a, bar=None, baz=None): + return 2 * a + + with warnings.catch_warnings(record=True) as w: + self.assertEqual( + foo(1, bar=True), + 2, + "Decorated function does not return original " "return value", + ) + self.assertTrue(len(w) > 0, "No warning raised!") + + with warnings.catch_warnings(record=True) as w: + self.assertEqual( + foo(1, baz=True), + 2, + "Decorated function does not return original " "return value", + ) + self.assertEqual( + len(w), 0, "Warning raised, but deprecated argument was not given." + ) + + def test_instances(self): + """ + Subsequent calls to a Deprecator instance must not interfere with each other. + """ + + @deprecate(bar="use baz instead") + def foo(bar=None, baz=None): + pass + + @deprecate(baz="use bar instead") + def food(bar=None, baz=None): + pass + + with warnings.catch_warnings(record=True) as w: + foo(bar=True) + food(baz=True) + self.assertEqual(len(w), 2, "Not all warnings preserved.") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_import_alarm.py b/tests/unit/test_import_alarm.py new file mode 100644 index 0000000..57d1007 --- /dev/null +++ b/tests/unit/test_import_alarm.py @@ -0,0 +1,82 @@ +import unittest +from snippets.import_alarm import ImportAlarm, ImportAlarmError + + +class TestImportAlarm(unittest.TestCase): + + def test_instance(self): + no_alarm = ImportAlarm(_fail_on_warning=True) + + @no_alarm + def add_one(x): + return x + 1 + + yes_alarm = ImportAlarm( + "Here is a message", + _fail_on_warning=True + ) + + @yes_alarm + def subtract_one(x): + return x + 1 + + try: + self.assertEqual( + 1, + add_one(0), + msg="Wrapped function should return the same return value." + ) + except ImportAlarmError: + self.fail("Without a message, the import alarm should not raise a warning (an " + "exception in this case, because of the private flag)") + with self.assertRaises( + ImportAlarmError, + msg="With a message, the import alarm should raise a warning. (an " + "exception in this case, because of the private flag)" + ): + subtract_one(0) + + def test_context(self): + with ImportAlarm( + "Working import", + _fail_on_warning=True + ) as alarm_working: + # Suppose all the imports here pass fine + pass + + with ImportAlarm( + "Broken import", + _fail_on_warning=True + ) as alarm_broken: + raise ImportError("Suppose a package imported here is not available") + + @alarm_working + def add_two(x): + return x + 2 + + @alarm_broken + def add_three(x): + return x + 3 + + self.assertEqual( + 2, + add_two(0), + msg="Without a message, no warning (exception here) should be raised" + ) + + with self.assertRaises( + ImportAlarmError, + msg="With a message, a warning (exception here) should be raised" + ): + add_three(0) + + def test_scope(self): + with self.assertRaises( + ZeroDivisionError, + msg="Context manager should not silence unrelated exceptions" + ), ImportAlarm("Unrelated"): + print(1 / 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_retry.py b/tests/unit/test_retry.py new file mode 100644 index 0000000..6c68569 --- /dev/null +++ b/tests/unit/test_retry.py @@ -0,0 +1,60 @@ +import unittest + +from snippets.retry import retry + + +class TestRetry(unittest.TestCase): + def test_return_value(self): + """retry should return the exact value that the function returns.""" + + def func(): + return 42 + + self.assertEqual( + func(), + retry(func, error=ValueError, msg=""), + "retry returned a different value!", + ) + + def test_unrelated_exception(self): + """retry should not catch exception that are not explicitely passed.""" + + def func(): + raise ValueError() + + with self.assertRaises( + ValueError, msg="retry caught an exception it was not supposed to!" + ): + retry(func, error=TypeError, msg="") + + def test_exception(self): + """retry should catch explicitly passed exceptions.""" + + class Func: + """Small helper to simulate a stateful function.""" + + def __init__(self): + self.n = 0 + + def __call__(self): + self.n += 1 + if self.n < 4: + raise ValueError(self.n) + else: + return self.n + + func = Func() + try: + retry(func, error=ValueError, msg="", delay=1e-6) + except ValueError: + self.fail("retry did not catch exception!") + + func = Func() + with self.assertRaises( + ValueError, msg="retry did re-raise exception after insufficient tries!" + ): + retry(func, error=ValueError, msg="", at_most=2, delay=1e-6) + + +if __name__ == "__main__": + unittest.main()