In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, metrics, Sequential

In [5]:
(x_train, y_train), (x_val, y_val) = datasets.fashion_mnist.load_data()

In [49]:
# 处理数据，由于在高阶api中不能在处理x和y的数据格式
# 所以我们需要在这个函数中转换x的shape，以及对y进行one-hot
def process(x, y):
    x = tf.cast(x, dtype=tf.float32) / 255.
    x = tf.reshape(x, [28*28])
    y = tf.cast(y, dtype=tf.int32)
    y = tf.one_hot(y, depth=10)
    return x, y

In [50]:
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train))
val_db = tf.data.Dataset.from_tensor_slices((x_val, y_val))

In [51]:
train_db = train_db.map(process).shuffle(10000).batch(128)
val_db = val_db.map(process).shuffle(10000).batch(128)

In [52]:
network = Sequential([
    layers.Dense(512, activation='relu'),
    layers.Dense(256, activation='relu'),
    layers.Dense(128, activation='relu'),
    layers.Dense(64, activation='relu'),
    layers.Dense(10)
])

network.build(input_shape=(None, 28*28))
network.summary()

Model: "sequential_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_25 (Dense)             multiple                  401920    
_________________________________________________________________
dense_26 (Dense)             multiple                  131328    
_________________________________________________________________
dense_27 (Dense)             multiple                  32896     
_________________________________________________________________
dense_28 (Dense)             multiple                  8256      
_________________________________________________________________
dense_29 (Dense)             multiple                  650       
Total params: 575,050
Trainable params: 575,050
Non-trainable params: 0
_________________________________________________________________


In [53]:
# Sequential这个类集成自Model类，Model类都有compile方法
# 用于编排网络的一些属性
# optimazer用于指定网络的优化器
# loss用于指定损失函数
# metrics指定训练时要统计的指标有哪些
network.compile(optimizer=optimizers.Adam(lr=1e-2),
               loss=tf.losses.CategoricalCrossentropy(from_logits=True),
               metrics=['accuracy'])

In [55]:
# 之后通过fit函数，就可以对网络进行训练
# 参数有input，训练的epoch数，validation数据，以及每过多少个epoch进行一次validation的测试（validation_freq）
network.fit(train_db, epochs=10, validation_data=val_db, validation_freq=1)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x7f770ba7d7f0>

In [58]:
# 也可以单独使用这个函数进行validation数据的测试
network.evaluate(val_db)



[0.387256896571268, 0.8685]

In [57]:
val_iter = iter(val_db)
sample = next(val_iter)

# 使用predict函数直接进行预测，没什么大用
pred = network.predict(sample[0])

pred = tf.argmax(pred, axis=1)

print('prediction:{}'.format(pred))

prediction:[0 4 6 3 1 7 7 3 9 1 6 5 0 9 1 1 1 3 2 1 8 8 7 0 0 6 3 1 2 9 8 9 8 0 6 2 3
 4 9 4 1 2 0 9 1 9 6 4 7 3 0 1 2 0 6 2 7 8 2 6 5 3 3 0 5 9 7 8 3 1 8 1 8 0
 7 5 8 9 6 5 7 5 2 4 6 7 0 9 6 4 0 7 3 3 7 1 9 9 0 5 7 7 5 5 3 5 0 6 0 5 2
 4 2 3 2 8 8 4 7 3 7 5 6 4 5 2 7 9]
