In [1]:
#default_exp dispatch

In [2]:
#export
from fastcore.imports import *
from fastcore.foundation import *
from fastcore.utils import *

In [3]:
from nbdev.showdoc import *
from fastcore.test import *

# Type dispatch

> Basic single and dual parameter dispatch

## Helpers

In [11]:
def foo(a: int): return None
typing.get_type_hints(foo)

{'a': int}

In [12]:
typing._allowed_types

(function,
 builtin_function_or_method,
 method,
 module,
 wrapper_descriptor,
 method-wrapper,
 method_descriptor)

In [7]:
# method used in getting the types of the args passed into a func/method if the func/method is an allowed type 
# in python else return an empty dict
#exports
def type_hints(f):
    "Same as `typing.get_type_hints` but returns `{}` if not allowed type"
    return typing.get_type_hints(f) if isinstance(f, typing._allowed_types) else {}

In [13]:
type_hints(foo)

{'a': int}

In [15]:
#tendo test
def foo(a: int): return None
test_eq(type_hints(foo), {'a': int})
def foo(a: float): return None
test_eq(type_hints(foo), {'a': float})
def foo(a: typing.Tuple[float, int]): return None
test_eq(type_hints(foo), {'a': typing.Tuple[float, int]})
def foo(a: typing.Tuple[float, int], b: list): return None
test_eq(type_hints(foo), {'a': typing.Tuple[float, int], 'b': list})

In [10]:
def foo(a:int) -> float: return None
typing.get_type_hints(foo)

{'a': int, 'return': float}

In [16]:
# func to get the return key in the func's arg type dict as shown above
#export
def anno_ret(func):
    "Get the return annotation of `func`"
    if not func: return None
    ann = type_hints(func)
    if not ann: return None
    return ann.get('return')

In [17]:
#hide
def f(x) -> float: return x
test_eq(anno_ret(f), float)
def f(x) -> typing.Tuple[float,float]: return x
test_eq(anno_ret(f), typing.Tuple[float,float])
def f(x) -> None: return x
test_eq(anno_ret(f), NoneType) #Note that when a None return type is specified, it return a NoneType
def f(x): return x
test_eq(anno_ret(f), None) #Note that when no return Type is listed, is returns None
test_eq(anno_ret(None), None)

In [52]:
issubclass(int, int), issubclass(int, numbers.Integral)

(True, True)

In [20]:
#export
# recursively grab two elements from a list of types, and compare them based on subclass and equality to 
# get their order in desceding order from uppermost superclass to least subclass
# https://www.geeksforgeeks.org/functools-module-in-python/
cmp_instance = functools.cmp_to_key(lambda a,b: 0 if a==b else 1 if issubclass(a,b) else -1)

In [54]:
td = [int, numbers.Number, numbers.Integral, object]
sorted(td, key=cmp_instance)

[object, numbers.Number, numbers.Integral, int]

In [23]:
td = {int:1, numbers.Number:2, numbers.Integral:3} #why did Jeremy use a dict? Lists also work fine
test_eq(sorted(td, key=cmp_instance), [numbers.Number, numbers.Integral, int])

In [24]:
# Etract the first two args type annotations in a func of method as a list containing two items
#export
def _p2_anno(f):
    "Get the 1st 2 annotations of `f`, defaulting to `object`"
    hints = type_hints(f)
#     ensure the return anno is not included
    ann = [o for n,o in hints.items() if n!='return']
#     if we don't have upto two arg annotations, append object which is basically the root annotation for everything in python
    while len(ann)<2: ann.append(object)
    return ann[:2]

In [43]:
typing._allowed_types

(function,
 builtin_function_or_method,
 method,
 module,
 wrapper_descriptor,
 method-wrapper,
 method_descriptor)

In [49]:
class M:
    a = 9
f = attrgetter('a')  #this can be seen as a callable which is an allowed type which has not type for it's arg
print(type_hints(f))
f(M)

{}


9

In [50]:
def _f(a): pass
test_eq(_p2_anno(_f), (object,object))  #when no  annotations are passed, the annotations are set as object by default
def _f(a, b): pass
test_eq(_p2_anno(_f), (object,object))
def _f(a:None, b)->str: pass
test_eq(_p2_anno(_f), (NoneType,object)) #Note that for type of `None` we return `Nonetype`
def _f(a:str, b)->float: pass
test_eq(_p2_anno(_f), (str,object)) #note that the return annotation was not returned
def _f(a:None, b:str)->float: pass
test_eq(_p2_anno(_f), (NoneType,str))
def _f(a:int, b:int)->float: pass
test_eq(_p2_anno(_f), (int,int))
def _f(self, a:int, b:int): pass
test_eq(_p2_anno(_f), (int,int))
def _f(a:int, b:str)->float: pass
test_eq(_p2_anno(_f), (int,str))
test_eq(_p2_anno(attrgetter('foo')), (object,object))  #since no type for it's args, default to `object`

## TypeDispatch -

The following class is the basis that allows us to do type dipatch with type annotations. It contains a dictionary type -> functions and ensures that the proper function is called when passed an object (depending on its type).

In [59]:
#export
class _TypeDict:
#     d is the dict if dispatchable types to func ie {type:func} and cache is the cache -_-
    def __init__(self): self.d,self.cache = {},{}

    def _reset(self): #this should really be called reorder
#         when reset, reset the  cache to an empty dict and reset the `self.d` to an ordered dict of {types:func} in ascending order this time
        self.d = {k:self.d[k] for k in sorted(self.d, key=cmp_instance, reverse=True)}
        self.cache = {}

    def add(self, t, f):
        "Add type `t` and function `f`"
#         add a new type `t` to and it's dispatch func `f` to the dict of dispatchable funcs `self.d`. Always ensure it is a tuple for times 
#         where we have an arg in a func with more than one type
        if not isinstance(t,tuple): t=tuple(L(t))
#         add this func `f` with type `t` to the dispatchable type dict `self.d` in {t:f} pair
        for t_ in t: self.d[t_] = f
#         as soon as the a new func is added to the dispatchable type dict `self.d` reset(reorder) the dict from least subclass to highest superclass
        self._reset()

    def all_matches(self, k):
        "Find first matching type that is a super-class of `k`"
#         if a type queried for is not in cache which is from `self.d`
        if k not in self.cache:
#             if type k is equal to a type in `self.d`(dispatchable type dict) or a type from `self.d` is a superclass of `k` 
#             add that type from `self.d` to the list of items to be cached
            types = [f for f in self.d if k==f or (isinstance(k,type) and issubclass(k,f))]
#           all the types that match or are a superclass of the `k` query type should be used to index into the `self.d` dict so 
#           we can extract all the funcs that match that query type `k` and put them in a list which will then be put in the 
#           cache dict with a k of the query `k` type
            self.cache[k] = [self.d[o] for o in types]
#     return the list of func(s) that match the queried type
        return self.cache[k]

    def __getitem__(self, k):
        "Find first matching type that is a super-class of `k`"
        res = self.all_matches(k)
        return res[0] if len(res) else None

    def __repr__(self): return self.d.__repr__()
#     the `first` method returns the first item in the values list from `self.d`(dictionary of dispatchable types)
    def first(self): return first(self.d.values())

In [60]:
issubclass(int, numbers.Integral)

True

It needs to work a lot like a dict because we will be adding things to it.

Recall that encodes, decodes, setups of Transforms are objects of type TypeDispatch. So we want to dispatch different code depending on type. (See
Rotate example below).

Dispatch refers to how a programming language decides which piece of code to run when you call something. Different languages do this in different ways.

We want something for TypeDispatch that looks like a function so has a \_\_call\_\_
but when you call it with some argument, we are not just going to call a function BUT we are going to look at the type of the argument and we are going to find the appropriate function or method to call based on the type of the argument passed in and the methods created so far. 

So, inside our TypeDispatch is a dict called funcs where the key is the type so for eg in Rotate eg below we have keys, TensorImage and Bbox, and the values are the functions to call for each key. 

There is an add method in TypeDispatch which is what _TfmDict calls.It adds this function - when it does self[k].add(v). The add method finds the type annotation using _p2_anno and if anno is None assumes it is object (which is highest level of the type hierarchy) and it pops it into our funcs dictionary.


Later when we call \_\_call\_\_ it is going to look up the type of parameter you are calling this function on by mapping type (see below) and then it is going to look it up in this object. If it does not find it, it returns arg[0] which implies doing nothing but just returning x (so x is arg[0]). So if you call Rotate with Int (instead of with TensorImage or BoundingBox) nothing happens implies it return whatever is its input unchanged. That function is not defined for that type. If we do find it we call it with anything else you passed along (*args, **kwargs). Can turn into method if needed.

How does it look up type? See \_\_getitem\_\_ code 

We keep a cache in _TypeDict which is a dict mapping from types to functions. We need a special cache dict coz of way TypeDispatch works (explained below). If we do not find it in cache - self.func.all_matches(k[0]) - if k is not in cache (in all_matches code) then we add it to cache as per code in all_matches method. 

TypeDispatch looks not just at type of TensorImage, BoundingBox etc but looks at subclassess of those for e.g., if we also add Rotate for Tensors TypeDispatch
will grab the most specific (lowest subclass) version it can which is a TensorImage. Also if you have defined encodes on Tensor and call it on TensorImage which is subclass of Tensor, then it WILL be invoked because of the way the cache is created. 



In [None]:
self.func is a `_TypeDict` containing a dict of type to `_TypeDict` {type : `_TypeDict`}

In [145]:
#export
class TypeDispatch:
    "Dictionary-like object; `__getitem__` matches keys of types using `issubclass`"
    def __init__(self, funcs=(), bases=()):
        self.funcs,self.bases = _TypeDict(),L(bases).filter(is_not(None))
#         that add here calls add below
        for i, o in enumerate(L(funcs)): 
            self.add(o)
#             if i == 0: break
        self.inst = None

    def add(self, f):
        "Add type `t` and function `f`"
#         get the first two annotations of the function(s) in `funcs
        a0,a1 = _p2_anno(f)
#        NB: Always remember that sef.funcs is a `_TypeDict()`. In our case, I'll reference it as the outer `TypeDict`
#        In the `self.funcs` use the `d` attribute of a `_TypeDict()` which contains a dictionary of {types:function with the type}
#        to get a type of a0 (first arg's type) if it exists out of this `_TypeDict` dictionary
        t = self.funcs.d.get(a0)
#       If the type we are given is not available in the `d` atrr
        if t is None:
#         create a new TypeDict
            t = _TypeDict()
#         and add this new type dict as a value in the outer `TypeDict`. The key in the outer `TypeDict` should be the a0(first arg's type)
            self.funcs.add(a0, t)
#        At this point we have an outer `TypeDict` of {first_arg_type: _TypeDict()}
#        This `_TypeDict` inside the outer `TypeDict` willl be referneced as the inner `TypeDict`
        
#       We add the function argument `f` as a value into this inner `TypeDict` with a key corresponding to the a0 ie type of the second arg
        t.add(a1, f)
#       At this point, we have a nested type dict of the form
#       {first_arg_type: {second_arg_type : func}}

#      NB: Note that if either first_arg_type or second_arg_type is a collection of types, 
#      the {first_arg_type: {second_arg_type : func}} will be done for each of the args in the collection individually
#      hence we will have {first_arg_type[0]: 
#                                            {second_arg_type[0] : func, second_arg_type[1] : func}, 
#                          first_arg_type[1]:
#                                            {second_arg_type[0] : func, second_arg_type[1] : func}}
        
    
    
#   This returns the inner first item in the list of values for the inner `_TypeDict()` from the 
#   first item in the list of values for the outer `_TypeDict()`
    def first(self): return self.funcs.first().first()
    
#   Calling `self[]` calls `self.__getitem__()`. The function in the Inner `_TypeDict()` is extracted and it's return annotation is returned
    def returns(self, x): return anno_ret(self[type(x)])
#     ensure that the return annotation of the func in the inner `_TypeDict()` is NoneType
    def returns_none(self, x):
        r = anno_ret(self[type(x)])
        return r if r == NoneType else None

    def _attname(self,k): return getattr(k,'__name__',str(k))
    def __repr__(self):
        r = [f'({self._attname(k)},{self._attname(l)}) -> {getattr(v, "__name__", v.__class__.__name__)}'
             for k in self.funcs.d for l,v in self.funcs[k].d.items()]
        r = r + [o.__repr__() for o in self.bases]
        return '\n'.join(r)

    def __call__(self, *args, **kwargs):
        ts = L(args).map(type)[:2]
        f = self[tuple(ts)]
        if not f: return args[0]
        if self.inst is not None: f = MethodType(f, self.inst)
        return f(*args, **kwargs)

    def __get__(self, inst, owner):
        self.inst = inst
        return self

#     Look through the `self.funcs` outer `_TypeDict()` to find if there are types in it's keys that match the query `k` or are a superclass of it.
#     study `_TypeDict().all_matches` mthod again to get more insight 
    def __getitem__(self, k):
        "Find first matching type that is a super-class of `k`"
        k = L(k)
#         always ensure that `k` is a list of 2 types
        while len(k)<2: k.append(object)
#       obtain the list of inner `TypeDict`(s) in the outer `_TypeDict()` that match the query type `k[0]` (first arg's type)
        r = self.funcs.all_matches(k[0])
#      loop
        for t in r:
#           for each TypeDict in the list, use the key k[1](second arg's type) to get the corresponding func
            o = t[k[1]]
#       if a func was found from the inner `_TypeDict()`, return it
            if o is not None: return o
        for base in self.bases:
            res = base[k]
            if res is not None: return res
        return None

In [157]:
issubclass(np.int32, numbers.Integral), issubclass(str, typing.Collection)

(True, True)

In [158]:
type(np.int32(1))

numpy.int32

In [146]:
def f_col(x:typing.Collection): return x
def f_nin(x:numbers.Integral)->int:  return x+1
def f_ni2(x:int): return x
def f_bll(x:(bool,list)): return x
def f_num(x:numbers.Number): return x
t = TypeDispatch([f_nin,f_ni2,f_num,f_bll,None]) # Store the funcs in the typedispatch's `_TypeDict`

t.add(f_ni2) #Should work even if we add the same function twice. because of caching in the _TypeDict
test_eq(t[int], f_ni2) 
test_eq(t[np.int32], f_nin) #np.int32 is a subclass of numbers.Integral which is why the type is dispatched to the `f_nin` func
test_eq(t[str], None) #there is no func with this type or a superclass of it so `typedispatch.__getitem__` defaults to None 
# test_eq(t[float], f_num) float is a subclass of numbers.Number
test_eq(t[bool], f_bll) #since the types in the `_TypeDict()` for multiple types are spread out, we can get their individual dispatch funcs
test_eq(t[list], f_bll)
t.add(f_col) # we can add new funcs to the TypeDispatch dict on the fly just like we would to a normal dict
test_eq(t[str], f_col) #str is a subclass of typing.Collection so it is dispatched to the appropriate func
test_eq(t[np.int32], f_nin)
# Let's test the __call__
o = np.int32(1)
test_eq(t(o), 2)
test_eq(t.returns(o), int)
assert t.first() is not None
t

(list,object) -> f_bll
(Collection,object) -> f_col
(bool,object) -> f_bll
(int,object) -> f_ni2
(Integral,object) -> f_nin
(Number,object) -> f_num
(object,object) -> NoneType

If bases is set to a collection of TypeDispatch objects, then they are searched matching functions if no match is found in this object.

In [152]:
t[numbers.Number]

<function __main__.f_num>

In [147]:
def f_str(x:str): return x+'1'

t2 = TypeDispatch(f_str, bases=t)
test_eq(t2[int], f_ni2)
test_eq(t2[np.int32], f_nin)
test_eq(t2[float], f_num)
test_eq(t2[bool], f_bll)
test_eq(t2[str], f_str)
test_eq(t2('a'), 'a1')
test_eq(t2[np.int32], f_nin)
test_eq(t2(o), 2)
test_eq(t2.returns(o), int)

In [None]:
def m_nin(self, x:(str,numbers.Integral)): return str(x)+'1'
def m_bll(self, x:bool): self.foo='a'
def m_num(self, x:numbers.Number): return x

t = TypeDispatch([m_nin,m_num,m_bll])
class A: f = t
a = A()
test_eq(a.f(1), '11')
test_eq(a.f(1.), 1.)
test_is(a.f.inst, a)
a.f(False)
test_eq(a.foo, 'a')
test_eq(a.f(()), ())

In [None]:
def m_tup(self, x:tuple): return x+(1,)
t2 = TypeDispatch(m_tup, t)
class A2: f = t2
a2 = A2()
test_eq(a2.f(1), '11')
test_eq(a2.f(1.), 1.)
test_is(a2.f.inst, a2)
a2.f(False)
test_eq(a2.foo, 'a')
test_eq(a2.f(()), (1,))

In [None]:
def f1(x:numbers.Integral, y): return x+1
def f2(x:int, y:float): return x+y
t = TypeDispatch([f1,f2])

test_eq(t[int], f1)
test_eq(t[int,int], f1)
test_eq(t[int,float], f2)
test_eq(t[float,float], None)
test_eq(t[np.int32,float], f1)
test_eq(t(3,2.0), 5)
test_eq(t(3,2), 4)
test_eq(t('a'), 'a')

## typedispatch Decorator

In [None]:
#export
class DispatchReg:
    "A global registry for `TypeDispatch` objects keyed by function name"
    def __init__(self): self.d = defaultdict(TypeDispatch)
    def __call__(self, f):
        nm = f'{f.__qualname__}'
        self.d[nm].add(f)
        return self.d[nm]

typedispatch = DispatchReg()

In [None]:
@typedispatch
def f_td_test(x, y): return f'{x}{y}'
@typedispatch
def f_td_test(x:numbers.Integral, y): return x+1
@typedispatch
def f_td_test(x:int, y:float): return x+y

test_eq(f_td_test(3,2.0), 5)
test_eq(f_td_test(3,2), 4)
test_eq(f_td_test('a','b'), 'ab')