# [튜토리얼5] 케라스(Keras) 모델로 에스티메이터(Estimator) 만들기

이번 튜토리얼에는 케라스(Keras)를 이용해서 모델을 만드는 과정을 함께 해볼 것입니다.

텐서플로우 에스티메이터(Estimator)는 텐서플로우에서 완전히 지원되며, 새로운 모델이나 기존에 있던 `tf.keras` 모델로 생성할 수 있습니다. 

In [15]:
from __future__ import absolute_import, division, print_function, unicode_literals

import warnings
warnings.filterwarnings(action='ignore')

import tensorflow as tf

import numpy as np
import tensorflow_datasets as tfds

# 목차
1. 간단한 케라스(Keras) 모델 만들기
2. 입력 함수 만들기
3. tf.keras 모델에서 에스티메이터(Estimator) 만들기

### 1. 간단한 케라스(Keras) 모델 만들기

케라스에서는 레이어를 모으고 모델을 제작합니다. 모델은 보통 레이어의 그래프로 가장 일반적인 유형의 모델은 바로 레이어를 쌓는 것입니다.
`tf.keras.Sequential`을 이용해 모델을 만듭니다.

단순하고 완전히 연결된 네트워크(즉, 다중 레이어 인식자)를 구축하려면 다음을 수행합니다:

In [2]:
model = tf.keras.models.Sequential([
    tf.keras.layers.InputLayer(input_shape=(4,), name ='dense_input'),
    tf.keras.layers.Dense(16, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(1, activation='sigmoid')
])

모델을 컴파일하고 모델 구성을 확인하기 위한 모델 요약을 확인합니다.

In [3]:
model.compile(loss='categorical_crossentropy', optimizer='adam')
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense (Dense)                (None, 16)                80        
_________________________________________________________________
dropout (Dropout)            (None, 16)                0         
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 17        
Total params: 97
Trainable params: 97
Non-trainable params: 0
_________________________________________________________________


### 2. 입력 함수 만들기

`Datasets` API를 사용하여 대규모 데이터셋이나 다중 장치(multi-device) 학습으로 확장할 수 있습니다.

에스티메이터(Estimator)는 입력 파이프라인을 만드는 시기와 방법을 제어해야 합니다. 이를 위해 **직접 생성한 입력 함수** 또는 `input_fn`이 필요합니다. `Estimator`는 인수 없이 이 함수를 호출합니다. `input_fn`은 `tf.data.Dataset`을 반환해야 합니다.

In [4]:
def input_fn():
    split = tfds.Split.TRAIN
    dataset = tfds.load('iris', split=split, as_supervised=True)
    dataset = dataset.map(lambda features, labels: ({'dense_input':features}, labels))
    dataset = dataset.batch(32).repeat()
    return dataset

`input_fn`을 테스트해봅시다.

In [12]:
for features_batch, labels_batch in input_fn().take(1):
    print(features_batch)
    print(labels_batch)

{'dense_input': <tf.Tensor: id=2165, shape=(32, 4), dtype=float32, numpy=
array([[6.1, 2.8, 4.7, 1.2],
       [5.7, 3.8, 1.7, 0.3],
       [7.7, 2.6, 6.9, 2.3],
       [6. , 2.9, 4.5, 1.5],
       [6.8, 2.8, 4.8, 1.4],
       [5.4, 3.4, 1.5, 0.4],
       [5.6, 2.9, 3.6, 1.3],
       [6.9, 3.1, 5.1, 2.3],
       [6.2, 2.2, 4.5, 1.5],
       [5.8, 2.7, 3.9, 1.2],
       [6.5, 3.2, 5.1, 2. ],
       [4.8, 3. , 1.4, 0.1],
       [5.5, 3.5, 1.3, 0.2],
       [4.9, 3.1, 1.5, 0.1],
       [5.1, 3.8, 1.5, 0.3],
       [6.3, 3.3, 4.7, 1.6],
       [6.5, 3. , 5.8, 2.2],
       [5.6, 2.5, 3.9, 1.1],
       [5.7, 2.8, 4.5, 1.3],
       [6.4, 2.8, 5.6, 2.2],
       [4.7, 3.2, 1.6, 0.2],
       [6.1, 3. , 4.9, 1.8],
       [5. , 3.4, 1.6, 0.4],
       [6.4, 2.8, 5.6, 2.1],
       [7.9, 3.8, 6.4, 2. ],
       [6.7, 3. , 5.2, 2.3],
       [6.7, 2.5, 5.8, 1.8],
       [6.8, 3.2, 5.9, 2.3],
       [4.8, 3. , 1.4, 0.3],
       [4.8, 3.1, 1.6, 0.2],
       [4.6, 3.6, 1. , 0.2],
       [5.7, 4.4, 1.5, 0.4]

### 3. tf.keras 모델에서 에스티메이터(Estimator) 만들기

 `tf.keras.estimator.model_to_estimator`를 이용해 모델을 `tf.estimator.Estimator`로 변환함으로써 `tf.keras.Model`을 `tf.estimator` API로  학습시킬 수 있습니다.

In [16]:
import tempfile
model_dir = tempfile.mkdtemp()

#model_dir = "/tmp/tfkeras_example/"
keras_estimator = tf.keras.estimator.model_to_estimator(
    keras_model=model, model_dir=model_dir)

INFO:tensorflow:Using default config.


INFO:tensorflow:Using default config.


INFO:tensorflow:Using the Keras model provided.


INFO:tensorflow:Using the Keras model provided.


INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpsiv74_2e', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fa8b86f2c88>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}


INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpsiv74_2e', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fa8b86f2c88>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}


에스티메이터(Estimator)를 학습시키고 평가합니다.

In [17]:
keras_estimator.train(input_fn=input_fn, steps=10)
eval_result = keras_estimator.evaluate(input_fn=input_fn, steps=10)
print('Eval result: {}'.format(eval_result))

INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmpsiv74_2e/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})


INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmpsiv74_2e/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})


INFO:tensorflow:Warm-starting from: /tmp/tmpsiv74_2e/keras/keras_model.ckpt


INFO:tensorflow:Warm-starting from: /tmp/tmpsiv74_2e/keras/keras_model.ckpt


INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.


INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.


INFO:tensorflow:Warm-started 4 variables.


INFO:tensorflow:Warm-started 4 variables.


INFO:tensorflow:Create CheckpointSaverHook.


INFO:tensorflow:Create CheckpointSaverHook.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpsiv74_2e/model.ckpt.


INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpsiv74_2e/model.ckpt.


INFO:tensorflow:loss = 136.93942, step = 0


INFO:tensorflow:loss = 136.93942, step = 0


INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmpsiv74_2e/model.ckpt.


INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmpsiv74_2e/model.ckpt.


INFO:tensorflow:Loss for final step: 105.213394.


INFO:tensorflow:Loss for final step: 105.213394.


INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Starting evaluation at 2020-09-15T02:12:04Z


INFO:tensorflow:Starting evaluation at 2020-09-15T02:12:04Z


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Restoring parameters from /tmp/tmpsiv74_2e/model.ckpt-10


INFO:tensorflow:Restoring parameters from /tmp/tmpsiv74_2e/model.ckpt-10


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Evaluation [1/10]


INFO:tensorflow:Evaluation [1/10]


INFO:tensorflow:Evaluation [2/10]


INFO:tensorflow:Evaluation [2/10]


INFO:tensorflow:Evaluation [3/10]


INFO:tensorflow:Evaluation [3/10]


INFO:tensorflow:Evaluation [4/10]


INFO:tensorflow:Evaluation [4/10]


INFO:tensorflow:Evaluation [5/10]


INFO:tensorflow:Evaluation [5/10]


INFO:tensorflow:Evaluation [6/10]


INFO:tensorflow:Evaluation [6/10]


INFO:tensorflow:Evaluation [7/10]


INFO:tensorflow:Evaluation [7/10]


INFO:tensorflow:Evaluation [8/10]


INFO:tensorflow:Evaluation [8/10]


INFO:tensorflow:Evaluation [9/10]


INFO:tensorflow:Evaluation [9/10]


INFO:tensorflow:Evaluation [10/10]


INFO:tensorflow:Evaluation [10/10]


INFO:tensorflow:Finished evaluation at 2020-09-15-02:12:04


INFO:tensorflow:Finished evaluation at 2020-09-15-02:12:04


INFO:tensorflow:Saving dict for global step 10: global_step = 10, loss = 113.93606


INFO:tensorflow:Saving dict for global step 10: global_step = 10, loss = 113.93606


INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmp/tmpsiv74_2e/model.ckpt-10


INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmp/tmpsiv74_2e/model.ckpt-10


Eval result: {'loss': 113.93606, 'global_step': 10}


# Copyright 2019 The TensorFlow Authors.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.