diff --git a/README.rst b/README.rst index e4efda9..6621758 100644 --- a/README.rst +++ b/README.rst @@ -16,6 +16,190 @@ argument, create your function accordingly. >>> from methoddispatch import singledispatch, register, SingleDispatch + >>> @singledispatch + ... def fun(arg, verbose=False): + ... if verbose: + ... print("Let me just say,", end=" ") + ... print(arg) + + To add overloaded implementations to the function, use the + + >>> @singledispatchmethod + ... def fun(arg, verbose=False): + ... if verbose: + ... print("Let me just say,", end=" ") + ... print(arg) + + To add overloaded implementations to the function, use the + + >>> @singledispatchmethod + ... def fun(arg, verbose=False): + ... if verbose: + ... print("Let me just say,", end=" ") + ... print(arg) + + To add overloaded implementations to the function, use the + + >>> @singledispatch + ... def fun(arg, verbose=False): + ... if verbose: + ... print("Let me just say,", end=" ") + ... print(arg) + + To add overloaded implementations to the function, use the + + >>> @singledispatch + ... def fun(arg, verbose=False): + ... if verbose: + ... print("Let me just say,", end=" ") + ... print(arg) + + To add overloaded implementations to the function, use the + + >>> @singledispatch + ... def fun(arg, verbose=False): + ... if verbose: + ... print("Let me just say,", end=" ") + ... print(arg) + + To add overloaded implementations to the function, use the + + >>> @singledispatchmethod + ... def fun(arg, verbose=False): + ... if verbose: + ... print("Let me just say,", end=" ") + ... print(arg) + + To add overloaded implementations to the function, use the + + >>> @singledispatch + ... def fun(arg, verbose=False): + ... if verbose: + ... print("Let me just say,", end=" ") + ... print(arg) + + To add overloaded implementations to the function, use the + + >>> @singledispatch + ... def fun(arg, verbose=False): + ... if verbose: + ... print("Let me just say,", end=" ") + ... print(arg) + + To add overloaded implementations to the function, use the + + >>> @singledispatch + ... def fun(arg, verbose=False): + ... if verbose: + ... print("Let me just say,", end=" ") + ... print(arg) + + To add overloaded implementations to the function, use the + + >>> @singledispatchmethod + ... def fun(arg, verbose=False): + ... if verbose: + ... print("Let me just say,", end=" ") + ... print(arg) + + To add overloaded implementations to the function, use the + + >>> @singledispatch + ... def fun(arg, verbose=False): + ... if verbose: + ... print("Let me just say,", end=" ") + ... print(arg) + + To add overloaded implementations to the function, use the + + >>> @singledispatch + ... def fun(arg, verbose=False): + ... if verbose: + ... print("Let me just say,", end=" ") + ... print(arg) + + To add overloaded implementations to the function, use the + + >>> @singledispatchmethod + ... def fun(arg, verbose=False): + ... if verbose: + ... print("Let me just say,", end=" ") + ... print(arg) + + To add overloaded implementations to the function, use the + + >>> @singledispatch + ... def fun(arg, verbose=False): + ... if verbose: + ... print("Let me just say,", end=" ") + ... print(arg) + + To add overloaded implementations to the function, use the + + >>> @singledispatch + ... def fun(arg, verbose=False): + ... if verbose: + ... print("Let me just say,", end=" ") + ... print(arg) + + To add overloaded implementations to the function, use the + + >>> @singledispatch + ... def fun(arg, verbose=False): + ... if verbose: + ... print("Let me just say,", end=" ") + ... print(arg) + + To add overloaded implementations to the function, use the + + >>> @singledispatchmethod + ... def fun(arg, verbose=False): + ... if verbose: + ... print("Let me just say,", end=" ") + ... print(arg) + + To add overloaded implementations to the function, use the + + >>> @singledispatch + ... def fun(arg, verbose=False): + ... if verbose: + ... print("Let me just say,", end=" ") + ... print(arg) + + To add overloaded implementations to the function, use the + + >>> @singledispatch + ... def fun(arg, verbose=False): + ... if verbose: + ... print("Let me just say,", end=" ") + ... print(arg) + + To add overloaded implementations to the function, use the + + >>> @singledispatch + ... def fun(arg, verbose=False): + ... if verbose: + ... print("Let me just say,", end=" ") + ... print(arg) + + To add overloaded implementations to the function, use the + + >>> @singledispatchmethod + ... def fun(arg, verbose=False): + ... if verbose: + ... print("Let me just say,", end=" ") + ... print(arg) + + To add overloaded implementations to the function, use the + + >>> @singledispatchmethod + ... def fun(arg, verbose=False): + ... if verbose: + ... print("Let me just say,", end=" ") + ... print(arg) + + To add overloaded implementations to the function, use the + >>> @singledispatch ... def fun(arg, verbose=False): ... if verbose: @@ -191,11 +375,11 @@ shown below:: ... def foo_float(self, bar: float): ... return 'float' -In Python 3.6 and earlier, the ``SingleDispatch`` class uses a +In Python 3.5 and earlier, the ``SingleDispatch`` class uses a meta-class ``SingleDispatchMeta`` to manage the dispatch registries. However in Python 3.6 and later the ``__init_subclass__`` method is used instead. If your class also inherits from an ABC interface you can use -the ``SingleDispatchABCMeta`` metaclass in Python 3.6 and earlier. +the ``SingleDispatchABCMeta`` metaclass in Python 3.5 and earlier. Finally, accessing the method ``foo`` via a class will use the dispatch registry for that class:: diff --git a/methoddispatch/__init__.py b/methoddispatch/__init__.py index 58f1571..1b93a84 100644 --- a/methoddispatch/__init__.py +++ b/methoddispatch/__init__.py @@ -5,94 +5,90 @@ [![Build Status](https://travis-ci.com/seequent/methoddispatch.svg?branch=master)](https://travis-ci.com/seequent/methoddispatch) Python 3.4 added the ``singledispatch`` decorator to the ``functools`` standard library module. -This library extends this functionality to instance methods (and works for functions too). +This library adds this functionality to instance methods. To define a generic method , decorate it with the ``@singledispatch`` decorator. Note that the dispatch happens on the type of the first argument, create your function accordingly. - - >>> from methoddispatch import singledispatch, register, SingleDispatch - - >>> @singledispatch - ... def fun(arg, verbose=False): - ... if verbose: - ... print("Let me just say,", end=" ") - ... print(arg) - To add overloaded implementations to the function, use the ``register()`` attribute of the generic function. It is a decorator, taking a type parameter and decorating a function implementing the operation for that type +The ``register()`` attribute returns the undecorated function which enables decorator stacking, pickling, as well as creating unit tests for each variant independently - >>> @fun.register(int) - ... def _(arg, verbose=False): - ... if verbose: - ... print("Strength in numbers, eh?", end=" ") - ... print(arg) + >>> from methoddispatch import singledispatch, register, SingleDispatch + >>> from decimal import Decimal + >>> class MyClass(SingleDispatch): + ... @singledispatch + ... def fun(self, arg, verbose=False): + ... if verbose: + ... print("Let me just say,", end=" ") + ... print(arg) ... - >>> @fun.register(list) - ... def _(arg, verbose=False): - ... if verbose: - ... print("Enumerate this:") - ... for i, elem in enumerate(arg): - ... print(i, elem) - -To enable registering lambdas and pre-existing functions, the ``register()`` attribute can be used in a functional form:: - - >>> def nothing(arg, verbose=False): - ... print("Nothing.") + ... @fun.register(int) + ... def fun_int(self, arg, verbose=False): + ... if verbose: + ... print("Strength in numbers, eh?", end=" ") + ... print(arg) ... - >>> fun.register(type(None), nothing) - + ... @fun.register(list) + ... def fun_list(self, arg, verbose=False): + ... if verbose: + ... print("Enumerate this:") + ... for i, elem in enumerate(arg): + ... print(i, elem) + ... + ... @fun.register(float) + ... @fun.register(Decimal) + ... def fun_num(obj, arg, verbose=False): + ... if verbose: + ... print("Half of your number:", end=" ") + ... print(arg / 2) -The ``register()`` attribute returns the undecorated function which enables decorator stacking, pickling, as well as creating unit tests for each variant independently +The ``register()`` attribute only works inside a class statement, relying on ``SingleDispatch.__init_subclass__`` +to create the actual dispatch table. This also means that (unlike functools.singledispatch) two methods +with the same name cannot be registered as only the last one will be in the class dictionary. - >>> from decimal import Decimal - >>> @fun.register(float) - ... @fun.register(Decimal) - ... def fun_num(arg, verbose=False): - ... if verbose: - ... print("Half of your number:", end=" ") - ... print(arg / 2) - ... - >>> fun_num is fun - False +Functions not defined in the class can be registered using the ``add_overload`` attribute. + + >>> def nothing(obj, arg, verbose=False): + ... print('Nothing.') + >>> MyClass.fun.add_overload(type(None), nothing) When called, the generic function dispatches on the type of the first argument:: - >>> fun("Hello, world.") + >>> a = MyClass() + >>> a.fun("Hello, world.") Hello, world. - >>> fun("test.", verbose=True) + >>> a.fun("test.", verbose=True) Let me just say, test. - >>> fun(42, verbose=True) + >>> a.fun(42, verbose=True) Strength in numbers, eh? 42 - >>> fun(['spam', 'spam', 'eggs', 'spam'], verbose=True) + >>> a.fun(['spam', 'spam', 'eggs', 'spam'], verbose=True) Enumerate this: 0 spam 1 spam 2 eggs 3 spam - >>> fun(None) + >>> a.fun(None) Nothing. - >>> fun(1.23) + >>> a.fun(1.23) 0.615 Where there is no registered implementation for a specific type, its method resolution order is used to find a more generic implementation. The original function decorated with ``@singledispatch`` is registered for the base ``object`` type, which means it is used if no better implementation is found. To check which implementation will the generic function choose for a given type, use the ``dispatch()`` attribute:: - >>> fun.dispatch(float) - - >>> fun.dispatch(dict) # note: default implementation - + >>> a.fun.dispatch(float) + + >>> a.fun.dispatch(dict) # note: default implementation + To access all registered implementations, use the read-only ``registry`` attribute:: - >>> fun.registry.keys() + >>> a.fun.registry.keys() dict_keys([, , , , , ]) - >>> fun.registry[float] - - >>> fun.registry[object] - - -Decorating class methods requires the class to inherit from ``SingleDispatch`` + >>> a.fun.registry[float] + + >>> a.fun.registry[object] + >>> class BaseClass(SingleDispatch): ... @singledispatch @@ -110,14 +106,15 @@ 'int' Subclasses can extend the type registry of the function on the base class with their own overrides. -Because we do not want to modify the base class ``foo`` registry the ``methoddispatch.register`` decorator must be used instead of ``foo.register``. The module level ``register`` function takes either the method name or the method itself as the first parameter and the dispatch type as the second. +Because we do not want to modify the base class ``foo`` registry the ``foo.overload`` decorator must be used +instead of ``foo.register``. >>> class SubClass(BaseClass): - ... @register('foo', float) + ... @BaseClass.foo.register(float) ... def foo_float(self, bar): ... return 'float' ... - ... @register(BaseClass.foo, str) + ... @BaseClass.foo.register(str) ... def foo_str(self, bar): ... return 'str' ... @@ -146,6 +143,19 @@ However, providing the register decorator with the same type will also work. Decorating a method override with a different type (not a good idea) will register the different type and leave the base-class handler in place for the orginal type. +Method overrides can be specified on individual instances if necessary + + >>> def foo_set(obj, bar): + ... return 'set' + >>> b = BaseClass() + >>> b.foo.register(set, foo_set) + + >>> b.foo(set()) + 'set' + >>> b2 = BaseClass() + >>> b2.foo(set()) + 'default' + In Python 3.6 and later, for functions annotated with types, the decorator will infer the type of the first argument automatically as shown below >>> class BaseClassAnno(SingleDispatch): @@ -158,7 +168,7 @@ ... return 'int' ... >>> class SubClassAnno(BaseClassAnno): - ... @register('foo') + ... @BaseClassAnno.foo.register ... def foo_float(self, bar: float): ... return 'float' @@ -185,7 +195,7 @@ from .methoddispatch3 import * del methoddispatch3 -__version__ = '2.0.1' +__version__ = '3.0.0' __author__ = 'Seequent' __license__ = 'BSD' -__copyright__ = 'Copyright 2018 Seequent' +__copyright__ = 'Copyright 2019 Seequent' diff --git a/methoddispatch/methoddispatch2.py b/methoddispatch/methoddispatch2.py index d136ad4..de0ab73 100644 --- a/methoddispatch/methoddispatch2.py +++ b/methoddispatch/methoddispatch2.py @@ -5,6 +5,7 @@ from functools import update_wrapper from UserDict import UserDict from weakref import WeakKeyDictionary +import warnings class MappingProxyType(UserDict): @@ -21,11 +22,6 @@ def get_cache_token(): 'SingleDispatch', 'SingleDispatchABC'] -################################################################################ -### singledispatch() - single-dispatch generic function decorator -################################################################################ - - def _c3_merge(sequences): """Merges MROs in *sequences* to a single MRO using the C3 algorithm. @@ -210,23 +206,29 @@ def dispatch(self, cls): self.dispatch_cache[cls] = impl return impl - def register(self, cls, func=None): - """register(cls, func) -> func + def add_overload(self, cls, func=None): + """add_overload(cls, func) -> func (private) Registers a new implementation for the given *cls* on a *generic_func*. + """ if func is None: - return lambda f: self.register(cls, f) + return lambda f: self.add_overload(cls, f) self._registry[cls] = func if self.cache_token is None and hasattr(cls, '__abstractmethods__'): self.cache_token = get_cache_token() self.dispatch_cache.clear() return func - def __get__(self, instance, cls=None): - if cls is not None and not isinstance(cls, SingleDispatchMeta): - raise ValueError('singledispatch can only be used on methods of SingleDispatchMeta types') - wrapper = sd_method(self, instance) + def __get__(self, instance, owner=None): + if owner is not None and not isinstance(owner, SingleDispatchMeta): + raise ValueError('singledispatch can only be used on methods of SingleDispatch subclasses') + if instance is not None: + wrapper = BoundSDMethod(self, instance) + elif owner is not None: + wrapper = UnboundSDMethod(self) + else: + return self update_wrapper(wrapper, self.func) return wrapper @@ -244,12 +246,69 @@ def copy(self): def get_registered_types(self): return [type_ for type_ in self._registry.keys() if type_ is not object] + def register(self, cls, func=None): + """ Decorator for methods to register an overload on generic method. + :param cls: is the type to register or may be omitted or None to use the annotated parameter type. + """ + if func is None: + return lambda f: self.register(cls, f) + + overloads = getattr(func, '_overloads', []) + overloads.append((self.__name__, cls)) + func._overloads = overloads + return func + + +class BoundSDMethod(object): + """ A bounf singledispatch method """ -class sd_method(object): - """ A singledispatch method """ def __init__(self, s_d, instance): self._instance = instance self._s_d = s_d + self._instance_sd = instance.__dict__.get('__' + s_d.__name__, None) + + def copy(self): + if self._instance_sd is not None: + return self._instance_sd.copy() + else: + return self._s_d.copy() + + def dispatch(self, cls): + if self._instance_sd is not None: + return self._instance_sd.dispatch(cls) + else: + return self._s_d.dispatch(cls) + + @property + def registry(self): + if self._instance_sd is not None: + return self._instance_sd.registry + else: + return self._s_d.registry + + def __call__(self, *args, **kwargs): + return self.dispatch(args[0].__class__)(self._instance, *args, **kwargs) + + def register(self, cls, func=None): + if self._instance_sd is None: + self._instance.__dict__['__' + self.__name__] = self._instance_sd = self._s_d.copy() + return self._instance_sd.add_overload(cls, func) + + def get_registered_types(self): + if self._instance_sd is not None: + return self._instance_sd.get_registered_types() + else: + return self._s_d.get_registered_types() + + +class UnboundSDMethod: + """ An unbound singledispatch method """ + + def __init__(self, s_d): + self._s_d = s_d + + def copy(self): + return self._s_d.copy() def dispatch(self, cls): return self._s_d.dispatch(cls) @@ -259,10 +318,16 @@ def registry(self): return self._s_d.registry def __call__(self, *args, **kwargs): - if self._instance is None: - return self.dispatch(args[1].__class__)(*args, **kwargs) - else: - return self.dispatch(args[0].__class__)(self._instance, *args, **kwargs) + return self.dispatch(args[1].__class__)(*args, **kwargs) + + def register(self, cls, func=None): + return self._s_d.register(cls, func) + + def get_registered_types(self): + return self._s_d.get_registered_types() + + def add_overload(self, cls, func=None): + self._s_d.add_overload(cls, func) def _fixup_class_attributes(cls): @@ -275,7 +340,7 @@ def _fixup_class_attributes(cls): if isinstance(value, singledispatch) and name not in patched: if name in attributes: raise RuntimeError('Cannot override generic function. ' - 'Try @register("{}", object) instead.'.format(name)) + 'Try @{name}.register(object) instead.'.format(name=name)) generic = value.copy() setattr(cls, name, generic) patched.add(name) @@ -283,15 +348,15 @@ def _fixup_class_attributes(cls): for name, value in attributes.items(): if not callable(value) or isinstance(value, singledispatch): continue - if hasattr(value, 'overloads'): - for generic_name, cls in value.overloads: + if hasattr(value, '_overloads'): + for generic_name, cls in value._overloads: generic = attributes[generic_name] - generic.register(cls, value) + generic.add_overload(cls, value) else: # register over-ridden methods for generic in generics: for cls, f in generic.registry.items(): if name == f.__name__: - generic.register(cls, value) + generic.add_overload(cls, value) break @@ -319,14 +384,15 @@ class SingleDispatchABC(object): def register(name, cls): - """ Decorator for methods on a sub-class to register an overload on a base-class generic method + """ Decorator for methods on a sub-class to register an overload on a base-class generic method. :param name: is the name of the generic method on the base class, or the unbound method itself :param cls: is the type to register """ + warnings.warn('Use @BaseClass.method.register() instead of register', DeprecationWarning, stacklevel=2) name = getattr(name, '__name__', name) # __name__ exists on sd_method courtesy of update_wrapper def wrapper(func): - overloads = getattr(func, 'overloads', []) + overloads = getattr(func, '_overloads', []) overloads.append((name, cls)) - func.overloads = overloads + func._overloads = overloads return func return wrapper diff --git a/methoddispatch/methoddispatch3.py b/methoddispatch/methoddispatch3.py index bb82d51..ea9ed56 100644 --- a/methoddispatch/methoddispatch3.py +++ b/methoddispatch/methoddispatch3.py @@ -4,6 +4,7 @@ from functools import update_wrapper, _find_impl from types import MappingProxyType from weakref import WeakKeyDictionary +import warnings __all__ = ['singledispatch', 'register', 'SingleDispatchMeta', 'SingleDispatchABCMeta', 'SingleDispatch', 'SingleDispatchABC'] @@ -53,23 +54,29 @@ def dispatch(self, cls): self.dispatch_cache[cls] = impl return impl - def register(self, cls, func=None): - """register(cls, func) -> func + def add_overload(self, cls, func=None): + """add_overload(cls, func) -> func (private) Registers a new implementation for the given *cls* on a *generic_func*. + """ if func is None: - return lambda f: self.register(cls, f) + return lambda f: self.add_overload(cls, f) self._registry[cls] = func if self.cache_token is None and hasattr(cls, '__abstractmethods__'): self.cache_token = get_cache_token() self.dispatch_cache.clear() return func - def __get__(self, instance, cls=None): - if cls is not None and not isinstance(cls, SingleDispatchMeta): - raise ValueError('singledispatch can only be used on methods of SingleDispatchMeta types') - wrapper = sd_method(self, instance) + def __get__(self, instance, owner=None): + if owner is not None and not isinstance(owner, SingleDispatchMeta): + raise ValueError('singledispatch can only be used on methods of SingleDispatch subclasses') + if instance is not None: + wrapper = BoundSDMethod(self, instance) + elif owner is not None: + wrapper = UnboundSDMethod(self) + else: + return self update_wrapper(wrapper, self.func) return wrapper @@ -87,12 +94,69 @@ def copy(self): def get_registered_types(self): return [type_ for type_ in self._registry.keys() if type_ is not object] + def register(self, cls, func=None): + """ Decorator for methods to register an overload on generic method. + :param cls: is the type to register or may be omitted or None to use the annotated parameter type. + """ + if func is None: + return lambda f: self.register(cls, f) + + overloads = getattr(func, '_overloads', []) + overloads.append((self.__name__, cls)) + func._overloads = overloads + return func + + +class BoundSDMethod(object): + """ A bound singledispatch method """ -class sd_method(object): - """ A singledispatch method """ def __init__(self, s_d, instance): self._instance = instance self._s_d = s_d + self._instance_sd = instance.__dict__.get('__' + s_d.__name__, None) + + def copy(self): + if self._instance_sd is not None: + return self._instance_sd.copy() + else: + return self._s_d.copy() + + def dispatch(self, cls): + if self._instance_sd is not None: + return self._instance_sd.dispatch(cls) + else: + return self._s_d.dispatch(cls) + + @property + def registry(self): + if self._instance_sd is not None: + return self._instance_sd.registry + else: + return self._s_d.registry + + def __call__(self, *args, **kwargs): + return self.dispatch(args[0].__class__)(self._instance, *args, **kwargs) + + def register(self, cls, func=None): + if self._instance_sd is None: + self._instance.__dict__['__' + self.__name__] = self._instance_sd = self._s_d.copy() + return self._instance_sd.add_overload(cls, func) + + def get_registered_types(self): + if self._instance_sd is not None: + return self._instance_sd.get_registered_types() + else: + return self._s_d.get_registered_types() + + +class UnboundSDMethod: + """ An unbound singledispatch method """ + + def __init__(self, s_d): + self._s_d = s_d + + def copy(self): + return self._s_d.copy() def dispatch(self, cls): return self._s_d.dispatch(cls) @@ -102,10 +166,16 @@ def registry(self): return self._s_d.registry def __call__(self, *args, **kwargs): - if self._instance is None: - return self.dispatch(args[1].__class__)(*args, **kwargs) - else: - return self.dispatch(args[0].__class__)(self._instance, *args, **kwargs) + return self.dispatch(args[1].__class__)(*args, **kwargs) + + def register(self, cls, func=None): + return self._s_d.register(cls, func) + + def get_registered_types(self): + return self._s_d.get_registered_types() + + def add_overload(self, cls, func=None): + self._s_d.add_overload(cls, func) def _fixup_class_attributes(cls): @@ -118,7 +188,7 @@ def _fixup_class_attributes(cls): if isinstance(value, singledispatch) and name not in patched: if name in attributes: raise RuntimeError('Cannot override generic function. ' - 'Try @register("{}", object) instead.'.format(name)) + 'Try @{name}.register(object) instead.'.format(name=name)) generic = value.copy() setattr(cls, name, generic) patched.add(name) @@ -126,15 +196,15 @@ def _fixup_class_attributes(cls): for name, value in attributes.items(): if not callable(value) or isinstance(value, singledispatch): continue - if hasattr(value, 'overloads'): - for generic_name, cls in value.overloads: + if hasattr(value, '_overloads'): + for generic_name, cls in value._overloads: generic = attributes[generic_name] - generic.register(cls, value) + generic.add_overload(cls, value) else: # register over-ridden methods for generic in generics: for cls, f in generic.registry.items(): if name == f.__name__: - generic.register(cls, value) + generic.add_overload(cls, value) break @@ -160,15 +230,17 @@ class SingleDispatch(metaclass=SingleDispatchMeta): class SingleDispatchABC(metaclass=SingleDispatchABCMeta): pass + def register(name, cls): """ Decorator for methods on a sub-class to register an overload on a base-class generic method. :param name: is the name of the generic method on the base class, or the unbound method itself :param cls: is the type to register """ + warnings.warn('Use @BaseClass.method.register() instead of register', DeprecationWarning, stacklevel=2) name = getattr(name, '__name__', name) # __name__ exists on sd_method courtesy of update_wrapper def wrapper(func): - overloads = getattr(func, 'overloads', []) + overloads = getattr(func, '_overloads', []) overloads.append((name, cls)) - func.overloads = overloads + func._overloads = overloads return func return wrapper diff --git a/methoddispatch/methoddispatch36.py b/methoddispatch/methoddispatch36.py index 06be769..b31a388 100644 --- a/methoddispatch/methoddispatch36.py +++ b/methoddispatch/methoddispatch36.py @@ -4,15 +4,17 @@ from functools import update_wrapper, _find_impl from types import MappingProxyType from weakref import WeakKeyDictionary +import warnings + +__all__ = ['singledispatch', 'register', 'SingleDispatch'] -__all__ = ['singledispatch', 'register', 'SingleDispatch', 'SingleDispatchABC'] ################################################################################ ### singledispatch() - single-dispatch generic function decorator ################################################################################ -class singledispatch(object): +class singledispatch: """Single-dispatch generic function decorator. Transforms a function into a generic function, which can have different @@ -52,40 +54,32 @@ def dispatch(self, cls): self.dispatch_cache[cls] = impl return impl - def register(self, cls, func=None): - """register(cls, func) -> func + def add_overload(self, cls, func=None): + """add_overload(cls, func) -> func (private) Registers a new implementation for the given *cls* on a *generic_func*. """ if func is None: if isinstance(cls, type): - return lambda f: self.register(cls, f) - ann = getattr(cls, '__annotations__', {}) - if not ann: - raise TypeError( - f"Invalid first argument to `register()`: {cls!r}. " - f"Use either `@register(some_class)` or plain `@register` " - f"on an annotated function." - ) + return lambda f: self.add_overload(cls, f) func = cls - - # only import typing if annotation parsing is necessary - from typing import get_type_hints - argname, cls = next(iter(get_type_hints(func).items())) - assert isinstance(cls, type), ( - f"Invalid annotation for {argname!r}. {cls!r} is not a class." - ) + cls = _get_class_from_annotation(cls) self._registry[cls] = func if self.cache_token is None and hasattr(cls, '__abstractmethods__'): self.cache_token = get_cache_token() self.dispatch_cache.clear() return func - def __get__(self, instance, cls=None): - if cls is not None and not issubclass(cls, SingleDispatch): + def __get__(self, instance, owner=None): + if owner is not None and not issubclass(owner, SingleDispatch): raise ValueError('singledispatch can only be used on methods of SingleDispatch subclasses') - wrapper = sd_method(self, instance) + if instance is not None: + wrapper = BoundSDMethod(self, instance) + elif owner is not None: + wrapper = UnboundSDMethod(self) + else: + return self update_wrapper(wrapper, self.func) return wrapper @@ -103,12 +97,91 @@ def copy(self): def get_registered_types(self): return [type_ for type_ in self._registry.keys() if type_ is not object] + def register(self, cls, func=None): + """ Decorator for methods to register an overload on generic method. + :param cls: is the type to register or may be omitted or None to use the annotated parameter type. + """ + if func is None: + if isinstance(cls, type): + return lambda f: self.register(cls, f) + func = cls + cls = _get_class_from_annotation(cls) + + overloads = getattr(func, '_overloads', []) + overloads.append((self.__name__, cls)) + func._overloads = overloads + return func + + +def _get_class_from_annotation(func): + # only import inspect if annotation parsing is necessary + import inspect + argspec = inspect.getfullargspec(func) + assert len(argspec.args) > 1, f'{func!r} must have at least 2 parameters.' + argname = argspec.args[1] + if argname not in argspec.annotations: + raise TypeError( + f"Invalid first argument to `register()`: {func!r}. " + f"Use either `@register(some_class)` or plain `@register` " + f"on an annotated function." + ) + cls = argspec.annotations.get(argname) + assert isinstance(cls, type), ( + f"Invalid annotation for {argname!r}. {cls!r} is not a class." + ) + return cls + + +class BoundSDMethod: + """ A bound singledispatch method """ -class sd_method(object): - """ A singledispatch method """ def __init__(self, s_d, instance): self._instance = instance self._s_d = s_d + self._instance_sd = instance.__dict__.get('__' + s_d.__name__, None) + + def copy(self): + if self._instance_sd is not None: + return self._instance_sd.copy() + else: + return self._s_d.copy() + + def dispatch(self, cls): + if self._instance_sd is not None: + return self._instance_sd.dispatch(cls) + else: + return self._s_d.dispatch(cls) + + @property + def registry(self): + if self._instance_sd is not None: + return self._instance_sd.registry + else: + return self._s_d.registry + + def __call__(self, *args, **kwargs): + return self.dispatch(args[0].__class__)(self._instance, *args, **kwargs) + + def register(self, cls, func=None): + if self._instance_sd is None: + self._instance.__dict__['__' + self.__name__] = self._instance_sd = self._s_d.copy() + return self._instance_sd.add_overload(cls, func) + + def get_registered_types(self): + if self._instance_sd is not None: + return self._instance_sd.get_registered_types() + else: + return self._s_d.get_registered_types() + + +class UnboundSDMethod: + """ An unbound singledispatch method """ + + def __init__(self, s_d): + self._s_d = s_d + + def copy(self): + return self._s_d.copy() def dispatch(self, cls): return self._s_d.dispatch(cls) @@ -118,10 +191,16 @@ def registry(self): return self._s_d.registry def __call__(self, *args, **kwargs): - if self._instance is None: - return self.dispatch(args[1].__class__)(*args, **kwargs) - else: - return self.dispatch(args[0].__class__)(self._instance, *args, **kwargs) + return self.dispatch(args[1].__class__)(*args, **kwargs) + + def register(self, cls, func=None): + return self._s_d.register(cls, func) + + def get_registered_types(self): + return self._s_d.get_registered_types() + + def add_overload(self, cls, func=None): + self._s_d.add_overload(cls, func) def _fixup_class_attributes(cls): @@ -129,12 +208,12 @@ def _fixup_class_attributes(cls): attributes = cls.__dict__ patched = set() for base in cls.mro()[1:]: - if issubclass(base, SingleDispatch): + if issubclass(base, SingleDispatch) and base is not SingleDispatch: for name, value in base.__dict__.items(): if isinstance(value, singledispatch) and name not in patched: if name in attributes: raise RuntimeError('Cannot override generic function. ' - 'Try @register("{}", object) instead.'.format(name)) + 'Try @{name}.register(object) instead.'.format(name=name)) generic = value.copy() setattr(cls, name, generic) patched.add(name) @@ -142,22 +221,22 @@ def _fixup_class_attributes(cls): for name, value in attributes.items(): if not callable(value) or isinstance(value, singledispatch): continue - if hasattr(value, 'overloads'): - for generic_name, cls in value.overloads: + if hasattr(value, '_overloads'): + for generic_name, cls in value._overloads: generic = attributes[generic_name] if cls is None: - generic.register(value) + generic.add_overload(value) else: - generic.register(cls, value) + generic.add_overload(cls, value) else: # register over-ridden methods for generic in generics: for cls, f in generic.registry.items(): if name == f.__name__: - generic.register(cls, value) + generic.add_overload(cls, value) break -class SingleDispatch(object): +class SingleDispatch: """ Base or mixin class to enable single dispatch on methods. """ @@ -166,18 +245,16 @@ def __init_subclass__(cls, **kwargs): _fixup_class_attributes(cls) -SingleDispatchABC = SingleDispatch # for backwards compatibility - - def register(name, cls=None): """ Decorator for methods on a sub-class to register an overload on a base-class generic method. :param name: is the name of the generic method on the base class, or the unbound method itself :param cls: is the type to register or may be omitted or None to use the annotated parameter type. """ + warnings.warn('Use @BaseClass.method.register() instead of register', DeprecationWarning, stacklevel=2) name = getattr(name, '__name__', name) # __name__ exists on sd_method courtesy of update_wrapper def wrapper(func): - overloads = getattr(func, 'overloads', []) + overloads = getattr(func, '_overloads', []) overloads.append((name, cls)) - func.overloads = overloads + func._overloads = overloads return func return wrapper diff --git a/test.py b/test.py index 9d18397..16f8085 100644 --- a/test.py +++ b/test.py @@ -2,40 +2,66 @@ import unittest import methoddispatch -from methoddispatch import singledispatch, register, SingleDispatch, SingleDispatchABC +from methoddispatch import singledispatch, SingleDispatch +try: + from methoddispatch import SingleDispatchABC +except ImportError: + SingleDispatchABC = SingleDispatch + import abc import doctest import six import sys +def instance_foo(self, bar): + return 'instance' + + class BaseClass(SingleDispatch): @singledispatch def foo(self, bar): return 'default' - @foo.register(int) + @foo.add_overload(int) def foo_int(self, bar): return 'int' + @foo.register(set) + def foo_set(self, bar): + return 'set' + + @singledispatch + def bar(self, bar): + return 'default' + + @bar.register(int) + def bar_int(self, bar): + return 'int' + class SubClass(BaseClass): - @register('foo', float) + @BaseClass.foo.register(float) def foo_float(self, bar): return 'float' def foo_int(self, bar): return 'sub int' - @register(BaseClass.foo, str) + @BaseClass.foo.register(str) def foo_str(self, bar): return 'str' + class SubSubClass(SubClass): - @register('foo', list) + @SubClass.foo.register(list) def foo_list(self, bar): return 'list' + @methoddispatch.register('foo', tuple) + def foo_tuple(self, bar): + return 'tuple' + @six.add_metaclass(abc.ABCMeta) class IFoo(object): @@ -44,7 +70,7 @@ def foo(self, bar): pass -class MyClass(IFoo, SingleDispatchABC): +class MyClass(SingleDispatchABC, IFoo): @singledispatch def foo(self, bar): return 'my default' @@ -53,15 +79,9 @@ def foo(self, bar): def foo_int(self, bar): return 'my int' - -@singledispatch -def func(a): - return 'default' - - -@func.register(bool) -def func_bool(a): - return not a + @foo.register(list) + def foo_list(self, bar): + return 'my list' class TestMethodDispatch(unittest.TestCase): @@ -69,7 +89,10 @@ def test_base_class(self): b = BaseClass() self.assertEqual(b.foo('text'), 'default') self.assertEqual(b.foo(1), 'int') + self.assertEqual(b.foo(set()), 'set') self.assertEqual(b.foo(1.0), 'default') + self.assertEqual(b.bar(1.0), 'default') + self.assertEqual(b.bar(1), 'int') def test_sub_class(self): s = SubClass() @@ -92,6 +115,14 @@ def test_independence(self): self.assertEqual(b.foo(1.0), 'default') self.assertEqual(s.foo(1.0), 'float') + def test_instance_register(self): + b = BaseClass() + b2 = BaseClass() + b.foo.register(float, instance_foo) + self.assertEqual(BaseClass.foo(b, 1.0), 'default') + self.assertEqual(b.foo(1.0), 'instance') + self.assertEqual(b2.foo(1.0), 'default') + def test_attempted_override(self): with self.assertRaises(RuntimeError): class SubClass2(BaseClass): @@ -100,13 +131,9 @@ def foo(self, bar): def test_abc_interface_support(self): m = MyClass() - self.assertEqual(m.foo('text'), 'my default') - self.assertEqual(m.foo(1), 'my int') - - def test_pure_funcs(self): - self.assertEqual('default', func(self)) - self.assertEqual(False, func(True)) - self.assertEqual(True, func(False)) + self.assertEqual('my default', m.foo('text')) + self.assertEqual('my int', m.foo(1)) + self.assertEqual('my list', m.foo([])) def test_class_access(self): s = SubClass() @@ -118,7 +145,7 @@ def test_class_extra_attributes(self): self.assertTrue(hasattr(SubClass.foo, 'dispatch')) self.assertTrue(hasattr(SubClass.foo, 'registry')) self.assertIs(SubClass.foo.dispatch(float), SubClass.__dict__['foo_float']) - self.assertEqual(set(SubClass.foo.registry.keys()), set([float, object, int, str])) + self.assertEqual(set(SubClass.foo.registry.keys()), set([float, object, set, int, str])) def test_instance_extra_attributes(self): """ Check that dispatch and registry attributes are accessible """ @@ -126,14 +153,14 @@ def test_instance_extra_attributes(self): self.assertTrue(hasattr(s.foo, 'dispatch')) self.assertTrue(hasattr(s.foo, 'registry')) self.assertIs(s.foo.dispatch(float), SubClass.__dict__['foo_float']) - self.assertEqual(set(s.foo.registry.keys()), set([float, object, int, str])) + self.assertEqual(set(s.foo.registry.keys()), set([float, object, set, int, str])) @unittest.skipIf(six.PY2, 'docs are in python3 syntax') def test_docs(self): num_failures, num_tests = doctest.testmod(methoddispatch, name='methoddispatch') # we expect 6 failures as a result like is not deterministic - self.assertLessEqual(num_failures, 7) - self.assertGreater(num_tests, 30) + self.assertLessEqual(num_failures, 6) + self.assertGreaterEqual(num_tests, 40) @unittest.skipIf(sys.version_info < (3, 6), 'python < 3.6') def test_annotations(self): @@ -143,12 +170,18 @@ def test_annotations(self): annotation_tests = """ def test_annotations(self): class AnnClass(BaseClass): - @register('foo') + @BaseClass.foo.register def foo_int(self, bar: int): - return 'ann int' + return 'an int' c = AnnClass() - self.assertEqual(c.foo(1), 'ann int') + self.assertEqual(c.foo(1), 'an int') + + def foo_float(obj: AnnClass, bar: float): + return 'float' + c.foo.register(foo_float) + self.assertEqual(c.foo(1.23), 'float') + test_annotations(self) """