In [4]:
import torch
from torch.utils.data import Dataset
import numpy as np

In [3]:
class CompositeDataset(Dataset):
    """ Dataset that is a composite of several Dataset objects. Useful for combining splits of a dataset. """
    def __init__(self, *datasets):
        self.datasets = datasets

    def __getitem__(self, item):
        current = self.datasets[0]
        for d in self.datasets:
            if item < len(d):
                return d[item]
            item -= len(d)
        else:
            raise IndexError('Index too large for composite dataset')

    def __len__(self):
        return sum(map(len, self.datasets))

In [5]:
class DumpDataset(Dataset):
    """ Dataset that is a composite of several Dataset objects. Useful for combining splits of a dataset. """
    def __init__(self, count):
        self.datas = np.arange(count)

    def __getitem__(self, item):
        return self.datas(item)

    def __len__(self):
        return len(self.datas)

In [6]:
a = DumpDataset(3)

In [7]:
len(a)

3

In [8]:
b = DumpDataset(4)

In [10]:
c = CompositeDataset(a,b)

In [12]:
d={1:2,2:4}

In [14]:
for k in list(d):
    d.pop(k)

In [15]:
d

{}