In [None]:
# default_exp data.mixed

# Mixed data

> DataLoader than can take data from multiple dataloaders with different types of data

In [None]:
#export
from tsai.imports import *

In [None]:
#export
from packaging import version
from fastai.data.load import _FakeLoader
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter,_SingleProcessDataLoaderIter,_DatasetKind
_loaders = (_MultiProcessingDataLoaderIter,_SingleProcessDataLoaderIter)

In [None]:
#export
# This implementation of a mixed dataloader is based on a great implementation created by Zach Mueller in this fastai thread:
# https://forums.fast.ai/t/combining-tabular-images-in-fastai2-and-should-work-with-almost-any-other-type/73197

class MixedDL():
    def __init__(self, *dls, device=None):
        "Accepts any number of `DataLoaders` and a device"
        device = ifnone(device, default_device())
        self.device = device
        self.c = []
        bs = min([dl.bs for dl in dls])
        for dl in dls: # ensure all dls have the same bs
            dl.bs = bs
            if bs == 0:  self.train_ds = dl.dataset
            dl.shuffle_fn = self.shuffle_fn
            if self.c == [] and hasattr(dl, "c"): self.c = dl.c
            dl.to(device=device)
        self.dls = dls
        self.count = 0
        self.fake_l = _FakeLoader(self, False, 0, 0, 0) if version.parse(fastai.__version__) >= version.parse("2.1") else _FakeLoader(self, False, 0, 0)
        self._get_idxs()
        
    def __len__(self): return len(self.dls[0])
    
    def _get_vals(self, x):
        "Checks for duplicates in batches"
        idxs, new_x = [], []
        for i, o in enumerate(x): x[i] = o.cpu().numpy().flatten()
        for idx, o in enumerate(x):
            if not self._arrayisin(o, new_x):
                idxs.append(idx)
                new_x.append(o)
        return idxs
    
    def _get_idxs(self):
        "Get `x` and `y` indices for batches of data"
        self.n_inps = [dl.n_inp for dl in self.dls]
        self.x_idxs = self._split_idxs(self.n_inps)
        
        # Identify duplicate targets
        dl_dict = dict(zip(range(0,len(self.dls)), self.n_inps))
        outs = L([])
        for key, n_inp in dl_dict.items():
            b = next(iter(self.dls[key]))
            outs += L(b[n_inp:])
        self.y_idxs = self._get_vals(outs)

    def __iter__(self):
        z = zip(*[_loaders[i.fake_l.num_workers==0](i.fake_l) for i in self.dls])
        for b in z:  
            inps = []
            outs = []
            if self.device is not None: b = to_device(b, self.device)
            for batch, dl in zip(b, self.dls):
                batch = dl.after_batch(batch)
                inps += batch[:dl.n_inp]
                outs += batch[dl.n_inp:]
            # Remove duplicates and split inputs and outputs
            inps = [L(inps)[idx] for idx in self.x_idxs] if len(self.x_idxs) > 1 else L(outs)[self.x_idxs][0]
            outs = L(outs)[self.y_idxs] if len(self.y_idxs) > 1 else L(outs)[self.y_idxs][0]
            yield (inps, outs)

    def one_batch(self):
        "Grab one batch of data"
        with self.fake_l.no_multiproc(): res = first(self)
        if hasattr(self, 'it'): delattr(self, 'it')
        return res

    def shuffle_fn(self, idxs):
        "Generate the same idxs for all dls in each batch"
        if self.count == 0: self.rng = self.dls[0].rng.sample(idxs, len(idxs))
        self.count += 1
        if self.count == len(self.dls): self.count = 0
        return self.rng

    def show_batch(self):
        "Show a batch of data"
        for dl in self.dls: dl.show_batch()
            
    def to(self, device): self.device = device

    def _arrayisin(self, arr, arr_list):
        "Checks if `arr` is in `arr_list`"
        for a in arr_list:
            if np.array_equal(arr, a): return True
        return False
    
    def _split_idxs(self, a):
        a_cum = np.array(a).cumsum().tolist()
        b = np.arange(sum(a)).tolist()
        start = 0
        b_ = []
        for i, idx in enumerate(range(len(a))):
            end = a_cum[i]
            b_.append(b[start:end] if end - start > 1 else b[start])
            start = end
        return b_

In [None]:
#export
def get_mixed_dls(*dls, device=None):
    device = ifnone(device, default_device())
    _mixed_train_dls = []
    _mixed_valid_dls = []
    for dl in dls:
        _mixed_train_dls.append(dl.train)
        _mixed_valid_dls.append(dl.valid)
    mixed_train_dl = MixedDL(*_mixed_train_dls)
    mixed_valid_dl = MixedDL(*_mixed_valid_dls)
    mixed_dls = DataLoaders(mixed_train_dl, mixed_valid_dl, device=device)
    return mixed_dls

In [None]:
from tsai.data.tabular import *

path = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(path/'adult.csv')
# df['salary'] = np.random.rand(len(df)) # uncomment to simulate a cont dependent variable
target = 'salary'
splits = RandomSplitter()(range_of(df))

cat_names = ['workclass', 'education', 'marital-status']
cont_names = ['age', 'fnlwgt']
dls1 = get_tabular_dls(df, cat_names=cat_names, cont_names=cont_names, target=target, splits=splits, bs=512)
dls1.show_batch()

cat_names = None #['occupation', 'relationship', 'race']
cont_names = ['education-num']
dls2 = get_tabular_dls(df, cat_names=cat_names, cont_names=cont_names, target=target, splits=splits, bs=128)
dls2.show_batch()

Unnamed: 0,workclass,education,marital-status,age,fnlwgt,salary
0,Private,12th,Married-civ-spouse,23.000001,201679.999957,<50k
1,State-gov,11th,Never-married,19.0,431745.004269,<50k
2,Private,10th,Married-civ-spouse,60.999999,68267.994495,<50k
3,Federal-gov,Bachelors,Never-married,28.0,381788.996004,<50k
4,Private,Some-college,Never-married,21.0,138768.000104,<50k
5,Private,HS-grad,Married-civ-spouse,39.0,51100.001895,>=50k
6,Local-gov,HS-grad,Married-civ-spouse,31.0,224233.998719,<50k
7,Private,Masters,Married-civ-spouse,49.0,101824.999443,>=50k
8,Private,HS-grad,Never-married,19.0,198662.999953,<50k
9,Private,Some-college,Married-civ-spouse,39.0,165186.000027,<50k


Unnamed: 0,education-num_na,education-num,salary
0,False,14.0,<50k
1,False,10.0,<50k
2,False,9.0,<50k
3,False,8.0,<50k
4,False,10.0,<50k
5,False,9.0,>=50k
6,False,10.0,<50k
7,False,4.0,<50k
8,False,10.0,<50k
9,False,9.0,<50k


In [None]:
b = first(dls2.train)
b[0].shape, b[1].shape

(torch.Size([128, 1]), torch.Size([128, 1]))

In [None]:
dls = get_mixed_dls(dls1, dls2)
dls.train.show_batch()

Unnamed: 0,workclass,education,marital-status,age,fnlwgt,salary
0,Private,Bachelors,Never-married,30.0,66193.99789,<50k
1,Private,10th,Never-married,16.999999,239345.999274,<50k
2,Federal-gov,Bachelors,Married-civ-spouse,43.0,197069.000368,>=50k
3,Private,Bachelors,Married-civ-spouse,57.000001,124506.99804,>=50k
4,Private,Bachelors,Never-married,24.0,163664.999487,<50k
5,?,9th,Married-civ-spouse,66.000001,108185.002139,<50k
6,Private,Bachelors,Never-married,28.0,338376.000524,<50k
7,Private,Assoc-voc,Divorced,29.0,177118.99975,<50k
8,Private,Some-college,Never-married,21.999999,383603.000333,<50k
9,Private,Some-college,Never-married,34.0,591710.997851,<50k


Unnamed: 0,education-num_na,education-num,salary
0,False,13.0,<50k
1,False,6.0,<50k
2,False,13.0,>=50k
3,False,13.0,>=50k
4,False,13.0,<50k
5,False,5.0,<50k
6,False,13.0,<50k
7,False,11.0,<50k
8,False,10.0,<50k
9,False,10.0,<50k


In [None]:
xs, ys = first(dls.train)
xs[0][0].shape, xs[0][1].shape, xs[1][0].shape, xs[1][1].shape, ys.shape

(torch.Size([128, 3]),
 torch.Size([128, 2]),
 torch.Size([128, 1]),
 torch.Size([128, 1]),
 torch.Size([128, 1]))

In [None]:
#hide
beep(create_scripts())