In [1]:
#export
from fastai2.torch_basics import *
from fastai2.data.core import *
from fastai2.data.load import *
from fastai2.data.external import *
from fastai2.data.transforms import *

In [2]:
path = untar_data(URLs.MNIST_TINY,dest='/media/puneet/Data')
path.ls()

(#5) [/media/puneet/Data/mnist_tiny/labels.csv,/media/puneet/Data/mnist_tiny/models,/media/puneet/Data/mnist_tiny/test,/media/puneet/Data/mnist_tiny/train,/media/puneet/Data/mnist_tiny/valid]

# TfmDS / DataSource

A datasource creates a tuple from *items* by applying each list of transforms ( or pipelines )in tfms. Note that if tfms contains only one list of tfms, the items given by *DataSource* will be tuples of one element

In [44]:
class _IntFloatTfm(Transform):
    def encodes(self, o):  return Int(o)
    def decodes(self, o):  return Float(o)
int2f_tfm=_IntFloatTfm()

def _neg(o): return -o
neg_tfm = Transform(_neg, _neg)

In [47]:
items = [1,2,3,4]
dsrc = DataSource(items, [[neg_tfm,int2f_tfm], [add(1)]])

t = dsrc[0]
test_eq(t,(-1,2))
test_eq(dsrc[0,1,2],[(-1, 2),(-2,3),(-3,4)])

In [58]:
class Norm(Transform):
    def encodes(self, o): return (o-self.m)/self.s
    def decodes(self, o): return (o*self.s)+self.m
    def setups(self, items):
        its = tensor(items).float()
        self.m,self.s = its.mean(),its.std()
        

In [61]:
items = [1,2,3,4]
nrm = Norm()
dsrc = DataSource(items, [[neg_tfm, int2f_tfm],[neg_tfm, nrm]])
x, y = zip(*dsrc)
test_eq(nrm.m,(-2.5))

In [71]:
#  Checking splits 
# DataSource contains different TfmList based on split_idx

items = [1,2,3,4]
nrm = Norm()
dsrc = DataSource(items, [[neg_tfm, int2f_tfm],[neg_tfm, nrm]], splits=[[3],[0,1,2]])
x, y = zip(*dsrc)
test_eq(nrm.m,(-4))

Here mean is calculated only on train index automatically. This is because datasource called setup only for train index.

In [72]:
dsrc.valid

(#3) [(-1, tensor(nan)),(-2, tensor(nan)),(-3, tensor(nan))]

In [73]:
dsrc.train

(#1) [(-4, tensor(nan))]

# TfmdList

It is a basic collection of items and transforms. In Datasource, if we pass splits, it creates two tfmdList one representing the train and other valid

In [92]:
items = [1,2,3,4]
nrm = Norm()
tn= TfmdList(items, tfms=[neg_tfm, int2f_tfm],splits=[[3],[0,1,2]],do_setup=False)
tv= TfmdList(items, tfms=[neg_tfm, nrm],splits=[[3],[0,1,2]],do_setup=True)


In [None]:
# add splits to TfmdList
splits = [[0,2],[1]]
tl = TfmdList(items, tfms=tfms, splits=splits)
test_eq(tl.n_subsets, 2)
test_eq(tl.train, tl.subset(0))
test_eq(tl.valid, tl.subset(1))
test_eq(tl.train.items, items[splits[0]])
test_eq(tl.valid.items, items[splits[1]])
test_eq(tl.train.tfms.split_idx, 0)
test_eq(tl.valid.tfms.split_idx, 1)
test_eq_type(tl.splits, L(splits))
assert not tl.overlapping_splits()

THis is what happens at Datasource, it creates two tfmList seperately for each transformation pipeline. Transformation knows how to handle splits

> Only difference between DataSource & TfmdList. DataSource creates tuples from multiple pipelines

# TfmDL

TfmDL is a special dataloader which understands transforms. It is inherited from dataloader but has special methods to decode if necessary

A TfmdDL is a DataLoader that creates Pipeline from a list of Transforms for the callbacks after_item, before_batch and after_batch. As a result, it can decode or show a processed batch.

> It has decode_batch, decode methods, _retain_dl

In [10]:
#Test retain type
class NegTfm(Transform):
    def encodes(self, x): return torch.neg(x)
    def decodes(self, x): return torch.neg(x)
    
tdl = TfmdDL([(TensorImage([1]),)] * 4, after_batch=NegTfm(), bs=4, num_workers=4)
b = tdl.one_batch()
test_eq(type(b[0]), TensorImage)
b = (tensor([1.,1.,1.,1.]),)
test_eq(type(tdl.decode_batch(b)[0][0]), TensorImage)

In [28]:
tdl.one_batch()[0].shape

torch.Size([4, 1])

# Data Bunch

It is just the collection of several data loaders

In [33]:
dbch = DataBunch(tdl , tdl)
x = dbch.train_dl.one_batch()
x2 = first(tdl)
test_eq(x, x2)
x2 = dbch.one_batch()
test_eq(x, x2)

# it uses Getattr which passes all parameter request to "default" variable

In [35]:
class _IntFloatTfm(Transform):
    def encodes(self, o):  return Int(o)
    def decodes(self, o):  return Float(o)
int2f_tfm=_IntFloatTfm()

def _neg(o): return -o
neg_tfm = Transform(_neg, _neg)
tfms = [neg_tfm, int2f_tfm]

items = L([1.,2.,3.]); 
