In [None]:
from fastai2.data.all import *



In [None]:
def _is_tuple(o): 
    r =  isinstance(o, tuple)  and not hasattr(o, '_fields')
    print(r)
    return r

from fastcore.transform import _TfmMeta

class Transform(metaclass=_TfmMeta):
    "Delegates (`__call__`,`decode`,`setup`) to (`encodes`,`decodes`,`setups`) if `split_idx` matches"
    split_idx,init_enc,order,train_setup = None,None,0,None
    def __init__(self, enc=None, dec=None, split_idx=None, order=None):
        self.split_idx = ifnone(split_idx, self.split_idx)
        if order is not None: self.order=order
        self.init_enc = enc or dec
        if not self.init_enc: return

        self.encodes,self.decodes,self.setups = TypeDispatch(),TypeDispatch(),TypeDispatch()
        if enc:
            self.encodes.add(enc)
            self.order = getattr(enc,'order',self.order)
            if len(type_hints(enc)) > 0: self.input_types = first(type_hints(enc).values())
            self._name = _get_name(enc)
        if dec: self.decodes.add(dec)

    @property
    def name(self): return getattr(self, '_name', _get_name(self))
    def __call__(self, x, **kwargs): return self._call('encodes', x, **kwargs)
    def decode  (self, x, **kwargs): return self._call('decodes', x, **kwargs)
    def __repr__(self): return f'{self.name}: {self.encodes} {self.decodes}'

    def setup(self, items=None, train_setup=False):
        train_setup = train_setup if self.train_setup is None else self.train_setup
        return self.setups(getattr(items, 'train', items) if train_setup else items)

    def _call(self, fn, x, split_idx=None, **kwargs):
        if split_idx!=self.split_idx and self.split_idx is not None: return x
        return self._do_call(getattr(self, fn), x, **kwargs)

    def _do_call(self, f, x, **kwargs):
        if not _is_tuple(x):
            return x if f is None else retain_type(f(x, **kwargs), x, f.returns_none(x))
        res = tuple(self._do_call(f, x_, **kwargs) for x_ in x)
        return retain_type(res, x)
    
class ItemTransform(Transform):
    "A transform that always take tuples as items"
    def __call__(self, x, **kwargs):
        if not _is_tuple(x): return super().__call__(x, **kwargs)
        return retain_type(super().__call__(list(x), **kwargs), x)

    def decode(self, x, **kwargs):
        if not _is_tuple(x): return super().decode(x, **kwargs)
        return retain_type(super().decode(list(x), **kwargs), x)


In [None]:
class OtherTuple(Tuple):pass
class Tfm(ItemTransform):        
    def encodes(self, o:OtherTuple):
        print('tmf.encodes triggered')
        return list(o)+['m']
    
    def decodes(self, o:list): 
        return Tuple(o)
    

In [None]:
t = Tfm()
t(OtherTuple([0,1]))

True
False


[0, 1]

In [None]:
t.encodes(OtherTuple([0,1]))
# t.encodes(tuple([0,1]))

tmf.encodes triggered


[0, 1, 'm']

In [None]:
class Tfm2(ItemTransform):        
    def encodes(self, o):
        print('tmf.encodes triggered')
        return list(o)
    
    def decodes(self, o:list): 
        return Tuple(o)
    

In [None]:
t = Tfm2()
t(tuple([0,1]))

tmf.encodes triggered


[0, 1]

In [None]:
t(OtherTuple(0,1))

tmf.encodes triggered


[0, 1]

In [None]:
def _is_tuple(o): 
    return (isinstance(o, tuple) or ) and not hasattr(o, '_fields')


In [None]:
class OItemTransform(Transform):
    "A transform that always take tuples, or childeren from tuples as items."    
    def __call__(self, x, **kwargs):
        if not _is_tuple(x): return super().__call__(x, **kwargs)
        return retain_type(super().__call__(list(x), **kwargs), x)

    def decode(self, x, **kwargs):
        if not _is_tuple(x): return super().decode(x, **kwargs)
        return retain_type(super().decode(list(x), **kwargs), x)


In [None]:
class OtherTuple(L):pass
class Tfm(ItemTransform):        
    def encodes(self, o:OtherTuple):
        return list(o)
    
    def decodes(self, o:list): 
        return Tuple(o)
    

In [None]:
t = Tfm()
t(OtherTuple(0,1))

[0, 1]

In [None]:
t.encodes(OtherTuple(0,1))

[0, 1]