In [None]:
# default_exp data.mixed

# Mixed data

> DataLoader than can take data from multiple dataloaders with different types of data

In [None]:
#export
from tsai.imports import *

In [None]:
#export
from packaging import version
from fastai.data.load import _FakeLoader
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter,_SingleProcessDataLoaderIter,_DatasetKind
_loaders = (_MultiProcessingDataLoaderIter,_SingleProcessDataLoaderIter)

In [None]:
#export
# This implementation of a mixed dataloader is based on an implementation created by Zach Mueller in this fastai thread:
# https://forums.fast.ai/t/combining-tabular-images-in-fastai2-and-should-work-with-almost-any-other-type/73197

class MixedDL():
    def __init__(self, *dls, device=None):
        "Accepts any number of `DataLoaders` and a device"
        device = ifnone(device, default_device())
        self.device = device
        self.c = []
        bs = 0
        for dl in dls: # ensure all dls have the same bs
            if bs == 0: bsi = dl.bs
            else: dl.bs = bsi
            dl.shuffle_fn = self.shuffle_fn
            if self.c == [] and hasattr(dl, "c"): self.c = dl.c
            dl.to(device=device)
        self.dls = dls
        self.count = 0
        self.fake_l = _FakeLoader(self, False, 0, 0, 0) if version.parse(fastai.__version__) >= version.parse("2.1") else _FakeLoader(self, False, 0, 0)
        self._get_idxs()
        
    def __len__(self): return len(self.dls[0])
    
    def _get_vals(self, x):
        "Checks for duplicates in batches"
        idxs, new_x = [], []
        for i, o in enumerate(x): x[i] = o.cpu().numpy().flatten()
        for idx, o in enumerate(x):
            if not self._arrayisin(o, new_x):
                idxs.append(idx)
                new_x.append(o)
        return idxs
    
    def _get_idxs(self):
        "Get `x` and `y` indices for batches of data"
        self.n_inps = [dl.n_inp for dl in self.dls]
        self.x_idxs = self._split_idxs(self.n_inps)
        
        # Identify duplicate targets
        dl_dict = dict(zip(range(0,len(self.dls)), self.n_inps))
        outs = L([])
        for key, n_inp in dl_dict.items():
            b = next(iter(self.dls[key]))
            outs += L(b[n_inp:])
        self.y_idxs = self._get_vals(outs)

    def __iter__(self):
        z = zip(*[_loaders[i.fake_l.num_workers==0](i.fake_l) for i in self.dls])
        for b in z:  
            inps = []
            outs = []
            if self.device is not None: b = to_device(b, self.device)
            for batch, dl in zip(b, self.dls):
                batch = dl.after_batch(batch)
                inps += batch[:dl.n_inp]
                outs += batch[dl.n_inp:]
            # Remove duplicates and split inputs and outputs
            inps = [L(inps)[idx] for idx in self.x_idxs] if len(self.x_idxs) > 1 else L(outs)[self.x_idxs][0]
            outs = L(outs)[self.y_idxs] if len(self.y_idxs) > 1 else L(outs)[self.y_idxs][0]
            yield (inps, outs)

    def one_batch(self):
        "Grab one batch of data"
        with self.fake_l.no_multiproc(): res = first(self)
        if hasattr(self, 'it'): delattr(self, 'it')
        return res

    def shuffle_fn(self, idxs):
        "Generate the same idxs for all dls in each batch"
        if self.count == 0: self.rng = self.dls[0].rng.sample(idxs, len(idxs))
        self.count += 1
        if self.count == len(self.dls): self.count = 0
        return self.rng

    def show_batch(self):
        "Show a batch of data"
        for dl in self.dls: dl.show_batch()
            
    def to(self, device): self.device = device

    def _arrayisin(self, arr, arr_list):
        "Checks if `arr` is in `arr_list`"
        for a in arr_list:
            if np.array_equal(arr, a): return True
        return False
    
    def _split_idxs(self, a):
        a_cum = np.array(a).cumsum().tolist()
        b = np.arange(sum(a)).tolist()
        start = 0
        b_ = []
        for i, idx in enumerate(range(len(a))):
            end = a_cum[i]
            b_.append(b[start:end] if end - start > 1 else b[start])
            start = end
        return b_

In [None]:
path = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(path/'adult.csv')
# df['salary'] = np.random.rand(len(df)) # uncomment to simulate a cont dependent variable
procs = [Categorify, FillMissing, Normalize]
cat_names = ['workclass', 'education', 'marital-status']
cont_names = ['age', 'fnlwgt']
y_names = ['salary']
y_block = RegressionBlock() if isinstance(df['salary'].values[0], float) else CategoryBlock()
splits = RandomSplitter()(range_of(df))
pd.options.mode.chained_assignment=None
to0 = TabularPandas(df, procs=procs, cat_names=cat_names, cont_names=cont_names,
                   y_names=y_names, y_block=y_block, splits=splits, inplace=True,
                   reduce_memory=False)
to0.show(5)

Unnamed: 0,workclass,education,marital-status,age,fnlwgt,salary
29090,Private,HS-grad,Never-married,35.0,23621.0,<50k
16290,?,HS-grad,Married-civ-spouse,80.0,172826.0,<50k
30261,Private,10th,Never-married,17.0,202521.0,<50k
10243,Private,HS-grad,Widowed,47.0,223342.0,<50k
28925,Self-emp-not-inc,HS-grad,Married-civ-spouse,56.0,335605.0,>=50k


In [None]:
path = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(path/'adult.csv')
# df['salary'] = np.random.rand(len(df)) # uncomment to simulate a cont dependent variable
procs = [Categorify, FillMissing, Normalize]
cat_names = None#['occupation', 'relationship', 'race']
cont_names = ['education-num']
y_names = ['salary']
y_block = RegressionBlock() if isinstance(df['salary'].values[0], float) else CategoryBlock()
pd.options.mode.chained_assignment=None
to1 = TabularPandas(df, procs=procs, cat_names=cat_names, cont_names=cont_names,
                   y_names=y_names, y_block=y_block, splits=splits, inplace=True,
                   reduce_memory=False)
to1.show(5)

Unnamed: 0,education-num_na,education-num,salary
29090,False,9.0,<50k
16290,False,9.0,<50k
30261,False,6.0,<50k
10243,False,9.0,<50k
28925,False,9.0,>=50k


In [None]:
dls0 = to0.dataloaders()
dls1 = to1.dataloaders()

In [None]:
mixed_train_dl = MixedDL(dls0.train, dls1.train)
mixed_valid_dl = MixedDL(dls0.valid, dls1.valid)
dls = DataLoaders(mixed_train_dl, mixed_valid_dl)
dls.train.show_batch()

Unnamed: 0,workclass,education,marital-status,age,fnlwgt,salary
0,Private,HS-grad,Married-civ-spouse,34.0,255693.003055,<50k
1,Self-emp-inc,Some-college,Married-civ-spouse,28.0,142711.999766,<50k
2,Private,Some-college,Married-civ-spouse,34.0,113838.000577,>=50k
3,Private,Assoc-acdm,Never-married,40.0,70760.994379,<50k
4,?,Some-college,Never-married,68.0,170181.999792,<50k
5,Private,HS-grad,Divorced,33.0,213226.00034,<50k
6,Local-gov,Masters,Divorced,50.0,120189.997195,<50k
7,?,Some-college,Never-married,18.000001,184101.000187,<50k
8,Private,HS-grad,Divorced,46.0,185673.000122,<50k
9,Self-emp-not-inc,Masters,Married-civ-spouse,42.0,206066.000638,<50k


Unnamed: 0,education-num_na,education-num,salary
0,False,9.0,<50k
1,False,10.0,<50k
2,False,10.0,>=50k
3,False,12.0,<50k
4,False,10.0,<50k
5,False,9.0,<50k
6,False,14.0,<50k
7,False,10.0,<50k
8,False,9.0,<50k
9,False,14.0,<50k


In [None]:
xs, ys = first(dls.train)
len(xs), len(xs[0]), len(xs[1]), len(ys)

(2, 2, 2, 64)

In [None]:
#hide
beep(create_scripts())