In [7]:
import iris_data
import tensorflow as tf 

## 采用预创建的Estimator

1. 创建输入函数
2. 定义特征列
3. 实例化Estimator
4. 训练、评估和预测

In [5]:
def input_evaluation_set():
    features = {'SepalLength': np.array([6.4, 5.0]),
                'SepalWidth':  np.array([2.8, 2.3]),
                'PetalLength': np.array([5.6, 3.3]),
                'PetalWidth':  np.array([2.2, 1.0])}
    labels = np.array([2, 1])
    return features, labels

In [3]:
def train_input_fn(features, labels, batch_size):
    """An input function for training"""
    # Convert the inputs to a Dataset.
    dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))

    # Shuffle, repeat, and batch the examples.
    return dataset.shuffle(1000).repeat().batch(batch_size)

## 定义特征列　

特征列是一个对象，用于说明模型应该如何使用特征字典中的原始输入数据。在构建 Estimator 模型时，您会向其传递一个特征列的列表，其中包含您希望模型使用的每个特征。tf.feature_column 模块提供很多用于在模型中表示数据的选项。

对于鸢尾花问题，4 个原始特征是数值，因此我们会构建一个特征列的列表，以告知 Estimator 模型将这 4 个特征都表示为 32 位浮点值。因此，创建特征列的代码如下所示：

In [8]:
# Fetch the data
(train_x, train_y), (test_x, test_y) = iris_data.load_data()

# Feature columns describe how to use the input.
my_feature_columns = []
for key in train_x.keys():
    my_feature_columns.append(tf.feature_column.numeric_column(key=key))

In [10]:
train_x.keys()

Index(['SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth'], dtype='object')

## 实例化Estimator

鸢尾花问题是一个经典的分类问题。幸运的是，TensorFlow 提供了几个预创建的分类器 Estimator，其中包括：

- tf.estimator.DNNClassifier：适用于执行多类别分类的深度模型。
- tf.estimator.DNNLinearCombinedClassifier：适用于宽度和深度模型。
- tf.estimator.LinearClassifier：适用于基于线性模型的分类器。

对于鸢尾花问题，tf.estimator.DNNClassifier 似乎是最好的选择。我们将如下所示地实例化此 Estimator：

In [12]:
# Build a DNN with 2 hidden layers and 10 nodes in each hidden layer.
classifier = tf.estimator.DNNClassifier(
    feature_columns = my_feature_columns,
    # Two hidden layers of 10 nodes each.
    hidden_units = [10, 10],
    # The model must choose between 3 classes.
    n_classes = 3,
    model_dir = 'models/iris'
)

In [13]:
classifier.train(
input_fn=lambda: train_input_fn(train_x, train_y,batch_size=100),
steps=200)

W0911 21:09:08.057539 139799877797696 deprecation.py:323] From /home/tianqin/.conda/envs/tensorflow/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
W0911 21:09:08.087918 139799877797696 deprecation.py:506] From /home/tianqin/.conda/envs/tensorflow/lib/python3.7/site-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
W0911 21:09:08.756662 139799877797696 deprecation.py:323] From /home/tianqin/.conda/envs/tensorflow/lib/python3.7/site-packages/

<tensorflow_estimator.python.estimator.canned.dnn.DNNClassifier at 0x7f256609d1d0>

In [14]:
classifier.model_dir

'models/iris'

In [15]:
def eval_input_fn(features, labels, batch_size):
    """An input function for evaluation or prediction"""
    features=dict(features)
    if labels is None:
        # No labels, use only features.
        inputs = features
    else:
        inputs = (features, labels)

    # Convert the inputs to a Dataset.
    dataset = tf.data.Dataset.from_tensor_slices(inputs)

    # Batch the examples
    assert batch_size is not None, "batch_size must not be None"
    dataset = dataset.batch(batch_size)

    # Return the dataset.
    return dataset

In [16]:
# Evaluate the model 
eval_result = classifier.evaluate(
input_fn=lambda: eval_input_fn(test_x, test_y, batch_size=100))

print('Tst set accuracy: {accuracy:0.3f}'.format(**eval_result))

W0911 21:16:15.029642 139799877797696 deprecation.py:323] From /home/tianqin/.conda/envs/tensorflow/lib/python3.7/site-packages/tensorflow/python/training/saver.py:1276: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to check for files with this prefix.


Tst set accuracy: 0.967


In [17]:
eval_result

{'accuracy': 0.96666664,
 'average_loss': 0.09045978,
 'loss': 2.7137933,
 'global_step': 200}

In [38]:
# Generate predictions from the model
expected = ['Setosa', 'Versicolor', 'Virginica']
predict_x = {
    'SepalLength': [5.1, 5.9, 6.9],
    'SepalWidth': [3.3, 3.0, 3.1],
    'PetalLength': [1.7, 4.2, 5.4],
    'PetalWidth': [0.5, 1.5, 2.1],
}

predictions = classifier.predict(input_fn=lambda: iris_data.eval_input_fn(
    predict_x, labels=None, batch_size=5))

In [39]:
template = ('\nPrediction is "{}" ({:.1f}%), expected "{}"')

for pred_dict, expec in zip(predictions, expected):
    class_id = pred_dict['class_ids'][0]
    probability = pred_dict['probabilities'][class_id]

    print(template.format(iris_data.SPECIES[class_id],
                          100 * probability, expec))


Prediction is "Setosa" (98.5%), expected "Setosa"

Prediction is "Versicolor" (95.2%), expected "Versicolor"

Prediction is "Virginica" (91.5%), expected "Virginica"
