In [3]:
import mxnet as mx
from mxnet import gluon, nd
import os

dirs = '../../tmp/NDArrayFileDataset'

def mkdir_if_not_exist(path):
    if not os.path.exists(os.path.join(*path)):
        os.makedirs(os.path.join(*path))


## 1. create ndarray data

In [2]:

mkdir_if_not_exist([dirs])

mkdir_if_not_exist([dirs, 'label0'])
mkdir_if_not_exist([dirs, 'label1'])

for i in range(10):
    data = nd.random.uniform(shape=(3, 32, 32))
    data[0, 0, 0] = i
    nd.save(dirs + '/label0/data_' + str(i) + ".ndarray", data)
    
for i in range(10):
    data = nd.random.uniform(shape=(3, 32, 32))
    data[0, 0, 0] = i + 10
    nd.save(dirs + '/label1/data_' + str(i) + ".ndarray", data)

In [3]:
data = nd.load(dirs + '/label0/data_0.ndarray')[0]
print data.shape, data[0, 0, 0]

(3L, 32L, 32L) 
[ 0.]
<NDArray 1 @cpu(0)>


## 2. create dataset and use data loader

In [4]:
import warnings
from mxnet import gluon, nd
class NDArrayFolderDataset(gluon.data.dataset.Dataset):
    """A dataset for loading ndarray files stored in a folder structure like::

        root/car/0001.ndarray
        root/car/xxxa.ndarray
        root/car/yyyb.ndarray
        root/bus/123.ndarray
        root/bus/023.ndarray
        root/bus/wwww.ndarray

    Parameters
    ----------
    root : str
        Path to root directory.
    transform : callable, default None
        A function that takes data and label and transforms them:
    ::

        transform = lambda data, label: (data.astype(np.float32)/255, label)

    Attributes
    ----------
    synsets : list
        List of class names. `synsets[i]` is the name for the integer label `i`
    items : list of tuples
        List of all ndarrays in (filename, label) pairs.
    """
    def __init__(self, root, transform=None):
        self._root = os.path.expanduser(root)
        self._transform = transform
        self._exts = ['.ndarray']
        self._list_images(self._root)

    def _list_images(self, root):
        self.synsets = []
        self.items = []

        for folder in sorted(os.listdir(root)):
            path = os.path.join(root, folder)
            if not os.path.isdir(path):
                warnings.warn('Ignoring %s, which is not a directory.'%path, stacklevel=3)
                continue
            label = len(self.synsets)
            self.synsets.append(folder)
            for filename in sorted(os.listdir(path)):
                filename = os.path.join(path, filename)
                ext = os.path.splitext(filename)[1]
                if ext.lower() not in self._exts:
                    warnings.warn('Ignoring %s of type %s. Only support %s'%(
                        filename, ext, ', '.join(self._exts)))
                    continue
                self.items.append((filename, label))

    def __getitem__(self, idx):
        data = nd.load(self.items[idx][0])[0]
        label = self.items[idx][1]
        if self._transform is not None:
            return self._transform(data, label)
        return data, label

    def __len__(self):
        return len(self.items)

In [5]:
ndarrayds = NDArrayFolderDataset(dirs, None)
train_data = gluon.data.DataLoader(ndarrayds, batch_size=4, shuffle=True)

In [6]:
for data, label in train_data:
    print data[:, 0, 0, 0], label


[ 14.   8.   6.  11.]
<NDArray 4 @cpu(0)> 
[1 0 0 1]
<NDArray 4 @cpu(0)>

[ 15.  18.   0.   1.]
<NDArray 4 @cpu(0)> 
[1 1 0 0]
<NDArray 4 @cpu(0)>

[  7.  19.  16.   2.]
<NDArray 4 @cpu(0)> 
[0 1 1 0]
<NDArray 4 @cpu(0)>

[ 10.   4.   3.  12.]
<NDArray 4 @cpu(0)> 
[1 0 0 1]
<NDArray 4 @cpu(0)>

[  5.   9.  17.  13.]
<NDArray 4 @cpu(0)> 
[0 0 1 1]
<NDArray 4 @cpu(0)>


In [7]:
import warnings
from mxnet import gluon, nd, image
class MultiFolderDataset(gluon.data.dataset.Dataset):
    """A dataset for loading ndarray files or image files stored in a folder structure like::

        roots[0]/car/0001.ndarray
        roots[0]/car/xxxa.ndarray
        roots[0]/car/yyyb.ndarray
        roots[0]/bus/123.ndarray
        roots[0]/bus/023.ndarray
        roots[0]/bus/wwww.ndarray
        
        roots[1]/car/0001.ndarray
        roots[1]/car/xxxa.ndarray
        roots[1]/car/yyyb.ndarray
        roots[1]/bus/123.ndarray
        roots[1]/bus/023.ndarray
        roots[1]/bus/wwww.ndarray

    Parameters
    ----------
    root : str
        Path to root directory.
    transform : callable, default None
        A function that takes data and label and transforms them:
    ::

        transform = lambda data, label: (data.astype(np.float32)/255, label)

    Attributes
    ----------
    synsets : list
        List of class names. `synsets[i]` is the name for the integer label `i`
    items : list of tuples
        List of all ndarrays in (filename, label) pairs.
    """
    def __init__(self, roots, flag=1, transform=None):
        self._roots = []
        for root in roots:
            self._roots.append(os.path.expanduser(root))
        self._flag = flag
        self._transform = transform
        self._exts = ['.ndarray', '.jpeg', '.jpg', '.png']
        self._label_dict = {}
        self.synsets = []
        self.items = []
        for root in self._roots:
            self._list_images(root)

    def _list_images(self, root):
        for folder in sorted(os.listdir(root)):
            if folder[0] == '.': continue
            path = os.path.join(root, folder)
            if not os.path.isdir(path):
                warnings.warn('Ignoring %s, which is not a directory.'%path, stacklevel=3)
                continue
                
            if not self._label_dict.has_key(folder):
                self._label_dict[folder] = len(self.synsets)
                self.synsets.append(folder)
            label = self._label_dict[folder]
            
            for filename in sorted(os.listdir(path)):
                filename = os.path.join(path, filename)
                ext = os.path.splitext(filename)[1]
                if ext.lower() not in self._exts:
                    warnings.warn('Ignoring %s of type %s. Only support %s'%(
                        filename, ext, ', '.join(self._exts)))
                    continue
                self.items.append((filename, label))

    def __getitem__(self, idx):
        if (os.path.splitext(self.items[idx][0])[1]).lower() == '.ndarray':
            data = nd.load(self.items[idx][0])[0]
        else:
            data = image.imread(self.items[idx][0], self._flag)
        label = self.items[idx][1]
        if self._transform is not None:
            return self._transform(data, label)
        return data, label

    def __len__(self):
        return len(self.items)

In [8]:
multifolderds = MultiFolderDataset([dirs, '/home/hui/dataset/CIFAR10_kaggle/train_valid_test/train_valid'], transform=None)
train_data = gluon.data.DataLoader(multifolderds, batch_size=4, shuffle=True)

In [9]:
print len(train_data)

12505
