In [None]:
#default_exp core.dispatch

In [None]:
#export
from local.core.imports import *
from local.core.foundation import *
from local.core.utils import *
from local.test import *

In [None]:
from local.notebook.showdoc import *

# Type dispatch

> Basic single and dual parameter dispatch

## Helpers

In [None]:
#exports
def type_hints(f):
    "Same as `typing.get_type_hints` but returns `{}` if not allowed type"
    return typing.get_type_hints(f) if isinstance(f, typing._allowed_types) else {}

In [None]:
#export
def anno_ret(func):
    "Get the return annotation of `func`"
    if not func: return None
    ann = type_hints(func)
    if not ann: return None
    return ann.get('return')

In [None]:
#hide
def f(x) -> float: return x
test_eq(anno_ret(f), float)
def f(x) -> typing.Tuple[float,float]: return x
test_eq(anno_ret(f), typing.Tuple[float,float])
def f(x) -> None: return x
test_eq(anno_ret(f), NoneType)
def f(x): return x
test_eq(anno_ret(f), None)
test_eq(anno_ret(None), None)

In [None]:
#export
cmp_instance = functools.cmp_to_key(lambda a,b: 0 if a==b else 1 if issubclass(a,b) else -1)

In [None]:
td = {int:1, numbers.Number:2, numbers.Integral:3}
test_eq(sorted(td, key=cmp_instance), [numbers.Number, numbers.Integral, int])

In [None]:
#export
def _p2_anno(f):
    "Get the 1st 2 annotations of `f`, defaulting to `object`"
    hints = type_hints(f)
    ann = [o for n,o in hints.items() if n!='return']
    while len(ann)<2: ann.append(object)
    return ann[:2]

In [None]:
def _f(a): pass
test_eq(_p2_anno(_f), (object,object))
def _f(a, b): pass
test_eq(_p2_anno(_f), (object,object))
def _f(a:None, b)->str: pass
test_eq(_p2_anno(_f), (NoneType,object))
def _f(a:str, b)->float: pass
test_eq(_p2_anno(_f), (str,object))
def _f(a:None, b:str)->float: pass
test_eq(_p2_anno(_f), (NoneType,str))
def _f(a:int, b:int)->float: pass
test_eq(_p2_anno(_f), (int,int))
def _f(self, a:int, b:int): pass
test_eq(_p2_anno(_f), (int,int))
def _f(a:int, b:str)->float: pass
test_eq(_p2_anno(_f), (int,str))
test_eq(_p2_anno(attrgetter('foo')), (object,object))

## TypeDispatch -

The following class is the basis that allows us to do type dipatch with type annotations. It contains a dictionary type -> functions and ensures that the proper function is called when passed an object (depending on its type).

In [None]:
#export
class _TypeDict:
    def __init__(self): self.d,self.cache = {},{}

    def _reset(self):
        self.d = {k:self.d[k] for k in sorted(self.d, key=cmp_instance, reverse=True)}
        self.cache = {}

    def add(self, t, f):
        "Add type `t` and function `f`"
        if not isinstance(t,tuple): t=(t,)
        for t_ in t: self.d[t_] = f
        self._reset()

    def all_matches(self, k):
        "Find first matching type that is a super-class of `k`"
        if k not in self.cache:
            types = [f for f in self.d if k==f or (isinstance(k,type) and issubclass(k,f))]
            self.cache[k] = [self.d[o] for o in types]
        return self.cache[k]

    def __getitem__(self, k):
        "Find first matching type that is a super-class of `k`"
        res = self.all_matches(k)
        return res[0] if len(res) else None

    def __repr__(self): return self.d.__repr__()
    def first(self): return first(self.d.values())

In [None]:
#export
class TypeDispatch:
    "Dictionary-like object; `__getitem__` matches keys of types using `issubclass`"
    def __init__(self, *funcs):
        self.funcs = _TypeDict()
        for o in funcs: self.add(o)
        self.inst = None

    def add(self, f):
        "Add type `t` and function `f`"
        a0,a1 = _p2_anno(f)
        t = self.funcs.d.get(a0)
        if t is None:
            t = _TypeDict()
            self.funcs.add(a0, t)
        t.add(a1, f)

    def first(self): return self.funcs.first().first()
    def returns(self, x): return anno_ret(self[type(x)])
    def returns_none(self, x):
        r = anno_ret(self[type(x)])
        return r if r == NoneType else None

    def _attname(self,k): return getattr(k,'__name__',str(k))
    def __repr__(self):
        r = [f'({self._attname(k)},{self._attname(l)}) -> {v.__name__}'
             for k in self.funcs.d for l,v in self.funcs[k].d.items()]
        return '\n'.join(r)

    def __call__(self, *args, **kwargs):
        ts = L(args).map(type)[:2]
        f = self[tuple(ts)]
        if not f: return args[0]
        if self.inst is not None: f = types.MethodType(f, self.inst)
        return f(*args, **kwargs)

    def __get__(self, inst, owner):
        self.inst = inst
        return self

    def __getitem__(self, k):
        "Find first matching type that is a super-class of `k`"
        k = L(k if isinstance(k, tuple) else (k,))
        while len(k)<2: k.append(object)
        r = self.funcs.all_matches(k[0])
        if len(r)==0: return None
        for t in r:
            o = t[k[1]]
            if o is not None: return o
        return None

In [None]:
def f_col(x:typing.Collection): return x
def f_nin(x:numbers.Integral)->int:  return x+1
def f_ni2(x:int): return x
def f_bll(x:(bool,list)): return x
def f_num(x:numbers.Number): return x
t = TypeDispatch(f_nin,f_ni2,f_num,f_bll)

t.add(f_ni2) #Should work even if we add the same function twice.
test_eq(t[int], f_ni2)
test_eq(t[np.int32], f_nin)
test_eq(t[str], None)
test_eq(t[float], f_num)
test_eq(t[bool], f_bll)
test_eq(t[list], f_bll)
t.add(f_col)
test_eq(t[str], f_col)
test_eq(t[np.int32], f_nin)
o = np.int32(1)
test_eq(t(o), 2)
test_eq(t.returns(o), int)
assert t.first() is not None
t

(typing.Collection,object) -> f_col
(bool,object) -> f_bll
(int,object) -> f_ni2
(Integral,object) -> f_nin
(Number,object) -> f_num
(list,object) -> f_col

In [None]:
def m_nin(self, x:(str,numbers.Integral)): return str(x)+'1'
def m_bll(self, x:bool): self.foo='a'
def m_num(self, x:numbers.Number): return x

t = TypeDispatch(m_nin,m_num,m_bll)
class A: f = t
a = A()
test_eq(a.f(1), '11')
test_eq(a.f(1.), 1.)
test_is(a.f.inst, a)
a.f(False)
test_eq(a.foo, 'a')

In [None]:
def f1(x:numbers.Integral, y): return x+1
def f2(x:int, y:float): return x+y
t = TypeDispatch(f1,f2)

test_eq(t[int], f1)
test_eq(t[int,int], f1)
test_eq(t[int,float], f2)
test_eq(t[float,float], None)
test_eq(t[np.int32,float], f1)
test_eq(t(3,2.0), 5)
test_eq(t(3,2), 4)
test_eq(t('a'), 'a')
t

(int,float) -> f2
(Integral,object) -> f1

## typedispatch Decorator

In [None]:
#export
class DispatchReg:
    "A global registry for `TypeDispatch` objects keyed by function name"
    def __init__(self): self.d = defaultdict(TypeDispatch)
    def __call__(self, f):
        nm = f'{f.__qualname__}'
        self.d[nm].add(f)
        return self.d[nm]

typedispatch = DispatchReg()

In [None]:
@typedispatch
def f_td_test(x, y): return f'{x}{y}'
@typedispatch
def f_td_test(x:numbers.Integral, y): return x+1
@typedispatch
def f_td_test(x:int, y:float): return x+y

test_eq(f_td_test(3,2.0), 5)
test_eq(f_td_test(3,2), 4)
test_eq(f_td_test('a','b'), 'ab')

## Export -

In [None]:
#hide
from local.notebook.export import notebook2script
notebook2script(all_fs=True)

Converted 00_test.ipynb.
Converted 01_core.ipynb.
Converted 01a_utils.ipynb.
Converted 01b_dispatch.ipynb.
Converted 01c_torch_core.ipynb.
Converted 02_script.ipynb.
Converted 03_dataloader.ipynb.
Converted 04_transform.ipynb.
Converted 05_data_core.ipynb.
Converted 06_data_transforms.ipynb.
Converted 07_vision_core.ipynb.
Converted 08_pets_tutorial.ipynb.
Converted 09_vision_augment.ipynb.
Converted 10_data_block.ipynb.
Converted 11_layers.ipynb.
Converted 11a_vision_models_xresnet.ipynb.
Converted 12_optimizer.ipynb.
Converted 13_learner.ipynb.
Converted 14_callback_schedule.ipynb.
Converted 14a_callback_data.ipynb.
Converted 15_callback_hook.ipynb.
Converted 16_callback_progress.ipynb.
Converted 17_callback_tracker.ipynb.
Converted 18_callback_fp16.ipynb.
Converted 19_callback_mixup.ipynb.
Converted 20_metrics.ipynb.
Converted 21_tutorial_imagenette.ipynb.
Converted 22_vision_learner.ipynb.
Converted 23_tutorial_transfer_learning.ipynb.
Converted 30_text_core.ipynb.
Converted 31_tex