In [None]:
# export
from fastai.datasets import URLs, untar_data
from pathlib import Path
import torch, re, PIL, os, mimetypes, csv, operator, pickle
import matplotlib.pyplot as plt
from collections import OrderedDict
from typing import *
import pandas as pd, numpy as np
from enum import Enum
from torch import tensor,Tensor
from numpy import array

## Data block API from config class

### Core helper functions

In [None]:
# export
def noop(x, *args, **kwargs): return x
def range_of(x): return list(range(len(x)))
torch.Tensor.ndim = property(lambda x: x.dim())

def test(a,b,cmp,cname=None,tst_name=''):
    if cname is None: cname=cmp.__name__
    assert cmp(a,b),f"{tst_name},{cname}:\n{a}\n{b}"

def test_eq(a,b,tst_name=''): 
    if isinstance(a, np.ndarray) or (isinstance(a, Tensor) and a.ndim):
        assert len(a) == len(b), f"{tst_name}, lengths mismatch:\n{a}\n{b}"
        test(a,b,lambda x,y: (x == y).all(),'==',tst_name)
    else: test(a,b,operator.eq,'==',tst_name)

def test_ne(a,b):    test(a,b,operator.ne,'!=')

def compose(*funcs): return reduce(lambda f,g: lambda x: f(g(x)), reversed(funcs), noop)

In [None]:
# export
def noop(x, *args, **kwargs): return x
def range_of(x): return list(range(len(x)))
torch.Tensor.ndim = property(lambda x: x.dim())

def test(a,b,cmp,cname=None,tst_name='', tf=True):
    if cname is None: cname=cmp.__name__
    assert cmp(a,b)==tf,f"{tst_name},{cname}:\n{a}\n{b}"

def test_eq(a,b,tst_name='', tf=True): 
    op = (np.array_equal if isinstance(a, np.ndarray)
          else torch.equal if isinstance(a, Tensor) and a.ndim
          else operator.eq)
    return test(a,b,op,'==',tst_name=tst_name, tf=tf)

def test_ne(a,b,tst_name=''): test_eq(a,b,tst_name,False)

def apply_all(*funcs): return reduce(lambda f,g: lambda x: f(g(x)), reversed(funcs), noop)

In [None]:
# test
test_eq(noop(1),1)
test_ne(2,1)
test_eq(array([1,2]), array([1,2]))
test_ne(array([1,2]), array([1]))
test_eq(tensor([1,2]), tensor([1,2]))
test_ne(tensor([1,2]), tensor([1]))

In [None]:
# export
def listify(o):
    "Make `o` a list."
    if o is None: return []
    if isinstance(o, list): return o
    if isinstance(o, str): return [o]
    if not isinstance(o, Iterable): return [o]
    #Rank 0 tensors in PyTorch are Iterable but don't have a length.
    try: a = len(o)
    except: return [o]
    return list(o)

In [None]:
# test
test_eq(listify(None),[])
test_eq(listify([1,2,3]),[1,2,3])
test_ne(listify([1,2,3]),[1,2,])
test_eq(listify('abc'),['abc'])
test_eq(listify(range(0,3)),[0,1,2])
test_eq(listify(tensor(0)),[tensor(0)])
test_eq(listify([tensor(0),tensor(1)]),[tensor(0),tensor(1)])
test_eq(listify(tensor([0.,1.1])),[0,1.1])

In [None]:
# export
def order_sorted(funcs, order_key='_order'):
    key = lambda o: getattr(o, order_key, 0)
    return sorted(listify(funcs), key=key)

def apply_all(x, funcs, *args, order_key='_order', **kwargs):
    "Apply all `funcs` to `x` in order, pass along `args` and `kwargs`."
    for f in order_sorted(funcs, order_key=order_key): x = f(x, *args, **kwargs)
    return x

In [None]:
# test
# basic behavior
def _test_f1(x, a=2): return x**a
def _test_f2(x, a=2): return a*x
test_eq(apply_all(2, [_test_f1, _test_f2]), 8)
# order
_test_f1._order = 1
test_eq(apply_all(2, [_test_f1, _test_f2]), 16)
#args
test_eq(apply_all(2, [_test_f1, _test_f2], 3), 216)
#kwargs
test_eq(apply_all(2, [_test_f1, _test_f2], a=3), 216)

In [None]:
# export
def uniqueify(x, sort=False):
    "Return the unqiue elements in `x`, optionally `sort`-ed."
    res = list(OrderedDict.fromkeys(x).keys())
    if sort: res.sort()
    return res

In [None]:
# test
test_eq(set(uniqueify([1,1,0,5,0,3])), {0,1,3,5})
test_eq(uniqueify([1,1,0,5,0,3], sort=True), [0,1,3,5])

In [None]:
# export
def setify(o): return o if isinstance(o,set) else set(listify(o))

In [None]:
# test
test_eq(setify(None), set())
test_eq(setify('abc'), {'abc'})
test_eq(setify([1,2,2]), {1,2})
test_eq(setify(range(0,3)), {0,1,2})
test_eq(setify({1,2}), {1,2})

In [None]:
# export
def onehot(x, c):
    "Return the one-hot encoded tensor for `x` with `c` classes."
    res = torch.zeros(c)
    res[x] = 1.
    return res

In [None]:
# test
test_eq(onehot(1,5), tensor([0.,1.,0.,0.,0.]))
test_eq(onehot([1,3],5), tensor([0.,1.,0.,1.,0.]))
test_eq(onehot(tensor([1,3]),5), tensor([0.,1.,0.,1.,0.]))
test_eq(onehot([True,False,True,True,False],5), tensor([1.,0.,1.,1.,0.]))
test_eq(onehot([],5), tensor([0.,0.,0.,0.,0.]))

In [None]:
# export
def _get_files(p, fs, extensions=None):
    p = Path(p)
    res = [p/f for f in fs if not f.startswith('.')
           and ((not extensions) or f'.{f.split(".")[-1].lower()}' in extensions)]
    return res

def get_files(path, extensions=None, recurse=False, include=None):
    "Get all the files in `path` with optional `extensions`."
    path = Path(path)
    extensions = setify(extensions)
    extensions = {e.lower() for e in extensions}
    if recurse:
        res = []
        for i,(p,d,f) in enumerate(os.walk(path)): # returns (dirpath, dirnames, filenames)
            if include is not None and i==0: d[:] = [o for o in d if o in include]
            else:                            d[:] = [o for o in d if not o.startswith('.')]
            res += _get_files(p, f, extensions)
        return res
    else:
        f = [o.name for o in os.scandir(path) if o.is_file()]
        return _get_files(path, f, extensions)

In [None]:
# test
path = untar_data(URLs.MNIST_TINY)
test_eq(len(get_files(path/'train'/'3')), 346)
test_eq(len(get_files(path/'train'/'3', extensions='.png')), 346)
test_eq(len(get_files(path/'train'/'3', extensions='.jpg')), 0)
test_eq(len(get_files(path/'train', extensions='.png')), 0)
test_eq(len(get_files(path/'train', extensions='.png', recurse=True)), 709)
test_eq(len(get_files(path, extensions='.png', recurse=True, include=['train'])), 709)
test_eq(len(get_files(path, extensions='.png', recurse=True, include=['train', 'test'])), 729)

In [None]:
# export
def grab_idx(batch, i):
    "Return the `i`-th sample in `batch`"
    return [grab_idx(b,i) for b in batch] if isinstance(batch, (list,tuple)) else batch[i].detach().cpu()

In [None]:
# test
test_eq(grab_idx(tensor([1,2]), 1), 2)
test_eq(grab_idx([tensor([1,2]), tensor([3,4])], 1), [2,4])

In [None]:
# export
def read_column(df, col_name, prefix='', suffix='', delim=None):
    "Read `col_name` in `df`, optionnally adding `prefix` or `suffix`."
    values = df[col_name].values.astype(str)
    values = np.char.add(np.char.add(prefix, values), suffix)
    if delim is not None:
        values = np.array(list(csv.reader(values, delimiter=delim)))
    return values

In [None]:
# test
df = pd.DataFrame({'a': ['cat', 'dog', 'car'], 'b': ['a b', 'c d', 'a e']})
test_eq(read_column(df, 'a'), np.array(['cat', 'dog', 'car']))
test_eq(read_column(df, 'a', prefix='o'), np.array(['ocat', 'odog', 'ocar']))
test_eq(read_column(df, 'a', suffix='.png'), np.array(['cat.png', 'dog.png', 'car.png']))
test_eq(read_column(df, 'b', delim=' '), np.array([['a','b'], ['c','d'], ['a','e']]))

### Transform

Behing the scenes there is no more open/get method or processors that transforms our raw items to the tensors we feed the model but only transforms. 

One transform will open an image, once transform will convert it to RGB, one transform will resize it, one transform (actually two) will convert it to a tensor. For the labels, one transform will transform class name to an index. Or one list of labels to a one-hot encoded vector. That transformation is done in `__call__`.

This means that some transforms need a preliminary step to get ready: for instance the transform that deal with labels needs to identify the different classes in the training set. We do this kind of things in the `setup` method.

Also, for display purposes, some transforms need to be reversed: we want to display the class name and not the index of the class for instance. This is done in the `undo` method. Note that a lot of transforms won't need to implement that method: you don't want to undo opening the image or a data augmentation transform. However, you want to undo the one-hot encoding of labels, or the transform that puts the channel dimension at 1 because in both cases, you need this reveresed for display.

In [None]:
# export
class Transform():
    "A basic class to transform some data."
    _order=0
    def __call__(self, o):  return o
    def undo(self, o):      return o
    def setup(self, items): pass

Then when we want to describe a class of items, we will specify default transforms that go with it: for instance, an `Image` will have a transforms that opens it by default. A `Cateogry` will have a transforms that encodes classes to index. This is given in the `default_tfms` argument (you give one or a list of `Transform` classes). `init_tfms` will initiliaze them.

Why is there `default_tfms` and `default_tfms_xy`? Well that's because transforms can be applied at two different level: the item level (item being input or target) or the tuple (input,target) level. Why the separation and not just do the latter? (a transfrom that applies to input could be written as return (tfm(input), label) after all) That would break the flexibility of the data block API in which each `Item` class can be used as input or target. For instance an `Image` can be used as x or y, and in each case, it will have its default transform that opens it.

Then why have `xy` transforms? Sometimes, you need to know things about the `x` to apply your transform to the `y` (or the opposite) so some transforms can only be written at that level. Note that if you specify a `default_tfms_xy` you break the data block API in the sense that your class of Item will only be able to be used as input or target. For instance, Bbox can only be used as target, as their transform `ScaleBbox` is a `xy` transform: it needs to know `x` (specifically its dimensions) to be able to be applied to `y`. 

Thos are very specific cases (but we need to handle them). In general, the transforms applied to the `xy` level will be the data augmentation ones.

In [None]:
# export
class Item():
    "A basic class for representing some data type."
    default_tfms = None
    default_tfms_xy = None
    
    def init_tfms(self, xy=False):
        return [t() for t in listify(self.default_tfms_xy if xy else self.default_tfms)]

### Data block API core

ItemList is a general class to contain all the items (item being input or target here). It's intended to be a final class. You initiliaze it with some `items`, `tfms`, `item_type` and some `tfm_kwargs`. When you access one (or several) element, `tfms` are applied to it with `kwargs` being passed. Their `_order` attribute determines in which order (note that we sort them at init because we need them in order for undo). Supported indexing includes int, collection of ints or boolean masks.

To quickly get the `undo` method on all transforms applied to an object `o`, there is the `deproc` method. Note that the transforms `undo` methods are applied in the reverse order of the one transforms were applied, so that it can be chained properly.

Lastly the `show` method will display the item in index `i` by indexing into it (which calls all transforms) then calling `deproc` on the result (which will reverse the transforms that need to be reversed) and calling the `show` method of `item_type`.

In [None]:
# export
class ItemList():
    def __init__(self, items, tfms=None, **tfm_kwargs):
        self.items,self.tfms,self.tfm_kwargs = items,order_sorted(tfms),tfm_kwargs
        for tfm in self.tfms: getattr(tfm, 'setup', noop)(self.items)
    def _get(self, i): return apply_all(i, self.tfms, **self.tfm_kwargs)
    def __getitem__(self, idx):
        try: return self._get(self.items[idx])
        except TypeError:
            if isinstance(idx[0],bool):
                assert len(idx)==len(self) # bool mask
                return [self._get(o) for m,o in zip(idx,self.items) if m]
            return [self._get(self.items[i]) for i in idx]
    def __len__(self): return len(self.items)
    def __iter__(self): return iter(self.items)
    def __setitem__(self, i, o): self.items[i] = o
    def __delitem__(self, i): del(self.items[i])
    def __repr__(self):
        res = f'{self.__class__.__name__} ({len(self)} items)\n{self.items[:10]}'
        if len(self)>10: res = res[:-1]+ '...]'
        return res
    
    def deproc(self, o): 
        "Reverse transforms on `o`."
        return apply_all(o, [getattr(t, 'undo', noop) for t in reversed(self.tfms)], **self.tfm_kwargs)

In [None]:
# test
def add(x, a=1): return x+a
def multiply(x, a=2): return x*a
def square(x): return x**2

#tfms can be basic functions
il = ItemList([0,1,2,3], tfms=[add, multiply, square])
#Test indexing
test_eq(il[1], ((1+1) * 2) ** 2)
test_eq(il[1,2,3], [(((x+1) * 2) ** 2) for x in [1,2,3]])
test_eq(il[True,False,False,True], [(((x+1) * 2) ** 2) for x in [0,3]])

#Test _order
square._order = 0
multiply._order = 1
add._order = 2
il = ItemList([0,1,2,3], tfms=[add, multiply, square])
test_eq(il[2], ((2**2) * 2) + 1)

#Test kwargs
il = ItemList([0,1,2,3], tfms=[add, multiply], a=3)
test_eq(il[2], (2 * 3) + 3)

#Test undo
def add_undo(x, a=1): return x-a
def multiply_undo(x, a=2): return x/a
add.undo = add_undo
multiply.undo = multiply_undo
il = ItemList([0,1,2,3], tfms=[add, multiply, square])
test_eq(il.deproc(9), (9-1)/2)
il = ItemList([0,1,2,3], tfms=[add, multiply], a=3)
test_eq(il.deproc(9), (9-3)/3)

To contain the tuples (x,y) we use the following class. It's initialized with two `ItemList` `x` and `y`, with a set of `tfms` and some `tfm_kwargs`. When accessing an index, it returns the result of the transforms applied to the corresponding (x,y)  with the `tfm_kwargs` (which are again sorted by their `_order` key in the init). 

Like in an `ItemList`, `deproc` will call `undo` on all transforms in reverse order, so first the transforms of this object, then the transforms on the `x` and `y` ItemList level. 

The `show` method will display the `x` and `y` at a given index on a given `ax`.

In [None]:
# export
class LabeledData():
    def __init__(self, x, y, tfms=None, **tfm_kwargs): 
        self.x,self.y,self.tfms,self.tfm_kwargs = x,y,order_sorted(tfms),tfm_kwargs
        for tfm in listify(self.tfms): getattr(tfm, 'setup', noop)(self)
    def __repr__(self):        return f'{self.__class__.__name__}\nx: {self.x}\ny: {self.y}\n'
    def __getitem__(self,idx): return apply_all((self.x[idx],self.y[idx]), self.tfms, **self.tfm_kwargs)
    def __len__(self):         return len(self.x)
    
    def deproc(self, o):
        "Reverse transforms on `o`."
        (x,y) = apply_all(o, [getattr(t, 'undo', noop) for t in reversed(self.tfms)], **self.tfm_kwargs)
        return (self.x.deproc(x),self.y.deproc(y))

In [None]:
# test
def add_sub(o, a=2): return (o[0]-a, o[1]+a)
x = ItemList([0,1,2,3], tfms=[multiply, square])
y = ItemList([0,1,2,3], tfms=[add, multiply])
ld = LabeledData(x, y, tfms=add_sub)
test_eq(ld[2], ((2**2) * 2 - 2, (2*2 + 1) + 2))

ld = LabeledData(x, y, tfms=add_sub, a=3)
test_eq(ld[2], ((2**2) * 2 - 3, (2*2 + 1) + 3))

def undo_add_sub(o, a=2): return (o[0]+a, o[1]-a)
ld = LabeledData(x, y, tfms=add_sub)
test_eq(ld.deproc((8,6)), (8/2, (6-1)/2))

add_sub.undo = undo_add_sub
ld = LabeledData(x, y, tfms=add_sub)
test_eq(ld.deproc((8,6)), ((8+2)/2, ((6-2) - 1)/2))

### First Items

How to get images

In [None]:
class ImageOpener(Transform):
    def __call__(self, o): return PIL.Image.open(o)

In [None]:
# export
class Image(Item):
    default_tfms = ImageOpener
    def __init__(self, cmap=None, alpha=1.): self.cmap,self.alpha = cmap,alpha

How to get categories

In [None]:
class CategoryEncoder(Transform):
    "Encodes a categorical variable to index."
    def __init__(self): self.vocab=None
    
    def setup(self, items):
        if self.vocab is not None: return
        self.vocab = uniqueify(items, sort=True)
        self.otoi  = {v:k for k,v in enumerate(self.vocab)}
    
    def __call__(self, o): return self.otoi[o]
    def undo(self, i):    return self.vocab[i]

In [None]:
# export
class Category(Item):
    default_tfms = CategoryEncoder

In [None]:
tfm = CategoryEncoder()
#Even if 'dog' is the first class, vocab is sorted for reproducibility
tfm.setup(['dog', 'cat', 'cat', 'dog', 'cat', 'dog'])
test_eq(tfm.vocab,['cat', 'dog'])
test_eq([tfm(o) for o in ['dog', 'cat', 'cat']], [1,0,0])
test_eq(tfm('cat'),0)
test_eq(tfm.undo(1),'dog')

### Helper functions

To get image files

In [None]:
# export
image_extensions = set(k for k,v in mimetypes.types_map.items() if v.startswith('image/'))

In [None]:
# export
def get_image_files(path, include=None):
    "Get image files in `path` recursively."
    return get_files(path, extensions=image_extensions, recurse=True, include=include)

In [None]:
# test
path = untar_data(URLs.MNIST_TINY)
test_eq(len(get_image_files(path)),1428)
test_eq(len(get_image_files(path/'train')),709)
test_eq(len(get_image_files(path, include='train')),709)
test_eq(len(get_image_files(path, include=['train','valid'])),1408)

To split

In [None]:
# export
def random_splitter(items, valid_pct=0.2, seed=None):
    "Split `items` between train/val with `valid_pct` randomly."
    if seed is not None: torch.manual_seed(seed)
    rand_idx = torch.randperm(len(items))
    cut = int(valid_pct * len(items))
    return rand_idx[cut:],rand_idx[:cut]

In [None]:
#test
trn,val = random_splitter([0,1,2,3,4,5], seed=42)
test_eq(trn, tensor([3, 2, 4, 1, 5]))
test_eq(val, tensor([0]))

In [None]:
# export
def _grandparent_mask(items, name):
    return [(o.parent.parent.name if isinstance(o, Path) else o.split(os.path.sep)[-2]) == name for o in items]

def grandparent_splitter(items, train_name='train', valid_name='valid'):
    "Split `items` from the grand parent folder names (`train_name` and `valid_name`)."
    return _grandparent_mask(items, train_name),_grandparent_mask(items, valid_name)

In [None]:
path = untar_data(URLs.MNIST_TINY)

In [None]:
#test
#With string filenames
path = untar_data(URLs.MNIST_TINY)
items = [path/'train'/'3'/'9932.png', path/'valid'/'7'/'7189.png', 
         path/'valid'/'7'/'7320.png', path/'train'/'7'/'9833.png',  
         path/'train'/'3'/'7666.png', path/'valid'/'3'/'925.png',
         path/'train'/'7'/'724.png', path/'valid'/'3'/'93055.png']
trn,val = grandparent_splitter(items)
test_eq(trn,[True,False,False,True,True,False,True,False])
test_eq(val,[False,True,True,False,False,True,False,True])

To label

In [None]:
# export
def parent_labeller(items):
    "Label `items` with the parent folder name."
    return [o.parent.name if isinstance(o, Path) else o.split(os.path.sep)[-1] for o in items]

In [None]:
# test
test_eq(parent_labeller(items),['3','7','7','7','3','3','7','3'])

In [None]:
# export
def func_labeller(items, func):
    "Label `items` according to `func`."
    return [func(o) for o in items]

In [None]:
# test
test_eq(func_labeller(items, lambda x: int(x.parent.name)+1),[4,8,8,8,4,4,8,4])

In [None]:
# export
def re_labeller(items, pat):
    "Label `items` with a regex `pat`."
    pat = re.compile(pat)
    def _inner(o):
        res = pat.search(str(o))
        assert res,f'Failed to find "{pat}" in "{o}"'
        return res.group(1)
    return func_labeller(items, _inner)

In [None]:
# test
pat = re.compile(r'/([^/]+)/\d+.png$')
test_eq(re_labeller(items, pat),['3','7','7','7','3','3','7','3'])

### First data:

Integration test

In [None]:
path = untar_data(URLs.PETS)
items = get_image_files(path/'images')
splits = random_splitter(items)
labels = re_labeller(items, pat = r'/([^/]+)_\d+.jpg$')

To split the data in train valid, we use a basic ItemList for fancy indexing.

In [None]:
items = ItemList(items)
labels = ItemList(labels)

For the xs we need to use the image opener transform

In [None]:
x_train,x_valid = map(lambda s: ItemList(items[s], tfms=ImageOpener()), splits)

For the ys we need to use the category encoder transform

In [None]:
y_train,y_valid = map(lambda s: ItemList(labels[s], tfms=CategoryEncoder()), splits)

And now we can construct our training and validation LabeledData objects.

In [None]:
train = LabeledData(x_train, y_train)
valid = LabeledData(x_valid, y_valid)

In [None]:
t,i = train[0]

In [None]:
t,i

In [None]:
img,cls = train.deproc(train[0])

In [None]:
_,ax = plt.subplots(1,1)
ax.imshow(img)
ax.axis('off')
ax.set_title(cls);

### Handling display

In [None]:
def show_img(x, ax=None):
    if ax is None: _,ax = plt.subplots(1,1)
    ax.imshow(x)
    ax.axis('off')
    
def show_cat(x, ax=None):
    if ax is None: print(x)
    else: ax.set_title(x)

In [None]:
show_cat(cls)

In [None]:
def display(ld, idx, show_x, show_y, **kwargs):
    (x,y) = ld.deproc(ld[idx])
    show_x(x, **kwargs)
    show_y(y, **kwargs)

In [None]:
_,ax = plt.subplots(1,1)
display(train, 0, show_img, show_cat, ax=ax)

### Refactor

An Item class has default transforms (no need to pass ImageOpener() for Image) and can have a show method (that we can always replace).

In [None]:
# export
class Image(Item):
    default_tfms = ImageOpener
    
    def show(self, x, ax=None, cmap=None, alpha=1):
        if ax is None: _,ax = plt.subplots(1,1)
        ax.imshow(x, cmap=cmap, alpha=alpha)
        ax.axis('off')
        return {'ax': ax}

class Category(Item):
    default_tfms = CategoryEncoder
    def show(self, x, ax): ax.set_title(x)

In [None]:
def display(ld, idx, types, show_x=None, show_y=None, **kwargs):
    (x,y) = ld.deproc(ld[idx])
    if show_x is None: show_x = types[0].show
    if show_y is None: show_y = types[1].show
    support = show_x(x, **kwargs)
    if support is None: support = {}
    show_y(y, **{**kwargs, **support})

In [None]:
display(train, 0, (Image(),Category()))

### DataBlock

Main class to represent any kind of data. User provides the type of inputs/targets in `type_cls`. Then they implement the four functions to gather the data:
- `get_source` returns the source of the data after potentially downloading it.
- `get_items` takes the source and retun the list of all items
- `split` take the items and returns two (or more) list of indices or boolean masks that explain how to split the data in train and valid (potentially valids) set.
- `label` take the items and returns a list of targets.

Then during the intilialization, the `source` is fetched by calling `get_source`, which then allows to collect the items (with `get_items`), the different splits (with `split`) and the labels (with `label`). Default transforms for `x` and `y` are collected (they can be overriden by a custom `tfms_x` or `tfms_y` passed to the init) and the different ItemList for each split x/y are created.

Then the default transforms for `xy` are collected by looking at the types of x and y (they can be overriden by a custom `tfms_xy` passed to the init), the `tfms` passed are added, and the xs/ys are grouped in the various datasets (which are all `LabeledData`). The `tfm_kwargs` are passed at this level only (so they'll only be passed along to xy transforms).

In [None]:
# export
class DataBlock():
    "Main class to represent a dataset. Subclass the 2 properties and 4 methods below to your need."
    type_cls = (Item,Item) #Type of input,target
    def get_source(self):         
        "Return the source of your data (path, dataframe...), optionally download it."
        raise NotImplementedError
    def get_items(self, source):  
        "Use `source` to return the list of all items."
        raise NotImplementedError
    def split(self, items):       
        "Explain how so split the `items`. Return two (or more) lists of indices/boolean masks."
        raise NotImplementedError
    def label(self, items):       
        "Explain how to label your `items`. Return a list of labels."
        raise NotImplementedError
        
    def __init__(self, tfms=None, tfms_x=None, tfms_y=None, tfms_xy=None, **tfm_kwargs):
        self.source = self.get_source()
        items = ItemList(self.get_items(self.source)) #Just for fancy indexing
        split_idx = self.split(items)
        labels = ItemList(self.label(items))          #Just for fancy indexing
        self.types = self.type_cls[0](),self.type_cls[1]()
        if tfms_x is None: tfms_x = self.types[0].init_tfms()
        if tfms_y is None: tfms_y = self.types[1].init_tfms()
        xs = map(lambda o: ItemList(items[o],  tfms=tfms_x), split_idx)
        ys = map(lambda o: ItemList(labels[o], tfms=tfms_y), split_idx)
        if tfms_xy is None: tfms_xy = self.types[0].init_tfms(xy=True) + self.types[1].init_tfms(xy=True)
        self.datasets = [LabeledData(x, y, tfms=listify(tfms) + tfms_xy, **tfm_kwargs) 
                         for (x,y) in zip(xs, ys)]
        
    @property
    def train(self): return self.datasets[0]
    @property
    def valid(self): return self.datasets[1]

In [None]:
class PetsData(DataBlock):
    type_cls = (Image, Category)
    
    def get_source(self):        return untar_data(URLs.PETS)
    def get_items(self, source): return get_image_files(source/"images")
    def split(self, items):      return random_splitter(items)
    def label(self, items):      return re_labeller(items, pat = r'/([^/]+)_\d+.jpg$')

In [None]:
data = PetsData()

In [None]:
display(data.train, 1, data.types)

## Transforms

In [None]:
# export
TfmY = Enum('TfmY', 'No Mask Image Point Bbox')

In [None]:
# export
class ImageTransform(Transform): 
    "Basic class for data augmentation transforms."
    _order=0
    _tfm_y_func={TfmY.Image: 'apply_img',   TfmY.Mask: 'apply_mask', TfmY.No: 'noop',
                 TfmY.Point: 'apply_point', TfmY.Bbox: 'apply_bbox'}
    _undo_y_func={TfmY.Image: 'unapply_img',   TfmY.Mask: 'unapply_mask', TfmY.No: 'noop',
                  TfmY.Point: 'unapply_point', TfmY.Bbox: 'unapply_bbox'}
    
    def apply(self, x):       return x
    def apply_img(self, y):   return self.apply(y)
    def apply_mask(self, y):  return self.apply_img(y)
    def apply_point(self, y): return y
    def apply_bbox(self, y):  return self.apply_point(y)
    
    def randomize(self): pass
    
    def __call__(self, o, tfm_y=TfmY.No):
        (x,y) = o
        self.x = x #Saves the x in case it's needed in the apply for y (x.size for apply_point for instance)
        self.randomize() #Ensures we have the same state for x and y
        return self.apply(x),getattr(self, self._tfm_y_func[tfm_y], noop)(y)
    
    def unapply(self, x):       return x
    def unapply_img(self, y):   return self.unapply(y)
    def unapply_mask(self, y):  return self.unapply_img(y)
    def unapply_point(self, y): return y
    def unapply_bbox(self, y):  return self.unapply_point(y)
    
    def undo(self, o, tfm_y=TfmY.No):
        (x,y) = o
        return self.unapply(x),getattr(self, self._undo_y_func[tfm_y], noop)(y)

In [None]:
#export
class DecodeImg(ImageTransform):
    "Convert regular image to RGB, masks to L mode."
    def __init__(self, mode_x='RGB', mode_y=None):
        self.mode_x,self.mode_y = mode_x,mode_y
        
    def apply(self, x):       return x.convert(self.mode_x)
    def apply_image(self, y): return y.convert(self.mode_x if self.mode_y is None else self.mode_y)
    def apply_mask(self, y):  return y.convert('L' if self.mode_y is None else self.mode_y)

In [None]:
#export
class ResizeFixed(ImageTransform):
    "Resize image to `size` using `mode_x` (and `mode_y` on targets)."
    _order=10
    def __init__(self, size, mode_x=PIL.Image.BILINEAR, mode_y=None):
        if isinstance(size,int): size=(size,size)
        size = (size[1],size[0]) #PIL takes size in the otherway round
        self.size,self.mode_x,self.mode_y = size,mode_x,mode_y
        
    def apply(self, x):       return x.resize(self.size, self.mode_x)
    def apply_image(self, y): return y.resize(self.size, self.mode_x if self.mode_y is None else self.mode_y)
    def apply_mask(self, y):  return y.resize(self.size, PIL.Image.NEAREST if self.mode_y is None else self.mode_y)

In [None]:
#export
class ToByteTensor(ImageTransform):
    "Transform our items to byte tensors."
    _order=20
    
    def apply(self, x):
        res = torch.ByteTensor(torch.ByteStorage.from_buffer(x.tobytes()))
        w,h = x.size
        return res.view(h,w,-1).permute(2,0,1)
    
    def unapply(self, x): return x[0] if x.shape[0] == 1 else x.permute(1,2,0)

In [None]:
#export
class ToFloatTensor(ImageTransform):
    "Transform our items to float tensors (int in the case of mask)."
    _order=20
    def __init__(self, div_x=255., div_y=None): self.div_x,self.div_y = div_x,div_y
    def apply(self, x):      return x.float().div_(self.div_x)
    def apply_mask(self, x): 
        return x.long() if self.div_y is None else x.long().div_(self.div_y)

In [None]:
tfms = [DecodeImg(), ResizeFixed(128), ToByteTensor(), ToFloatTensor()]

Integration test

In [None]:
data = PetsData(tfms=tfms)

In [None]:
display(data.train, 1, data.types)

### Dataloader and DataBunch

In [None]:
# export
from torch.utils.data.dataloader import DataLoader
def get_dl(ds, bs, shuffle=False, drop_last=False, **kwargs):
    "Basic function to get a `DataLoader`"
    return DataLoader(ds, batch_size=bs, shuffle=shuffle, drop_last=drop_last, **kwargs)

In [None]:
# export
class DataBunch():
    "Basic wrapper around several `DataLoader`."
    def __init__(self, *dls, types=None):
        self.dls,self.types = dls,types
        
    def show_batch(self, dl_idx=0, items=9, show_xy=None, show_x=None, show_y=None):
        if show_x is None:  show_x  = self.types[0].show
        if show_y is None:  show_y  = self.types[1].show
        if show_xy is None: show_xy = self.types[0].show_xy
        xb, yb = next(iter(self.dl[dl_idx]))
         = [self.dl[dl_idx].dataset.]
        
    @property
    def train_dl(self): return self.dls[0]
    @property
    def valid_dl(self): return self.dls[1]
    @property
    def train_ds(self): return self.train_dl.dataset
    @property
    def valid_ds(self): return self.valid_dl.dataset

In [None]:
# export
def _db_databunch(self, bs=64, **kwargs):
    dls = [get_dl(ds, bs, shuffle=(i==0), drop_last=(i==0), **kwargs) for ds in self.datasets]
    return DataBunch(*dls, types=self.types)

DataBlock.databunch = _db_databunch

## Try different data

### MNIST

In [None]:
class MnistData(DataBlock):
    x_cls = Image
    y_cls = Category
    
    def get_source(self):        return untar_data(URLs.MNIST)
    def get_items(self, source): return get_image_files(source)
    def split(self, items):      return grandparent_splitter(items, train_name='training', valid_name='testing')
    def label(self, items):      return parent_labeller(items)

In [None]:
data = MnistData(tfms=[ToByteTensor(), ToFloatTensor()])

In [None]:
data.train.show(1)

cmap is specified in the `item_type` for inputs.

In [None]:
data.train.x.item_type.cmap='gray'

In [None]:
data.train.show(1)

### Planet

In [None]:
path = untar_data(URLs.PLANET_SAMPLE)

In [None]:
path.ls()

In [None]:
df = pd.read_csv(path/'labels.csv')

In [None]:
df.head()

In [None]:
class MultiCategoryEncoder(Transform):
    "Encodes a categorical variable to index."
    def __init__(self, do_encode=True, classes=None): 
        assert do_encode or classes is not None, "If you use one_hot encoded items, please provide classes."
        self.vocab,self.do_encode=None,do_encode
        self.vocab = classes
    
    def setup(self, items):
        if self.vocab is not None: return
        vocab = set()
        for c in items: vocab = vocab.union(set(c))
        self.vocab = list(vocab)
        self.vocab.sort()
        self.otoi  = {v:k for k,v in enumerate(self.vocab)}
    
    def __call__(self, item): 
        if not self.do_encode: return item
        return onehot([self.otoi[o] for o in item if o in self.otoi], len(self.vocab))
    
    def undo(self, o): return [self.vocab[i] for i,v in enumerate(o) if v==1.]

In [None]:
# export
class MultiCategory(Item):
    default_tfms = MultiCategoryEncoder
    def show(self, x, ax): ax.set_title(';'.join(x))

In [None]:
# test
tfm = MultiCategoryEncoder()
#Even if 'c' is the first class, vocab is sorted for reproducibility
tfm.setup([['c','a'], ['a','b'], ['b']])
test_eq(tfm.vocab,['a','b','c'])

test_eq(tfm(['b','a']),tensor([1.,1.,0.]))
test_eq(tfm.undo(tensor([1.,0.,1.])),['a','c'])

In [None]:
class PlanetData(DataBlock):
    x_cls = Image
    y_cls = MultiCategory
    
    def get_source(self):        
        self.path = untar_data(URLs.PLANET_SAMPLE)
        return pd.read_csv(path/'labels.csv')
    def get_items(self, source): return read_column(source, 'image_name', prefix=f'{self.path}/train/', suffix='.jpg')
    def split(self, items):      return random_splitter(items)
    def label(self, items):      return read_column(self.source, 'tags', delim=' ')

In [None]:
data = PlanetData(tfms=tfms)

In [None]:
data.train.show(0)

In [None]:
classes = data.train.y.tfms[0].vocab
otoi = {s:i for i,s in enumerate(classes)}

In [None]:
class PlanetData1(DataBlock):
    x_cls = Image
    y_cls = MultiCategory
    
    def get_source(self):        
        self.path = untar_data(URLs.PLANET_SAMPLE)
        return pd.read_csv(path/'labels.csv')
    def get_items(self, source): return read_column(source, 'image_name', prefix=f'{self.path}/train/', suffix='.jpg')
    def split(self, items):      return random_splitter(items)
    def label(self, items):  
        #This is just for the sake of using one-hot encoded labels, but imagine we have a dataset where it's the case.
        tags = read_column(self.source, 'tags', delim=' ')
        labels = []
        for t in tags:
            x = torch.zeros(len(classes))
            idx = [otoi.get(l,None) for l in t]
            idx = [i for i in idx if i is not None]
            x[idx] = 1.
            labels.append(x)
        return labels

In [None]:
data = PlanetData1(tfms=tfms, tfms_y=MultiCategoryEncoder(do_encode=False, classes=classes))

In [None]:
data.train.show(0)

### Camvid

In [None]:
# export
class SegmentMask(Image):
    "An `ItemGetter` for segmentation mask targets."
    def __init__(self, cmap='tab20', alpha=0.5): 
        super().__init__(cmap=cmap, alpha=alpha)

In [None]:
class CamvidData(DataBlock):
    x_cls = Image
    y_cls = SegmentMask
    
    def get_source(self):        return untar_data(URLs.CAMVID_TINY)      
    def get_items(self, source): return get_image_files(source/'images')
    def split(self, items):      return random_splitter(items)
    def label(self, items):      
        path_lbl = self.source/'labels'
        codes = np.loadtxt(self.source/'codes.txt', dtype=str)
        return func_labeller(items, lambda x: path_lbl/f'{x.stem}_P{x.suffix}')

In [None]:
data = CamvidData(tfms=tfms, tfm_y=TfmY.Mask)

In [None]:
data.train.show(0)

### Biwii

In [None]:
import pickle

In [None]:
class PointScaler(Transform):
    _order = -10 #Run before we apply any ImageTransform
    def __init__(self, do_scale=True, y_first=False): 
        self.do_scale,self.y_first = do_scale,y_first
    
    def __call__(self, o, tfm_y=TfmY.No):
        (x,y) = o
        if not isinstance(y, torch.Tensor): y = tensor(y)
        y = y.view(-1, 2).float()
        if not self.y_first: y = y.flip(1)
        if self.do_scale: y = y * 2/tensor(list(x.size)).float() - 1
        return (x,y)
    
    def undo(self, o, tfm_y=TfmY.No):
        (x,y) = o
        y = y.flip(1)
        y = (y + 1) * tensor([x.shape[:2]]).float()/2
        return (x,y)

In [None]:
# export
class Points(Item):
    default_tfms_xy = PointScaler
    
    def show(self, x, ax):
        params = {'s': 10, 'marker': '.', 'c': 'r'}
        ax.scatter(x[:,1], x[:,0], **params)

In [None]:
class FakeImg():
    def __init__(self, size): self.size,self.shape = size,(size[1],size[0])

il = ItemList([[0,0], [120,0], [0,200], [120,200], [60,100]], item_type=Points())
#At the ItemList level, there is no transform happening.
test_eq(il[1], [120,0])
#The transform is applied when getting the items at the xy level.
ll = LabeledData(ItemList([FakeImg((200,120)) for _ in range(5)]), il, tfms=PointScaler())
test_eq(ll[1][1], tensor([[-1., 1.]]))
test_eq(ll[4][1], tensor([[0., 0.]]))
o = ll[2]
#Test deproc undoes the scaling and switching
test_eq(ll.deproc(o)[1], tensor([[0., 200.]]))

#Giving scaled points
il1 = ItemList([[-1.,-1.], [1.,-1.], [-1.,1.], [1.,1.], [0.,0.]], item_type=Points())
ll1 = LabeledData(ItemList([FakeImg((200,120)) for _ in range(5)]), il1, tfms=PointScaler(do_scale=False))
for i in range(5): 
    o,o1 = ll[i],ll1[i]
    test_eq(o[1], o1[1])
    test_eq(ll.deproc(o)[1], ll1.deproc(o1)[1])
    
#Giving scaled points with y_first=True
il2 = ItemList([[-1.,-1.], [-1.,1.], [1.,-1.], [1.,1.], [0.,0.]], item_type=Points())
ll2 = LabeledData(ItemList([FakeImg((200,120)) for _ in range(5)]), il2, tfms=PointScaler(do_scale=False, y_first=True))
for i in range(5): 
    o,o2 = ll[i],ll2[i]
    test_eq(o[1], o2[1])
    test_eq(ll.deproc(o)[1], ll2.deproc(o2)[1])

In [None]:
class BiwiData(DataBlock):
    x_cls = Image
    y_cls = Points
    
    def get_source(self):        return untar_data(URLs.BIWI_SAMPLE)      
    def get_items(self, source): return get_image_files(source/'images')
    def split(self, items):      return random_splitter(items)
    def label(self, items):      
        fn2ctr = pickle.load(open(self.source/'centers.pkl', 'rb'))
        return func_labeller(items, lambda o:fn2ctr[o.name])

In [None]:
data = BiwiData(tfms=tfms, tfm_y=TfmY.Point)

In [None]:
data.train.show(5)

### Coco

In [None]:
#export
from fastai.vision.data import get_annotations

In [None]:
class BBoxScaler(PointScaler):
     
    def __call__(self, o, tfm_y=TfmY.Bbox): 
        (x,y) = o
        return x, (super().__call__((x,y[0])).view(-1,4),y[1])
    def undo(self, o, tfm_y=TfmY.Bbox):     
        (x,y) = o
        _,bbox = super().undo((x,y[0].view(-1,2)))
        return x, (bbox.view(-1,4),y[1])

In [None]:
class BBoxEncoder(MultiCategoryEncoder):
    def setup(self, items):
        if self.vocab is not None: return
        super().setup([c[1] for c in items])
        self.vocab.insert(0, 'background')
        self.otoi  = {v:k for k,v in enumerate(self.vocab)}

In [None]:
#export 
from matplotlib import patches, patheffects

def _draw_outline(o, lw):
    o.set_path_effects([patheffects.Stroke(linewidth=lw, foreground='black'), patheffects.Normal()])

def _draw_rect(ax, b, color='white', text=None, text_size=14):
    patch = ax.add_patch(patches.Rectangle(b[:2], *b[-2:], fill=False, edgecolor=color, lw=2))
    _draw_outline(patch, 4)
    if text is not None:
        patch = ax.text(*b[:2], text, verticalalignment='top', color=color, fontsize=text_size, weight='bold')
        _draw_outline(patch,1)

In [None]:
# export
class BBox(Item):
    default_tfm = BBoxEncoder
    default_tfm_xy = BBoxScaler
     
    def show(self, x, ax):
        bbox,label = x
        for b,l in zip(bbox, label): 
            if l != 'background': _draw_rect(ax, [b[1],b[0],b[3]-b[1],b[2]-b[0]], text=l)

In [None]:
#export 
def bb_pad_collate(samples, pad_idx=0):
    "Collate function for bounding boxes targets."
    max_len = max([len(s[1][1]) for s in samples])
    bboxes = torch.zeros(len(samples), max_len, 4)
    labels = torch.zeros(len(samples), max_len).long() + pad_idx
    imgs = []
    for i,s in enumerate(samples):
        imgs.append(s[0][None])
        bbs, lbls = s[1]
        if not (bbs.nelement() == 0):
            bboxes[i,-len(lbls):] = bbs
            labels[i,-len(lbls):] = tensor(lbls)
    return torch.cat(imgs,0), (bboxes,labels)

In [None]:
class CocoData(DataBlock):
    x_cls = Image
    y_cls = BBox
    
    def get_source(self):        return untar_data(URLs.COCO_TINY)      
    def get_items(self, source): return get_image_files(source/'train')
    def split(self, items):      return random_splitter(items)
    def label(self, items):      
        images, lbl_bbox = get_annotations(self.source/'train.json')
        img2bbox = dict(zip(images, lbl_bbox))
        return func_labeller(items, lambda o:img2bbox[o.name])
    
    def databunch(self, bs=64, **kwargs):
        kwargs['collate_fn'] = bb_pad_collate
        return super().databunch(bs=bs, **kwargs)

In [None]:
data = CocoData(tfms=tfms, tfm_y=TfmY.Bbox)

In [None]:
data.train.show(1)

## Export

In [None]:
! python notebook2script.py "200_datablock_config.ipynb"