In [None]:
# default_exp data.mixed

# Mixed data

> DataLoader than can take data from multiple dataloaders with different types of data

In [None]:
#export
from tsai.imports import *

In [None]:
#export
# This implementation of a mixed dataloader is based on a great implementation created by Zach Mueller in this fastai thread:
# https://forums.fast.ai/t/combining-tabular-images-in-fastai2-and-should-work-with-almost-any-other-type/73197

from packaging import version
from fastai.data.load import _FakeLoader
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter,_SingleProcessDataLoaderIter,_DatasetKind
_loaders = (_MultiProcessingDataLoaderIter,_SingleProcessDataLoaderIter)

class MixedDataLoader():
    def __init__(self, *loaders, path='.', shuffle=False, device=None, bs=None):
        "Accepts any number of `DataLoader` and a device"
        self.path = path
        device = ifnone(device, default_device())
        self.device = device
        self.c = None
        bs = ifnone(bs, min([dl.bs for dl in loaders]))
        for i, dl in enumerate(loaders): # ensure all dls have the same bs
            if hasattr(dl, 'vars'): self.vars = dl.vars
            if hasattr(dl, 'len'): self.len = dl.len
            dl.bs = bs
            dl.shuffle_fn = self.shuffle_fn
            if self.c is None and hasattr(dl, "c"): self.c = dl.c
            if i == 0: self.dataset = dl.dataset
            dl.to(device=device)
        self.shuffle = shuffle
        if not self.shuffle: self.rng = np.arange(len(self.dataset)).tolist()
        self.loaders = loaders
        self.count = 0
        self.fake_l = _FakeLoader(self, False, 0, 0, 0) if version.parse(fastai.__version__) >= version.parse("2.1") else _FakeLoader(self, False, 0, 0)
        if sum([len(dl.dataset) for dl in loaders]) > 0: self._get_idxs() # Do not apply on an empty dataset

    def new(self, *args, **kwargs):
        loaders = [dl.new(*args, **kwargs) for dl in self.loaders]
        return type(self)(*loaders, path=self.path, device=self.device)

    def __len__(self): return len(self.loaders[0])

    def _get_vals(self, x):
        "Checks for duplicates in batches"
        idxs, new_x = [], []
        for i, o in enumerate(x): x[i] = o.cpu().numpy().flatten()
        for idx, o in enumerate(x):
            if not self._arrayisin(o, new_x):
                idxs.append(idx)
                new_x.append(o)
        return idxs

    def _get_idxs(self):
        "Get `x` and `y` indices for batches of data"
        self.n_inps = [dl.n_inp for dl in self.loaders]
        self.x_idxs = self._split_idxs(self.n_inps)

        # Identify duplicate targets
        dl_dict = dict(zip(range(0,len(self.loaders)), self.n_inps))
        outs = L([])
        for key, n_inp in dl_dict.items():
            b = next(iter(self.loaders[key]))
            outs += L(b[n_inp:])
        self.y_idxs = self._get_vals(outs)

    def __iter__(self):
        z = zip(*[_loaders[i.fake_l.num_workers==0](i.fake_l) for i in self.loaders])
        for b in z:
            inps = []
            outs = []
            if self.device is not None: b = to_device(b, self.device)
            for batch, dl in zip(b, self.loaders):
                batch = dl.after_batch(batch)
                inps += batch[:dl.n_inp]
                outs += batch[dl.n_inp:]
            inps = tuple([tuple(L(inps)[idx]) if isinstance(idx, list) else inps[idx] for idx in self.x_idxs]) if len(self.x_idxs) > 1 else tuple(L(outs)[self.x_idxs][0])
            outs = tuple(L(outs)[self.y_idxs]) if len(self.y_idxs) > 1 else L(outs)[self.y_idxs][0]
            yield inps, outs

    def one_batch(self):
        "Grab one batch of data"
        with self.fake_l.no_multiproc(): res = first(self)
        if hasattr(self, 'it'): delattr(self, 'it')
        return res

    def shuffle_fn(self, idxs):
        "Generate the same idxs for all dls in each batch when shuffled"
        if self.count == 0: self.shuffled_idxs = np.random.permutation(idxs)
        self.count += 1
        if self.count == len(self.loaders): self.count = 0
        return self.shuffled_idxs

    def show_batch(self):
        "Show a batch of data"
        for dl in self.loaders: dl.show_batch()

    def to(self, device): self.device = device

    def _arrayisin(self, arr, arr_list):
        "Checks if `arr` is in `arr_list`"
        for a in arr_list:
            if np.array_equal(arr, a): return True
        return False

    def _split_idxs(self, a):
        a_cum = np.array(a).cumsum().tolist()
        b = np.arange(sum(a)).tolist()
        start = 0
        b_ = []
        for i, idx in enumerate(range(len(a))):
            end = a_cum[i]
            b_.append(b[start:end] if end - start > 1 else b[start])
            start = end
        return b_


class MixedDataLoaders(DataLoaders): pass

In [None]:
#export
def get_mixed_dls(*dls, device=None, shuffle_train=True, **kwargs):
    device = ifnone(device, default_device())
    _mixed_train_dls = []
    _mixed_valid_dls = []
    for dl in dls:
        _mixed_train_dls.append(dl.train)
        _mixed_valid_dls.append(dl.valid)
    mixed_train_dl = MixedDataLoader(*_mixed_train_dls, shuffle=shuffle_train, **kwargs)
    mixed_valid_dl = MixedDataLoader(*_mixed_valid_dls, shuffle=False, **kwargs)
    mixed_dls = MixedDataLoaders(mixed_train_dl, mixed_valid_dl, device=device)
    return mixed_dls

In [None]:
from tsai.data.tabular import *

path = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(path/'adult.csv')
# df['salary'] = np.random.rand(len(df)) # uncomment to simulate a cont dependent variable
target = 'salary'
splits = RandomSplitter()(range_of(df))

cat_names = ['workclass', 'education', 'marital-status']
cont_names = ['age', 'fnlwgt']
dls1 = get_tabular_dls(df, cat_names=cat_names, cont_names=cont_names, y_names=target, splits=splits, bs=512)
dls1.show_batch()

cat_names = None #['occupation', 'relationship', 'race']
cont_names = ['education-num']
dls2 = get_tabular_dls(df, cat_names=cat_names, cont_names=cont_names, y_names=target, splits=splits, bs=128)
dls2.show_batch()

Unnamed: 0,workclass,education,marital-status,age,fnlwgt,salary
0,Private,7th-8th,Married-civ-spouse,43.0,316182.999841,<50k
1,Federal-gov,Bachelors,Never-married,35.0,185052.999815,<50k
2,Self-emp-inc,Some-college,Married-civ-spouse,49.0,143482.000349,>=50k
3,Private,HS-grad,Never-married,27.0,215873.000615,<50k
4,Private,Prof-school,Never-married,45.0,327886.000393,>=50k
5,Private,Assoc-voc,Married-civ-spouse,39.0,119097.997807,<50k
6,Self-emp-inc,Assoc-acdm,Married-civ-spouse,62.999999,96930.003163,>=50k
7,Private,Some-college,Married-civ-spouse,45.0,252078.998013,>=50k
8,Private,Some-college,Divorced,46.0,169953.000109,<50k
9,?,HS-grad,Never-married,39.0,103985.999645,<50k


Unnamed: 0,education-num_na,education-num,salary
0,False,9.0,<50k
1,False,9.0,<50k
2,False,13.0,>=50k
3,False,10.0,<50k
4,False,9.0,<50k
5,False,14.0,>=50k
6,False,13.0,<50k
7,False,9.0,<50k
8,False,10.0,>=50k
9,False,9.0,<50k


In [None]:
dls = get_mixed_dls(dls1, dls2, bs=8)
first(dls.train)
first(dls.valid)
torch.save(dls,'export/mixed_dls.pth')
del dls
dls = torch.load('export/mixed_dls.pth')
dls.train.show_batch()

Unnamed: 0,workclass,education,marital-status,age,fnlwgt,salary
0,Private,9th,Never-married,18.999999,175081.000655,<50k
1,?,10th,Separated,19.999999,114813.000655,<50k
2,Private,12th,Never-married,37.0,301567.998732,<50k
3,Private,Assoc-voc,Divorced,49.0,156925.999686,>=50k
4,State-gov,Some-college,Separated,52.0,303461.998782,<50k
5,Private,10th,Never-married,24.0,280134.002665,<50k
6,Private,Bachelors,Married-civ-spouse,37.0,105021.002051,>=50k
7,Private,10th,Married-civ-spouse,64.0,180401.000212,>=50k


Unnamed: 0,education-num_na,education-num,salary
0,False,5.0,<50k
1,False,6.0,<50k
2,False,8.0,<50k
3,False,11.0,>=50k
4,False,10.0,<50k
5,False,6.0,<50k
6,False,13.0,>=50k
7,False,6.0,>=50k


In [None]:
xb, yb = first(dls.train)
xb

((tensor([[ 5,  7,  5],
          [ 1,  1,  6],
          [ 5,  3,  5],
          [ 5,  9,  1],
          [ 8, 16,  6],
          [ 5,  1,  5],
          [ 5, 10,  3],
          [ 5,  1,  3]]),
  tensor([[-1.4364, -0.1444],
          [-1.3631, -0.7122],
          [-0.1165,  1.0471],
          [ 0.7635, -0.3155],
          [ 0.9835,  1.0649],
          [-1.0698,  0.8452],
          [-0.1165, -0.8044],
          [ 1.8635, -0.0943]])),
 (tensor([[1],
          [1],
          [1],
          [1],
          [1],
          [1],
          [1],
          [1]]),
  tensor([[-1.9840],
          [-1.5930],
          [-0.8109],
          [ 0.3621],
          [-0.0289],
          [-1.5930],
          [ 1.1442],
          [-1.5930]])))

In [None]:
xs, ys = first(dls.train)
xs[0][0].shape, xs[0][1].shape, xs[1][0].shape, xs[1][1].shape

(torch.Size([8, 3]),
 torch.Size([8, 2]),
 torch.Size([8, 1]),
 torch.Size([8, 1]))

In [None]:
#hide
out = create_scripts(); beep(out)