# 多分类问题

手写数字图像识别，可应用场景：打标签、手势识别等

## 1. 导入包，并打印版本信息

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib
from matplotlib import pyplot as plt

mnist = tf.keras.datasets.mnist
models = tf.keras.models
layers = tf.keras.layers
activations = tf.keras.activations
optimizers = tf.keras.optimizers
losses = tf.keras.losses

print('tensorflow: ' + tf.__version__)
print('numpy: ' + np.__version__)
print('matplotlib: ' + matplotlib.__version__)


## 2. 准备样本数据
tf框架自带一些测试数据

In [None]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()

## 3. 打印数据格式

In [None]:
print('x_train.shape: ' + str(x_train.shape))
print('y_train.shape: ' + str(y_train.shape))

print('x_test.shape: ' + str(x_test.shape))
print('y_test.shape: ' + str(y_test.shape))

print(y_train)

## 4. 使用matplotlib显示图片数据

In [None]:
plt.figure(figsize=(10, 10))
for i in range(25):
    plt.subplot(5, 5, i + 1)
    plt.xticks([])
    plt.yticks([])
    plt.imshow(x_train[i], cmap=plt.cm.binary)
    plt.grid(True)
    plt.xlabel('idx:' + str(i) + ', val:' + str(y_train[i]))
plt.show()

## 5. 处理数据

灰度值为整数值，处理成浮点数

In [None]:
x_train, x_test = x_train / 255.0, x_test / 255.0

## 6. 创建模型

In [None]:
 model = models.Sequential(
        [
            layers.Flatten(input_shape=(28,28)),
            layers.Dense(512, activation=activations.relu),
            layers.Dense(10, activation=activations.softmax)
        ])

## 7. 编译模型

In [None]:
model.compile(
        optimizer=optimizers.Adam(),
        loss=losses.sparse_categorical_crossentropy,
        metrics=['accuracy'])

## 8. 训练模型

In [None]:
model.fit(x_train, y_train, epochs=5)  

## 9. 使用测试集评估模型

In [None]:
model.evaluate(x_test, y_test)

## 10. 使用模型预测结果

对测试集前100个样本做预测

In [None]:
x_samples = x_test[:100]
y_samples = y_test[:100]

predications = model.predict(x_samples)

## 11. 简单看一下预测结果

In [None]:
print('shape: '+ str(predications.shape))
print('predications[0]: ' + str(predications[0]))

## 12. 使用图形显示100个结果

In [None]:
plt.figure(figsize=(20, 20))
for i in range(100):
    plt.subplot(10, 10, i + 1)
    plt.xticks([])
    plt.yticks([])
    plt.imshow(x_samples[i], cmap=plt.cm.binary)
    plt.grid(True)
    if np.argmax(predications[i]) != y_samples[i]:
        plt.xlabel('predict:' + str(np.argmax(predications[i])) + ', actual:' + str(y_samples[i]), color='red')
    else:
        plt.xlabel('predict:' + str(np.argmax(predications[i])) + ', actual:' + str(y_samples[i]))
plt.show()