In [15]:
import tensorflow as tf
import pandas as pd
from IPython.display import clear_output

In [3]:
train_path = tf.keras.utils.get_file(
    "iris_training.csv", "https://storage.googleapis.com/download.tensorflow.org/data/iris_training.csv")
test_path = tf.keras.utils.get_file(
    "iris_test.csv", "https://storage.googleapis.com/download.tensorflow.org/data/iris_test.csv")

In [4]:
columns = ['SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth', 'Species']
labels = ['Setosa', 'Versicolor', 'Virginica']
train = pd.read_csv(train_path, header=0, names=columns)
test = pd.read_csv(test_path, header=0, names=columns)

In [5]:
train

Unnamed: 0,SepalLength,SepalWidth,PetalLength,PetalWidth,Species
0,6.4,2.8,5.6,2.2,2
1,5.0,2.3,3.3,1.0,1
2,4.9,2.5,4.5,1.7,2
3,4.9,3.1,1.5,0.1,0
4,5.7,3.8,1.7,0.3,0
...,...,...,...,...,...
115,5.5,2.6,4.4,1.2,1
116,5.7,3.0,4.2,1.2,1
117,4.4,2.9,1.4,0.2,0
118,4.8,3.0,1.4,0.1,0


In [6]:
train_y = train.pop('Species')
test_y = test.pop('Species')

In [10]:
feature = []  # 特征列
for col in train.columns:
    feature.append(tf.feature_column.numeric_column(col))
feature

[NumericColumn(key='SepalLength', shape=(1,), default_value=None, dtype=tf.float32, normalizer_fn=None),
 NumericColumn(key='SepalWidth', shape=(1,), default_value=None, dtype=tf.float32, normalizer_fn=None),
 NumericColumn(key='PetalLength', shape=(1,), default_value=None, dtype=tf.float32, normalizer_fn=None),
 NumericColumn(key='PetalWidth', shape=(1,), default_value=None, dtype=tf.float32, normalizer_fn=None)]

In [18]:
def input_fn(x, y=None, training=True, batch_size=256):
    ds = tf.data.Dataset.from_tensor_slices((dict(x), y))
    if training:
        ds = ds.shuffle(1000).repeat()  # 用steps控制终止
    return ds.batch(batch_size)

In [23]:
classifier = tf.estimator.DNNClassifier([64, 32], feature, n_classes=3)
clear_output()

In [24]:
classifier.train(lambda: input_fn(train, train_y), steps=5000)

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into C:\Users\yy\AppData\Local\Temp\tmp09muw3er\model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 1.1503642, step = 0
INFO:tensorflow:global_step/sec: 280.602
INFO:tensorflow:loss = 0.90847456, step = 100 (0.357 sec)
INFO:tensorflow:global_step/sec: 437.701
INFO:tensorflow:loss = 0.8047812, step = 200 (0.228 sec)
INFO:tensorflow:global_step/sec: 436.197
INFO:tensorflow:loss = 0.74282455, step = 300 (0.229 sec)
INFO:tensorflow:global_step/sec: 431.409
INFO:tensorflow:loss = 0.69320524, step = 400 (0.232 sec)
INFO:tensorflow:global_step/sec: 434.77
INFO:tensorflow:loss = 0.6603484, s

<tensorflow_estimator.python.estimator.canned.dnn.DNNClassifierV2 at 0x16d6a722250>

In [27]:
classifier.evaluate(lambda: input_fn(test, test_y, training=False))

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-07-06T16:14:14
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from C:\Users\yy\AppData\Local\Temp\tmp09muw3er\model.ckpt-5000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Inference Time : 0.25187s
INFO:tensorflow:Finished evaluation at 2022-07-06-16:14:14
INFO:tensorflow:Saving dict for global step 5000: accuracy = 0.96666664, average_loss = 0.25931346, global_step = 5000, loss = 0.25931346
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 5000: C:\Users\yy\AppData\Local\Temp\tmp09muw3er\model.ckpt-5000


{'accuracy': 0.96666664,
 'average_loss': 0.25931346,
 'loss': 0.25931346,
 'global_step': 5000}

In [32]:
# 自定义例子输入
expected = ['Setosa', 'Versicolor', 'Virginica']
predict_x = {
    'SepalLength': [5.1, 5.9, 6.9],
    'SepalWidth': [3.3, 3.0, 3.1],
    'PetalLength': [1.7, 4.2, 5.4],
    'PetalWidth': [0.5, 1.5, 2.1],
}
result = list(classifier.predict(lambda: input_fn(predict_x, training=False)))
clear_output()
result[0]

{'logits': array([ 5.2090673,  2.749156 , -0.8937769], dtype=float32),
 'probabilities': array([0.9193889 , 0.07855491, 0.00205621], dtype=float32),
 'class_ids': array([0], dtype=int64),
 'classes': array([b'0'], dtype=object),
 'all_class_ids': array([0, 1, 2]),
 'all_classes': array([b'0', b'1', b'2'], dtype=object)}

In [34]:
def print_def(dic):  # 自定义预测显示函数
    class_id = dic['class_ids'][0]
    probability = dic['probabilities'][class_id]
    print('预测为 "{}" ({:.2f}%)'.format(labels[class_id], probability))
for dic in result:
    print_def(dic)  # 与expected完全符合

预测为 "Setosa" (0.92%)
预测为 "Versicolor" (0.71%)
预测为 "Virginica" (0.72%)
