In [None]:
# default_exp data.mixed

# Mixed data

> DataLoader than can take data from multiple dataloaders with different types of data

In [None]:
#export
from tsai.imports import *

In [None]:
#export
# This implementation of a mixed dataloader is based on a great implementation created by Zach Mueller in this fastai thread:
# https://forums.fast.ai/t/combining-tabular-images-in-fastai2-and-should-work-with-almost-any-other-type/73197

from packaging import version
from fastai.data.load import _FakeLoader
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter,_SingleProcessDataLoaderIter,_DatasetKind
_loaders = (_MultiProcessingDataLoaderIter,_SingleProcessDataLoaderIter)

class MixedDataLoader():
    def __init__(self, *loaders, path='.', shuffle=False, device=None, bs=None):
        "Accepts any number of `DataLoader` and a device"
        self.path = path
        device = ifnone(device, default_device())
        self.device = device
        self.c = None
        bs = ifnone(bs, min([dl.bs for dl in loaders]))
        for i, dl in enumerate(loaders): # ensure all dls have the same bs
            if hasattr(dl, 'vars'): self.vars = dl.vars
            if hasattr(dl, 'len'): self.len = dl.len
            dl.bs = bs
            dl.shuffle_fn = self.shuffle_fn
            if self.c is None and hasattr(dl, "c"): self.c = dl.c
            if i == 0: self.dataset = dl.dataset
            dl.to(device=device)
        self.shuffle = shuffle
        if not self.shuffle: self.rng = np.arange(len(self.dataset)).tolist()
        self.loaders = loaders
        self.count = 0
        self.fake_l = _FakeLoader(self, False, 0, 0, 0) if version.parse(fastai.__version__) >= version.parse("2.1") else _FakeLoader(self, False, 0, 0)
        if sum([len(dl.dataset) for dl in loaders]) > 0: self._get_idxs() # Do not apply on an empty dataset

    def new(self, *args, **kwargs):
        loaders = [dl.new(*args, **kwargs) for dl in self.loaders]
        return type(self)(*loaders, path=self.path, device=self.device)

    def __len__(self): return len(self.loaders[0])

    def _get_vals(self, x):
        "Checks for duplicates in batches"
        idxs, new_x = [], []
        for i, o in enumerate(x): x[i] = o.cpu().numpy().flatten()
        for idx, o in enumerate(x):
            if not self._arrayisin(o, new_x):
                idxs.append(idx)
                new_x.append(o)
        return idxs

    def _get_idxs(self):
        "Get `x` and `y` indices for batches of data"
        self.n_inps = [dl.n_inp for dl in self.loaders]
        self.x_idxs = self._split_idxs(self.n_inps)

        # Identify duplicate targets
        dl_dict = dict(zip(range(0,len(self.loaders)), self.n_inps))
        outs = L([])
        for key, n_inp in dl_dict.items():
            b = next(iter(self.loaders[key]))
            outs += L(b[n_inp:])
        self.y_idxs = self._get_vals(outs)

    def __iter__(self):
        z = zip(*[_loaders[i.fake_l.num_workers==0](i.fake_l) for i in self.loaders])
        for b in z:
            inps = []
            outs = []
            if self.device is not None: b = to_device(b, self.device)
            for batch, dl in zip(b, self.loaders):
                batch = dl.after_batch(batch)
                inps += batch[:dl.n_inp]
                outs += batch[dl.n_inp:]
            inps = tuple([tuple(L(inps)[idx]) if isinstance(idx, list) else inps[idx] for idx in self.x_idxs]) if len(self.x_idxs) > 1 else tuple(L(outs)[self.x_idxs][0])
            outs = tuple(L(outs)[self.y_idxs]) if len(self.y_idxs) > 1 else L(outs)[self.y_idxs][0]
            yield inps, outs

    def one_batch(self):
        "Grab one batch of data"
        with self.fake_l.no_multiproc(): res = first(self)
        if hasattr(self, 'it'): delattr(self, 'it')
        return res

    def shuffle_fn(self, idxs):
        "Generate the same idxs for all dls in each batch when shuffled"
        if self.count == 0: self.shuffled_idxs = np.random.permutation(idxs)
        self.count += 1
        if self.count == len(self.loaders): self.count = 0
        return self.shuffled_idxs

    def show_batch(self):
        "Show a batch of data"
        for dl in self.loaders: dl.show_batch()

    def to(self, device): self.device = device

    def _arrayisin(self, arr, arr_list):
        "Checks if `arr` is in `arr_list`"
        for a in arr_list:
            if np.array_equal(arr, a): return True
        return False

    def _split_idxs(self, a):
        a_cum = np.array(a).cumsum().tolist()
        b = np.arange(sum(a)).tolist()
        start = 0
        b_ = []
        for i, idx in enumerate(range(len(a))):
            end = a_cum[i]
            b_.append(b[start:end] if end - start > 1 else b[start])
            start = end
        return b_


class MixedDataLoaders(DataLoaders): pass

In [None]:
#export
def get_mixed_dls(*dls, device=None, shuffle_train=True, **kwargs):
    device = ifnone(device, default_device())
    _mixed_train_dls = []
    _mixed_valid_dls = []
    for dl in dls:
        _mixed_train_dls.append(dl.train)
        _mixed_valid_dls.append(dl.valid)
    mixed_train_dl = MixedDataLoader(*_mixed_train_dls, shuffle=shuffle_train, **kwargs)
    mixed_valid_dl = MixedDataLoader(*_mixed_valid_dls, shuffle=False, **kwargs)
    mixed_dls = MixedDataLoaders(mixed_train_dl, mixed_valid_dl, device=device)
    return mixed_dls

In [None]:
#hide
#original
# class MixedDataLoaders():
#     def __init__(self, *dls, device='cuda:0'):
#         "Accepts any number of `DataLoaders` and a device"
#         self.device = device
#         for dl in dls: dl.shuffle_fn = self.shuffle_fn
#         self.dls = dls
#         self.count = 0
#         self.fake_l = _FakeLoader(self, False, 0, 0, 0) if version.parse(fastai.__version__) >= version.parse("2.1") else _FakeLoader(self, False, 0, 0)
#         self._get_idxs()
        
#     def __len__(self): return len(self.dls[0])
    
#     def _get_vals(self, x):
#         "Checks for duplicates in batches"
#         idxs, new_x = [], []
#         for i, o in enumerate(x): x[i] = o.cpu().numpy().flatten()
#         for idx, o in enumerate(x):
#             if not _arrayisin(o, new_x):
#                 idxs.append(idx)
#                 new_x.append(o)
#         return idxs
    
    
#     def _get_idxs(self):
#         "Get `x` and `y` indicies for batches of data"
#         dl_dict = dict(zip(range(0,len(self.dls)), [dl.n_inp for dl in self.dls]))
#         inps = L([])
#         outs = L([])
#         for key, n_inp in dl_dict.items():
#             b = next(iter(self.dls[key]))
#             inps += L(b[:n_inp])
#             outs += L(b[n_inp:])
#         self.x_idxs = self._get_vals(inps)
#         self.y_idxs = self._get_vals(outs)
        
    
#     def __iter__(self):
#         z = zip(*[_loaders[i.fake_l.num_workers==0](i.fake_l) for i in self.dls])
#         for b in z:   
#             inps = []
#             outs = []
#             if self.device is not None: 
#                 b = to_device(b, self.device)
#             for batch, dl in zip(b, self.dls):
#                 batch = dl.after_batch(batch)
#                 inps += batch[:dl.n_inp]
#                 outs += batch[dl.n_inp:]
#             inps = tuple(L(inps)[self.x_idxs])
#             outs = tuple(L(outs)[self.y_idxs]
#             yield (inps, outs)
                
#     def one_batch(self):
#         "Grab one batch of data"
#         with self.fake_l.no_multiproc(): res = first(self)
#         if hasattr(self, 'it'): delattr(self, 'it')
#         return res
    
#     def shuffle_fn(self, idxs):
#         "Shuffle the internal `DataLoaders`"
#         if self.count == 0:
#             self.rng = self.dls[0].rng.sample(idxs, len(idxs))
#             self.count += 1
#             return self.rng
#         if self.count == 1:
#             self.count = 0
#             return self.rng

        
#     def show_batch(self):
#         "Show a batch of data"
#         for dl in self.dls:
#             dl.show_batch()
            
#     def to(self, device): self.device = device
        
        
# def _arrayisin(arr, arr_list):
#     "Checks if `arr` is in `arr_list`"
#     for a in arr_list:
#         if np.array_equal(arr, a):
#             return True
#     return False

In [None]:
from tsai.data.tabular import *

path = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(path/'adult.csv')
# df['salary'] = np.random.rand(len(df)) # uncomment to simulate a cont dependent variable
target = 'salary'
splits = RandomSplitter()(range_of(df))

cat_names = ['workclass', 'education', 'marital-status']
cont_names = ['age', 'fnlwgt']
dls1 = get_tabular_dls(df, cat_names=cat_names, cont_names=cont_names, y_names=target, splits=splits, bs=512)
dls1.show_batch()

cat_names = None #['occupation', 'relationship', 'race']
cont_names = ['education-num']
dls2 = get_tabular_dls(df, cat_names=cat_names, cont_names=cont_names, y_names=target, splits=splits, bs=128)
dls2.show_batch()

Unnamed: 0,workclass,education,marital-status,age,fnlwgt,salary
0,Private,Masters,Married-civ-spouse,39.0,505118.992245,>=50k
1,Self-emp-not-inc,Some-college,Divorced,34.0,168905.999588,<50k
2,Private,Some-college,Never-married,26.0,58097.997132,<50k
3,Private,Doctorate,Never-married,36.0,103109.997994,<50k
4,Local-gov,Some-college,Married-civ-spouse,50.0,177705.000226,<50k
5,Self-emp-not-inc,HS-grad,Widowed,63.000001,28611.999789,<50k
6,Local-gov,Some-college,Divorced,43.0,337468.993915,<50k
7,Private,HS-grad,Divorced,35.0,114605.000181,<50k
8,Private,Some-college,Never-married,20.0,205838.999256,<50k
9,Private,11th,Never-married,21.0,191459.999995,<50k


Unnamed: 0,education-num_na,education-num,salary
0,False,10.0,>=50k
1,False,4.0,<50k
2,False,11.0,>=50k
3,False,9.0,>=50k
4,False,12.0,>=50k
5,False,9.0,<50k
6,False,9.0,<50k
7,False,9.0,<50k
8,False,10.0,>=50k
9,False,16.0,>=50k


In [None]:
dls = get_mixed_dls(dls1, dls2, bs=8)
first(dls.train)
first(dls.valid)
torch.save(dls,'export/mixed_dls.pth')
del dls
dls = torch.load('export/mixed_dls.pth')
dls.train.show_batch()

Unnamed: 0,workclass,education,marital-status,age,fnlwgt,salary
0,Private,HS-grad,Never-married,26.0,224361.001095,<50k
1,Private,HS-grad,Divorced,39.0,342642.000041,<50k
2,Private,Some-college,Never-married,19.0,28789.998927,<50k
3,Private,Bachelors,Married-civ-spouse,33.0,169878.999316,>=50k
4,Federal-gov,Assoc-voc,Never-married,32.0,72338.00316,<50k
5,Private,HS-grad,Married-civ-spouse,25.0,315643.000341,<50k
6,Private,HS-grad,Divorced,33.0,258932.002426,<50k
7,Private,HS-grad,Never-married,23.0,172047.000543,<50k


Unnamed: 0,education-num_na,education-num,salary
0,False,9.0,<50k
1,False,9.0,<50k
2,False,10.0,<50k
3,False,13.0,>=50k
4,False,11.0,<50k
5,False,9.0,<50k
6,False,9.0,<50k
7,False,9.0,<50k


In [None]:
xb, yb = first(dls.train)
xb

((tensor([[ 5, 12,  5],
          [ 5, 12,  1],
          [ 5, 16,  5],
          [ 5, 10,  3],
          [ 2,  9,  5],
          [ 5, 12,  3],
          [ 5, 12,  1],
          [ 5, 12,  5]]),
  tensor([[-0.9210,  0.3337],
          [ 0.0326,  1.4611],
          [-1.4344, -1.5303],
          [-0.4075, -0.1856],
          [-0.4809, -1.1153],
          [-0.9943,  1.2038],
          [-0.4075,  0.6632],
          [-1.1410, -0.1649]])),
 (tensor([[1],
          [1],
          [1],
          [1],
          [1],
          [1],
          [1],
          [1]]),
  tensor([[-0.4236],
          [-0.4236],
          [-0.0315],
          [ 1.1449],
          [ 0.3606],
          [-0.4236],
          [-0.4236],
          [-0.4236]])))

In [None]:
xs, ys = first(dls.train)
xs[0][0].shape, xs[0][1].shape, xs[1][0].shape, xs[1][1].shape

(torch.Size([8, 3]),
 torch.Size([8, 2]),
 torch.Size([8, 1]),
 torch.Size([8, 1]))

In [None]:
#hide
out = create_scripts(); beep(out)