# Dataset

In [1]:
__all__ = ['Dataset', 'SimpleDataset', 'ArrayDataset', 'RecordFileDataset']

class Dataset(object):
    def __getitem__(self, idx):
        raise NotImplementedError
    def __len__(self):
        raise NotImplementedError
    def transform(self, fn, lazy=True):
        trans = _LazyTransformDataset(self, fn)
        if lazy:
            return trans
        return SimpleDataset([i for i in trans])
    
class SimpleDataset(Dataset):
    def __init__(self, data):
        self._data = data
    def __len__(self):
        return len(self._data)
    def __getitem__(self, idx):
        return self._data[idx]
    
class ArrayDataset(Dataset):
    def __init__(self, *args):
        assert len(args) > 0, "Needs at least 1 arrays"
        self._length = len(args[0])
        self._data = []
        for i, data in enumerate(args):
            assert len(data) == self._length, \
                "All arrays must have the same length; array[0] has length %d " \
                "while array[%d] has %d." % (self._length, i+1, len(data))
            if isinstance(data, ndarray.NDArray) and len(data.shape) == 1:
                data = data.asnumpy()
            self._data.append(data)

    def __getitem__(self, idx):
        if len(self._data) == 1:
            return self._data[0][idx]
        else:
            return tuple(data[idx] for data in self._data)

    def __len__(self):
        return self._length
    
class RecordFileDataset(Dataset):
    def __init__(self, filename):
        self.idx_file = os.path.splitext(filename)[0] + '.idx'
        self.filename = filename
        self._record = recordio.MXIndexedRecordIO(self.idx_file, self.filename, 'r')

    def __getitem__(self, idx):
        return self._record.read_idx(self._record.keys[idx])

    def __len__(self):
        return len(self._record.keys)

class _DownloadedDataset(Dataset):
    def __init__(self, root, transform):
        super(_DownloadedDataset, self).__init__()
        self._transform = transform
        self._data = None
        self._label = None
        root = os.path.expanduser(root)
        self._root = root
        if not os.path.isdir(root):
            os.makedirs(root)
        self._get_data()

    def __getitem__(self, idx):
        if self._transform is not None:
            return self._transform(self._data[idx], self._label[idx])
        return self._data[idx], self._label[idx]

    def __len__(self):
        return len(self._label)

    def _get_data(self):
        raise NotImplementedError
        
class _LazyTransformDataset(Dataset):
    def __init__(self, data, fn):
        self._data = data
        self._fn = fn
    def __len__(self):
        return len(self._data)
    def __getitem__(self, idx):
        item = self._data[idx]
        if isinstance(item, tuple):
            return self._fn(*item)
        return self._fn(item)
    def transform_first(self, fn, lazy=True):
        def base_fn(x, *args):
            if args:
                return (fn(x),) + args
            return fn(x)
        return self.transform(base_fn, lazy)

# DataLoader

In [2]:
import pickle
import io
import sys
import multiprocessing
import multiprocessing.queues
from multiprocessing.reduction import ForkingPickler
import threading
import numpy as np

class ConnectionWrapper(object):
    def __init__(self, conn):
        self._conn = conn
    def send(self, obj):
        buf = io.BytesIO()
        ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(obj)
        self.send_bytes(buf.getvalue())
    def recv(self):
        buf = self.recv_bytes()
        return pickle.loads(buf)
    def __getattr__(self, name):
        attr = self.__dict__.get('_conn', None)
        return getattr(attr, name)
    
class Queue(multiprocessing.queues.Queue):
    def __init__(self, *args, **kwargs):
        if sys.version_info[0] <= 2:
            super(Queue, self).__init__(*args, **kwargs)
        else:
            super(Queue, self).__init__(*args, ctx=multiprocessing.get_context(),
                                        **kwargs)
        self._reader = ConnectionWrapper(self._reader)
        self._writer = ConnectionWrapper(self._writer)
        self._send = self._writer.send
        self._recv = self._reader.recv
        
class SimpleQueue(multiprocessing.queues.SimpleQueue):
    def __init__(self, *args, **kwargs):
        if sys.version_info[0] <= 2:
            super(SimpleQueue, self).__init__(*args, **kwargs)
        else:
            super(SimpleQueue, self).__init__(*args, ctx=multiprocessing.get_context(),
                                              **kwargs)
        self._reader = ConnectionWrapper(self._reader)
        self._writer = ConnectionWrapper(self._writer)
        self._send = self._writer.send
        self._recv = self._reader.recv
        
def default_batchify_fn(data):
    if isinstance(data[0], nd.NDArray):
        return nd.stack(*data)
    elif isinstance(data[0], tuple):
        data = zip(*data)
        return [default_batchify_fn(i) for i in data]
    else:
        data = np.asarray(data)
        return nd.array(data, dtype=data.dtype)


def default_mp_batchify_fn(data):
    if isinstance(data[0], nd.NDArray):
        out = nd.empty((len(data),) + data[0].shape, dtype=data[0].dtype,
                       ctx=context.Context('cpu_shared', 0))
        return nd.stack(*data, out=out)
    elif isinstance(data[0], tuple):
        data = zip(*data)
        return [default_mp_batchify_fn(i) for i in data]
    else:
        data = np.asarray(data)
        return nd.array(data, dtype=data.dtype,
                        ctx=context.Context('cpu_shared', 0))


def _as_in_context(data, ctx):
    if isinstance(data, nd.NDArray):
        return data.as_in_context(ctx)
    elif isinstance(data, (list, tuple)):
        return [_as_in_context(d, ctx) for d in data]
    return data


def worker_loop_v1(dataset, key_queue, data_queue, batchify_fn):
    while True:
        idx, samples = key_queue.get()
        if idx is None:
            break
        batch = batchify_fn([dataset[i] for i in samples])
        data_queue.put((idx, batch))

def fetcher_loop_v1(data_queue, data_buffer, pin_memory=False, data_buffer_lock=None):
    while True:
        idx, batch = data_queue.get()
        if idx is None:
            break
        if pin_memory:
            batch = _as_in_context(batch, context.cpu_pinned())
        else:
            batch = _as_in_context(batch, context.cpu())
        if data_buffer_lock is not None:
            with data_buffer_lock:
                data_buffer[idx] = batch
        else:
            data_buffer[idx] = batch

class _MultiWorkerIterV1(object):
    def __init__(self, num_workers, dataset, batchify_fn, batch_sampler, pin_memory=False,
                 worker_fn=worker_loop_v1):
        assert num_workers > 0, "_MultiWorkerIter is not for {} workers".format(num_workers)
        self._num_workers = num_workers
        self._dataset = dataset
        self._batchify_fn = batchify_fn
        self._batch_sampler = batch_sampler
        self._key_queue = Queue()
        self._data_queue = Queue() if sys.version_info[0] <= 2 else SimpleQueue()

        self._data_buffer = {}
        self._data_buffer_lock = threading.Lock()

        self._rcvd_idx = 0
        self._sent_idx = 0
        self._iter = iter(self._batch_sampler)
        self._shutdown = False

        workers = []
        for _ in range(self._num_workers):
            worker = multiprocessing.Process(
                target=worker_fn,
                args=(self._dataset, self._key_queue, self._data_queue, self._batchify_fn))
            worker.daemon = True
            worker.start()
            workers.append(worker)
        self._workers = workers

        self._fetcher = threading.Thread(
            target=fetcher_loop_v1,
            args=(self._data_queue, self._data_buffer, pin_memory, self._data_buffer_lock))
        self._fetcher.daemon = True
        self._fetcher.start()

        # pre-fetch
        for _ in range(2 * self._num_workers):
            self._push_next()

    def __len__(self):
        return len(self._batch_sampler)

    def __del__(self):
        self.shutdown()

    def _push_next(self):
        r = next(self._iter, None)
        if r is None:
            return
        self._key_queue.put((self._sent_idx, r))
        self._sent_idx += 1

    def __next__(self):
        assert not self._shutdown, "call __next__ after shutdown is forbidden"
        if self._rcvd_idx == self._sent_idx:
            assert not self._data_buffer, "Data buffer should be empty at this moment"
            self.shutdown()
            raise StopIteration

        while True:
            if self._rcvd_idx in self._data_buffer:
                with self._data_buffer_lock:
                    batch = self._data_buffer.pop(self._rcvd_idx)
                self._rcvd_idx += 1
                self._push_next()
                return batch

    def next(self):
        return self.__next__()

    def __iter__(self):
        return self

    def shutdown(self):
        if not self._shutdown:
            self._data_queue.put((None, None))
            self._fetcher.join()
            for _ in range(self._num_workers):
                self._key_queue.put((None, None))
            for w in self._workers:
                if w.is_alive():
                    w.terminate()
            self._shutdown = True
            
            
class DataLoaderV1(object):
    def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
                 last_batch=None, batch_sampler=None, batchify_fn=None,
                 num_workers=0, pin_memory=False):
        self._dataset = dataset
        self._pin_memory = pin_memory

        if batch_sampler is None:
            if batch_size is None:
                raise ValueError("batch_size must be specified unless " \
                                 "batch_sampler is specified")
            if sampler is None:
                if shuffle:
                    sampler = _sampler.RandomSampler(len(dataset))
                else:
                    sampler = _sampler.SequentialSampler(len(dataset))
            elif shuffle:
                raise ValueError("shuffle must not be specified if sampler is specified")

            batch_sampler = _sampler.BatchSampler(
                sampler, batch_size, last_batch if last_batch else 'keep')
        elif batch_size is not None or shuffle or sampler is not None or \
                last_batch is not None:
            raise ValueError("batch_size, shuffle, sampler and last_batch must " \
                             "not be specified if batch_sampler is specified.")

        self._batch_sampler = batch_sampler
        self._num_workers = num_workers if num_workers >= 0 else 0
        if batchify_fn is None:
            if num_workers > 0:
                self._batchify_fn = default_mp_batchify_fn
            else:
                self._batchify_fn = default_batchify_fn
        else:
            self._batchify_fn = batchify_fn

    def __iter__(self):
        if self._num_workers == 0:
            def same_process_iter():
                for batch in self._batch_sampler:
                    ret = self._batchify_fn([self._dataset[idx] for idx in batch])
                    if self._pin_memory:
                        ret = _as_in_context(ret, context.cpu_pinned())
                    yield ret
            return same_process_iter()

        return _MultiWorkerIterV1(self._num_workers, self._dataset,
                                  self._batchify_fn, self._batch_sampler, self._pin_memory)

    def __len__(self):
        return len(self._batch_sampler)

_worker_dataset = None

def _worker_initializer(dataset):
    global _worker_dataset
    _worker_dataset = dataset

def _worker_fn(samples, batchify_fn):
    global _worker_dataset
    batch = batchify_fn([_worker_dataset[i] for i in samples])
    buf = io.BytesIO()
    ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(batch)
    return buf.getvalue()

class _MultiWorkerIter(object):
    def __init__(self, worker_pool, batchify_fn, batch_sampler, pin_memory=False,
                 worker_fn=_worker_fn, prefetch=0):
        self._worker_pool = worker_pool
        self._batchify_fn = batchify_fn
        self._batch_sampler = batch_sampler
        self._data_buffer = {}
        self._rcvd_idx = 0
        self._sent_idx = 0
        self._iter = iter(self._batch_sampler)
        self._worker_fn = worker_fn
        self._pin_memory = pin_memory
        # pre-fetch
        for _ in range(prefetch):
            self._push_next()

    def __len__(self):
        return len(self._batch_sampler)

    def _push_next(self):
        r = next(self._iter, None)
        if r is None:
            return
        async_ret = self._worker_pool.apply_async(self._worker_fn, (r, self._batchify_fn))
        self._data_buffer[self._sent_idx] = async_ret
        self._sent_idx += 1

    def __next__(self):
        self._push_next()
        if self._rcvd_idx == self._sent_idx:
            assert not self._data_buffer, "Data buffer should be empty at this moment"
            raise StopIteration

        assert self._rcvd_idx < self._sent_idx, "rcvd_idx must be smaller than sent_idx"
        assert self._rcvd_idx in self._data_buffer, "fatal error with _push_next, rcvd_idx missing"
        ret = self._data_buffer.pop(self._rcvd_idx)
        batch = pickle.loads(ret.get())
        if self._pin_memory:
            batch = _as_in_context(batch, context.cpu_pinned())
        batch = batch[0] if len(batch) == 1 else batch
        self._rcvd_idx += 1
        return batch

    def next(self):
        return self.__next__()

    def __iter__(self):
        return self


class DataLoader(object):
    def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
                 last_batch=None, batch_sampler=None, batchify_fn=None,
                 num_workers=0, pin_memory=False, prefetch=None):
        self._dataset = dataset
        self._pin_memory = pin_memory

        if batch_sampler is None:
            if batch_size is None:
                raise ValueError("batch_size must be specified unless " \
                                 "batch_sampler is specified")
            if sampler is None:
                if shuffle:
                    sampler = _sampler.RandomSampler(len(dataset))
                else:
                    sampler = _sampler.SequentialSampler(len(dataset))
            elif shuffle:
                raise ValueError("shuffle must not be specified if sampler is specified")

            batch_sampler = _sampler.BatchSampler(
                sampler, batch_size, last_batch if last_batch else 'keep')
        elif batch_size is not None or shuffle or sampler is not None or \
                last_batch is not None:
            raise ValueError("batch_size, shuffle, sampler and last_batch must " \
                             "not be specified if batch_sampler is specified.")

        self._batch_sampler = batch_sampler
        self._num_workers = num_workers if num_workers >= 0 else 0
        self._worker_pool = None
        self._prefetch = max(0, int(prefetch) if prefetch is not None else 2 * self._num_workers)
        if self._num_workers > 0:
            self._worker_pool = multiprocessing.Pool(
                self._num_workers, initializer=_worker_initializer, initargs=[self._dataset])
        if batchify_fn is None:
            if num_workers > 0:
                self._batchify_fn = default_mp_batchify_fn
            else:
                self._batchify_fn = default_batchify_fn
        else:
            self._batchify_fn = batchify_fn

    def __iter__(self):
        if self._num_workers == 0:
            def same_process_iter():
                for batch in self._batch_sampler:
                    ret = self._batchify_fn([self._dataset[idx] for idx in batch])
                    if self._pin_memory:
                        ret = _as_in_context(ret, context.cpu_pinned())
                    yield ret
            return same_process_iter()

        return _MultiWorkerIter(self._worker_pool, self._batchify_fn, self._batch_sampler,
                                pin_memory=self._pin_memory, worker_fn=_worker_fn,
                                prefetch=self._prefetch)

    def __len__(self):
        return len(self._batch_sampler)

    def __del__(self):
        if self._worker_pool:
            assert isinstance(self._worker_pool, multiprocessing.pool.Pool)
            self._worker_pool.terminate()

# Sampler

In [3]:
class Sampler(object):
    def __iter__(self):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError


class SequentialSampler(Sampler):
    def __init__(self, length):
        self._length = length

    def __iter__(self):
        return iter(range(self._length))

    def __len__(self):
        return self._length


class RandomSampler(Sampler):
    def __init__(self, length):
        self._length = length

    def __iter__(self):
        indices = np.arange(self._length)
        np.random.shuffle(indices)
        return iter(indices)

    def __len__(self):
        return self._length


class BatchSampler(Sampler):
    def __init__(self, sampler, batch_size, last_batch='keep'):
        self._sampler = sampler
        self._batch_size = batch_size
        self._last_batch = last_batch
        self._prev = []

    def __iter__(self):
        batch, self._prev = self._prev, []
        for i in self._sampler:
            batch.append(i)
            if len(batch) == self._batch_size:
                yield batch
                batch = []
        if batch:
            if self._last_batch == 'keep':
                yield batch
            elif self._last_batch == 'discard':
                return
            elif self._last_batch == 'rollover':
                self._prev = batch
            else:
                raise ValueError(
                    "last_batch must be one of 'keep', 'discard', or 'rollover', " \
                    "but got %s"%self._last_batch)

    def __len__(self):
        if self._last_batch == 'keep':
            return (len(self._sampler) + self._batch_size - 1) // self._batch_size
        if self._last_batch == 'discard':
            return len(self._sampler) // self._batch_size
        if self._last_batch == 'rollover':
            return (len(self._prev) + len(self._sampler)) // self._batch_size
        raise ValueError(
            "last_batch must be one of 'keep', 'discard', or 'rollover', " \
            "but got %s"%self._last_batch)

# Transforms

In [4]:
from mxnet.gluon.block import Block, HybridBlock
from mxnet.gluon.nn import Sequential, HybridSequential

class Compose(Sequential):
    def __init__(self, transforms):
        super(Compose, self).__init__()
        transforms.append(None)
        hybrid = []
        for i in transforms:
            if isinstance(i, HybridBlock):
                hybrid.append(i)
                continue
            elif len(hybrid) == 1:
                self.add(hybrid[0])
                hybrid = []
            elif len(hybrid) > 1:
                hblock = HybridSequential()
                for j in hybrid:
                    hblock.add(j)
                hblock.hybridize()
                self.add(hblock)
                hybrid = []

            if i is not None:
                self.add(i)


class Cast(HybridBlock):
    def __init__(self, dtype='float32'):
        super(Cast, self).__init__()
        self._dtype = dtype

    def hybrid_forward(self, F, x):
        return F.cast(x, self._dtype)


class ToTensor(HybridBlock):
    def __init__(self):
        super(ToTensor, self).__init__()

    def hybrid_forward(self, F, x):
        return F.image.to_tensor(x)


class Normalize(HybridBlock):
    def __init__(self, mean, std):
        super(Normalize, self).__init__()
        self._mean = mean
        self._std = std

    def hybrid_forward(self, F, x):
        return F.image.normalize(x, self._mean, self._std)


class RandomResizedCrop(Block):
    def __init__(self, size, scale=(0.08, 1.0), ratio=(3.0/4.0, 4.0/3.0),
                 interpolation=1):
        super(RandomResizedCrop, self).__init__()
        if isinstance(size, numeric_types):
            size = (size, size)
        self._args = (size, scale, ratio, interpolation)

    def forward(self, x):
        return image.random_size_crop(x, *self._args)[0]


class CenterCrop(Block):
    def __init__(self, size, interpolation=1):
        super(CenterCrop, self).__init__()
        if isinstance(size, numeric_types):
            size = (size, size)
        self._args = (size, interpolation)

    def forward(self, x):
        return image.center_crop(x, *self._args)[0]


class Resize(Block):
    def __init__(self, size, keep_ratio=False, interpolation=1):
        super(Resize, self).__init__()
        self._keep = keep_ratio
        self._size = size
        self._interpolation = interpolation

    def forward(self, x):
        if isinstance(self._size, numeric_types):
            if not self._keep:
                wsize = self._size
                hsize = self._size
            else:
                h, w, _ = x.shape
                if h > w:
                    wsize = self._size
                    hsize = int(h * wsize / w)
                else:
                    hsize = self._size
                    wsize = int(w * hsize / h)
        else:
            wsize, hsize = self._size
        return image.imresize(x, wsize, hsize, self._interpolation)


class RandomFlipLeftRight(HybridBlock):
    def __init__(self):
        super(RandomFlipLeftRight, self).__init__()

    def hybrid_forward(self, F, x):
        return F.image.random_flip_left_right(x)


class RandomFlipTopBottom(HybridBlock):
    def __init__(self):
        super(RandomFlipTopBottom, self).__init__()

    def hybrid_forward(self, F, x):
        return F.image.random_flip_top_bottom(x)


class RandomBrightness(HybridBlock):
    def __init__(self, brightness):
        super(RandomBrightness, self).__init__()
        self._args = (max(0, 1-brightness), 1+brightness)

    def hybrid_forward(self, F, x):
        return F.image.random_brightness(x, *self._args)


class RandomContrast(HybridBlock):
    def __init__(self, contrast):
        super(RandomContrast, self).__init__()
        self._args = (max(0, 1-contrast), 1+contrast)

    def hybrid_forward(self, F, x):
        return F.image.random_contrast(x, *self._args)


class RandomSaturation(HybridBlock):
    def __init__(self, saturation):
        super(RandomSaturation, self).__init__()
        self._args = (max(0, 1-saturation), 1+saturation)

    def hybrid_forward(self, F, x):
        return F.image.random_saturation(x, *self._args)


class RandomHue(HybridBlock):
    def __init__(self, hue):
        super(RandomHue, self).__init__()
        self._args = (max(0, 1-hue), 1+hue)

    def hybrid_forward(self, F, x):
        return F.image.random_hue(x, *self._args)


class RandomColorJitter(HybridBlock):
    def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
        super(RandomColorJitter, self).__init__()
        self._args = (brightness, contrast, saturation, hue)

    def hybrid_forward(self, F, x):
        return F.image.random_color_jitter(x, *self._args)


class RandomLighting(HybridBlock):
    def __init__(self, alpha):
        super(RandomLighting, self).__init__()
        self._alpha = alpha

    def hybrid_forward(self, F, x):
        return F.image.random_lighting(x, self._alpha)


# Datasets

In [5]:
import os
import gzip
import tarfile
import struct
import warnings
import numpy as np

class MNIST(_DownloadedDataset):
    def __init__(self, root=os.path.join(base.data_dir(), 'datasets', 'mnist'),
                 train=True, transform=None):
        self._train = train
        self._train_data = ('train-images-idx3-ubyte.gz', '6c95f4b05d2bf285e1bfb0e7960c31bd3b3f8a7d')
        self._train_label = ('train-labels-idx1-ubyte.gz', '2a80914081dc54586dbdf242f9805a6b8d2a15fc')
        self._test_data = ('t10k-images-idx3-ubyte.gz', 'c3a25af1f52dad7f726cce8cacb138654b760d48')
        self._test_label = ('t10k-labels-idx1-ubyte.gz', '763e7fa3757d93b0cdec073cef058b2004252c17')
        self._namespace = 'mnist'
        super(MNIST, self).__init__(root, transform)

    def _get_data(self):
        if self._train:
            data, label = self._train_data, self._train_label
        else:
            data, label = self._test_data, self._test_label

        namespace = 'gluon/dataset/'+self._namespace
        data_file = download(_get_repo_file_url(namespace, data[0]),
                             path=self._root,
                             sha1_hash=data[1])
        label_file = download(_get_repo_file_url(namespace, label[0]),
                              path=self._root,
                              sha1_hash=label[1])

        with gzip.open(label_file, 'rb') as fin:
            struct.unpack(">II", fin.read(8))
            label = np.frombuffer(fin.read(), dtype=np.uint8).astype(np.int32)

        with gzip.open(data_file, 'rb') as fin:
            struct.unpack(">IIII", fin.read(16))
            data = np.frombuffer(fin.read(), dtype=np.uint8)
            data = data.reshape(len(label), 28, 28, 1)

        self._data = nd.array(data, dtype=data.dtype)
        self._label = label


class FashionMNIST(MNIST):
    def __init__(self, root=os.path.join(base.data_dir(), 'datasets', 'fashion-mnist'), train=True, transform=None):
        self._train = train
        self._train_data = ('train-images-idx3-ubyte.gz',
                            '0cf37b0d40ed5169c6b3aba31069a9770ac9043d')
        self._train_label = ('train-labels-idx1-ubyte.gz',
                             '236021d52f1e40852b06a4c3008d8de8aef1e40b')
        self._test_data = ('t10k-images-idx3-ubyte.gz',
                           '626ed6a7c06dd17c0eec72fa3be1740f146a2863')
        self._test_label = ('t10k-labels-idx1-ubyte.gz',
                            '17f9ab60e7257a1620f4ad76bbbaf857c3920701')
        self._namespace = 'fashion-mnist'
        super(MNIST, self).__init__(root, transform) # pylint: disable=bad-super-call


class CIFAR10(dataset._DownloadedDataset):
    def __init__(self, root=os.path.join(base.data_dir(), 'datasets', 'cifar10'), train=True, transform=None):
        self._train = train
        self._archive_file = ('cifar-10-binary.tar.gz', 'fab780a1e191a7eda0f345501ccd62d20f7ed891')
        self._train_data = [('data_batch_1.bin', 'aadd24acce27caa71bf4b10992e9e7b2d74c2540'),
                            ('data_batch_2.bin', 'c0ba65cce70568cd57b4e03e9ac8d2a5367c1795'),
                            ('data_batch_3.bin', '1dd00a74ab1d17a6e7d73e185b69dbf31242f295'),
                            ('data_batch_4.bin', 'aab85764eb3584312d3c7f65fd2fd016e36a258e'),
                            ('data_batch_5.bin', '26e2849e66a845b7f1e4614ae70f4889ae604628')]
        self._test_data = [('test_batch.bin', '67eb016db431130d61cd03c7ad570b013799c88c')]
        self._namespace = 'cifar10'
        super(CIFAR10, self).__init__(root, transform)

    def _read_batch(self, filename):
        with open(filename, 'rb') as fin:
            data = np.frombuffer(fin.read(), dtype=np.uint8).reshape(-1, 3072+1)

        return data[:, 1:].reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1), \
               data[:, 0].astype(np.int32)

    def _get_data(self):
        if any(not os.path.exists(path) or not check_sha1(path, sha1)
               for path, sha1 in ((os.path.join(self._root, name), sha1)
                                  for name, sha1 in self._train_data + self._test_data)):
            namespace = 'gluon/dataset/'+self._namespace
            filename = download(_get_repo_file_url(namespace, self._archive_file[0]),
                                path=self._root,
                                sha1_hash=self._archive_file[1])

            with tarfile.open(filename) as tar:
                tar.extractall(self._root)

        if self._train:
            data_files = self._train_data
        else:
            data_files = self._test_data
        data, label = zip(*(self._read_batch(os.path.join(self._root, name))
                            for name, _ in data_files))
        data = np.concatenate(data)
        label = np.concatenate(label)

        self._data = nd.array(data, dtype=data.dtype)
        self._label = label


class CIFAR100(CIFAR10):
    def __init__(self, root=os.path.join(base.data_dir(), 'datasets', 'cifar100'),
                 fine_label=False, train=True, transform=None):
        self._train = train
        self._archive_file = ('cifar-100-binary.tar.gz', 'a0bb982c76b83111308126cc779a992fa506b90b')
        self._train_data = [('train.bin', 'e207cd2e05b73b1393c74c7f5e7bea451d63e08e')]
        self._test_data = [('test.bin', '8fb6623e830365ff53cf14adec797474f5478006')]
        self._fine_label = fine_label
        self._namespace = 'cifar100'
        super(CIFAR10, self).__init__(root, transform) # pylint: disable=bad-super-call

    def _read_batch(self, filename):
        with open(filename, 'rb') as fin:
            data = np.frombuffer(fin.read(), dtype=np.uint8).reshape(-1, 3072+2)

        return data[:, 2:].reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1), \
               data[:, 0+self._fine_label].astype(np.int32)


class ImageRecordDataset(dataset.RecordFileDataset):
    def __init__(self, filename, flag=1, transform=None):
        super(ImageRecordDataset, self).__init__(filename)
        self._flag = flag
        self._transform = transform

    def __getitem__(self, idx):
        record = super(ImageRecordDataset, self).__getitem__(idx)
        header, img = recordio.unpack(record)
        if self._transform is not None:
            return self._transform(image.imdecode(img, self._flag), header.label)
        return image.imdecode(img, self._flag), header.label


class ImageFolderDataset(dataset.Dataset):
    def __init__(self, root, flag=1, transform=None):
        self._root = os.path.expanduser(root)
        self._flag = flag
        self._transform = transform
        self._exts = ['.jpg', '.jpeg', '.png']
        self._list_images(self._root)

    def _list_images(self, root):
        self.synsets = []
        self.items = []

        for folder in sorted(os.listdir(root)):
            path = os.path.join(root, folder)
            if not os.path.isdir(path):
                warnings.warn('Ignoring %s, which is not a directory.'%path, stacklevel=3)
                continue
            label = len(self.synsets)
            self.synsets.append(folder)
            for filename in sorted(os.listdir(path)):
                filename = os.path.join(path, filename)
                ext = os.path.splitext(filename)[1]
                if ext.lower() not in self._exts:
                    warnings.warn('Ignoring %s of type %s. Only support %s'%(
                        filename, ext, ', '.join(self._exts)))
                    continue
                self.items.append((filename, label))

    def __getitem__(self, idx):
        img = image.imread(self.items[idx][0], self._flag)
        label = self.items[idx][1]
        if self._transform is not None:
            return self._transform(img, label)
        return img, label

    def __len__(self):
        return len(self.items)


NameError: name 'base' is not defined