In [1]:
import tensorflow as tf
from tensorflow import keras

In [2]:
(x, y), (x_test, y_test) = keras.datasets.mnist.load_data()

### MNIST

In [3]:
type(x), type(y), x.shape, y.shape

(numpy.ndarray, numpy.ndarray, (60000, 28, 28), (60000,))

In [4]:
x.min(), x.max(), x.mean()

(0, 255, 33.318421449829934)

In [5]:
y.min(), y.max()

(0, 9)

In [6]:
x_test.shape, y_test.shape

((10000, 28, 28), (10000,))

In [7]:
y[:4]

array([5, 0, 4, 1], dtype=uint8)

In [8]:
y_onehot = tf.one_hot(y, depth=10)
y_onehot[:2]

<tf.Tensor: shape=(2, 10), dtype=float32, numpy=
array([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)>

### CIFAR10/100

In [9]:
(x, y), (x_test, y_test) = keras.datasets.cifar10.load_data()
# 若报错：CERTIFICATE_VERIFY_FAILED
# 解决办法：
# 在启动jupyter notebook之前，执行此命令：
# /Applications/Python\ 3.6/Install\ Certificates.command

In [10]:
x.shape, y.shape, x_test.shape, y_test.shape

((50000, 32, 32, 3), (50000, 1), (10000, 32, 32, 3), (10000, 1))

In [14]:
x.min(), x.max(), y.min(), y.max()

(0, 255, 0, 9)

In [15]:
y[:4]

array([[6],
       [9],
       [9],
       [4]], dtype=uint8)

### tf.data.Dataset
- from_tensor_slices()

In [16]:
(x,y),(x_test,y_test) = keras.datasets.cifar10.load_data()

In [17]:
db = tf.data.Dataset.from_tensor_slices(x_test)

In [18]:
next(iter(db)).shape

TensorShape([32, 32, 3])

In [19]:
db = tf.data.Dataset.from_tensor_slices((x_test, y_test))

In [20]:
next(iter(db))[0].shape

TensorShape([32, 32, 3])

### .shuffle
随机

In [21]:
db = tf.data.Dataset.from_tensor_slices((x_test, y_test))

In [22]:
db = db.shuffle(10000)

### .map

In [23]:
def preprocess(x,y):
    x = tf.cast(x, dtype=tf.float32)/255
    y = tf.cast(y, dtype=tf.int32)
    y = tf.one_hot(y, depth=10)
    return x,y

In [24]:
db2 = db.map(preprocess)

In [25]:
res = next(iter(db2))

In [26]:
res[0].shape, res[1].shape

(TensorShape([32, 32, 3]), TensorShape([1, 10]))

In [27]:
res[1][:2]

<tf.Tensor: shape=(1, 10), dtype=float32, numpy=array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]], dtype=float32)>

### .batch

In [31]:
db3 = db2.batch(32)

In [32]:
res = next(iter(db3))

In [33]:
res[0].shape, res[1].shape

(TensorShape([32, 32, 32, 3]), TensorShape([32, 1, 10]))

### StopIteration

In [34]:
db_iter = iter(db3)

In [35]:
while True:
    next(db_iter)

StopIteration: 

### .repeat()

In [36]:
db4 = db3.repeat()

In [37]:
db4 = db3.repeat(2)

### For example

In [39]:
def prepare_mnist_features_and_labels(x, y):
    x = tf.cast(x, tf.float32) / 255.0
    y = tf.cast(y, tf.int64)
    return x, y

def mnist_dataset():
    (x, y), (x_val, y_val) = datasets.fashion_mnist.load_data()
    y = tf.one_hot(y, depth=10)
    y_val = tf.one_hot(y_val, depth=10)
    
    ds = tf.data.Dataset.from_tensor_slices((x, y))
    ds = ds.map(prepare_mnist_features_and_labels)
    ds = ds.shuffle(60000).batch(100)
    ds_val = tf.data.Dataset.from_tensor_slices((x_val, y_val))
    ds_val = ds_val.map(prepare_mnist_features_and_labels)
    ds_val = ds_val.shuffle(10000).batch(100)
    return ds, ds_val