In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds

# demo

In [2]:
ds = tf.data.Dataset.range(10)
ds = ds.shuffle(10)
ds = ds.batch(3) # shuffle first, then batch
for d in ds:
    print(d)

Metal device set to: Apple M1 Pro
tf.Tensor([5 0 7], shape=(3,), dtype=int64)
tf.Tensor([9 2 6], shape=(3,), dtype=int64)
tf.Tensor([8 3 1], shape=(3,), dtype=int64)
tf.Tensor([4], shape=(1,), dtype=int64)


2022-06-05 14:47:16.148904: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2022-06-05 14:47:16.149051: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


# load datasets

In [3]:
def load_dataset(dataset_name, batch_size, transform_fn=None, data_dir='./tensorflow_datasets/'):
    ds_train, ds_info = tfds.load(dataset_name, split='train', shuffle_files=True,
                                  data_dir=data_dir, as_supervised=True, with_info=True)
    ds_test = tfds.load(dataset_name, split='test', shuffle_files=True,
                        data_dir=data_dir, as_supervised=True)
    if transform_fn is not None: # transform data element-wise
        ds_train = ds_train.map(transform_fn)
        ds_test = ds_test.map(transform_fn)
    train_full_size = ds_info.splits['train'].num_examples
    test_full_size = ds_info.splits['test'].num_examples
    ds_train = ds_train.shuffle(train_full_size)
    ds_test = ds_test.shuffle(test_full_size)
    ds_train = ds_train.batch(batch_size)
    ds_test = ds_test.batch(batch_size)
    return ds_train, ds_test  # return tf.data.Dataset as iterable

In [4]:
ds_train, ds_test = load_dataset('mnist', 128, transform_fn=None)

In [5]:
for imgs, labels in ds_train:
    print(tf.shape(imgs))
    print(tf.shape(labels))
    print(labels)
    break

2022-06-05 14:47:16.237537: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


tf.Tensor([128  28  28   1], shape=(4,), dtype=int32)
tf.Tensor([128], shape=(1,), dtype=int32)
tf.Tensor(
[4 2 5 0 9 8 9 2 8 1 7 7 3 1 2 0 2 5 6 1 6 0 7 9 1 0 0 7 2 1 5 5 8 8 9 1 7
 8 0 3 2 6 6 0 0 6 9 1 0 0 8 0 5 8 6 2 8 5 5 7 8 4 4 8 4 0 5 3 3 8 2 1 3 6
 3 8 1 3 2 6 0 7 0 2 0 3 9 6 3 4 0 8 2 0 7 1 0 1 3 3 7 2 6 9 1 8 9 6 3 8 9
 1 2 1 4 3 4 6 4 2 5 8 0 8 2 9 1 2], shape=(128,), dtype=int64)


In [6]:
def normalize_MNIST(img, label):
    img = tf.cast(tf.reshape(img, [28 * 28]), tf.float32) / 255. - 0.5
    label = tf.cast(tf.reshape(label, [1]), tf.float32)
    return img, label

In [7]:
ds_train, ds_test = load_dataset('mnist', 128, transform_fn=normalize_MNIST)

In [8]:
for imgs, labels in ds_train:
    print(tf.shape(imgs))
    print(tf.shape(labels))
    print(labels)
    break

tf.Tensor([128 784], shape=(2,), dtype=int32)
tf.Tensor([128   1], shape=(2,), dtype=int32)
tf.Tensor(
[[2.]
 [2.]
 [0.]
 [1.]
 [1.]
 [2.]
 [4.]
 [4.]
 [8.]
 [6.]
 [3.]
 [4.]
 [6.]
 [6.]
 [4.]
 [7.]
 [2.]
 [6.]
 [2.]
 [0.]
 [3.]
 [0.]
 [1.]
 [3.]
 [6.]
 [2.]
 [5.]
 [9.]
 [8.]
 [6.]
 [8.]
 [3.]
 [8.]
 [7.]
 [9.]
 [0.]
 [6.]
 [0.]
 [3.]
 [2.]
 [8.]
 [8.]
 [2.]
 [2.]
 [5.]
 [1.]
 [4.]
 [5.]
 [7.]
 [8.]
 [4.]
 [2.]
 [0.]
 [7.]
 [3.]
 [5.]
 [6.]
 [3.]
 [9.]
 [9.]
 [3.]
 [8.]
 [9.]
 [1.]
 [1.]
 [7.]
 [9.]
 [5.]
 [3.]
 [2.]
 [9.]
 [8.]
 [5.]
 [8.]
 [0.]
 [2.]
 [8.]
 [9.]
 [5.]
 [0.]
 [0.]
 [2.]
 [9.]
 [2.]
 [4.]
 [9.]
 [2.]
 [6.]
 [1.]
 [4.]
 [3.]
 [7.]
 [8.]
 [1.]
 [7.]
 [1.]
 [3.]
 [6.]
 [9.]
 [9.]
 [6.]
 [8.]
 [5.]
 [8.]
 [2.]
 [8.]
 [0.]
 [6.]
 [2.]
 [5.]
 [4.]
 [9.]
 [8.]
 [0.]
 [6.]
 [8.]
 [4.]
 [1.]
 [3.]
 [2.]
 [1.]
 [9.]
 [4.]
 [7.]
 [0.]
 [8.]
 [3.]
 [6.]], shape=(128, 1), dtype=float32)
