# Keras Model to tf.Estimator conversion

    This notebook demonstrates how to convert a Keras model to Estimator. This is a handy scenario when we use Keras to buid the model in a simple and clean manner and then convert to Estimator for leveraging advantages of distributed and scaled training across mulltiple GPUs, TPUs.

In [1]:
import tensorflow as tf
import pandas as pd
import numpy as np
import tensorboard as tb
import matplotlib.pyplot as plt

In [2]:
fashion_mnist = tf.keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

In [3]:
train_labels_cat = tf.keras.utils.to_categorical(train_labels)
test_labels_cat = tf.keras.utils.to_categorical(test_labels)

train_images_scaled = train_images /255
test_images_scaled = test_images /255

In [4]:
print (tf.__version__)

1.14.0


# Step 2: Define Keras Model

In [5]:
model = tf.keras.models.Sequential()
model.add (tf.keras.layers.Flatten(input_shape=(28,28)))
model.add (tf.keras.layers.Dense(100, activation=tf.nn.relu))
model.add (tf.keras.layers.Dense(10, activation=tf.nn.softmax))
model.compile (optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()

W0908 23:06:40.496886  6300 deprecation.py:506] From C:\MachineLearning\anaconda\lib\site-packages\tensorflow\python\ops\init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten (Flatten)            (None, 784)               0         
_________________________________________________________________
dense (Dense)                (None, 100)               78500     
_________________________________________________________________
dense_1 (Dense)              (None, 10)                1010      
Total params: 79,510
Trainable params: 79,510
Non-trainable params: 0
_________________________________________________________________


In [6]:
history = model.fit(train_images_scaled, train_labels_cat, epochs=5, batch_size=32, verbose=2)

Epoch 1/5
60000/60000 - 2s - loss: 0.5073 - acc: 0.8228
Epoch 2/5
60000/60000 - 2s - loss: 0.3768 - acc: 0.8628
Epoch 3/5
60000/60000 - 2s - loss: 0.3380 - acc: 0.8770
Epoch 4/5
60000/60000 - 2s - loss: 0.3165 - acc: 0.8837
Epoch 5/5
60000/60000 - 2s - loss: 0.2982 - acc: 0.8899


In [7]:
test_loss, test_acc = model.evaluate(test_images, test_labels_cat)
print ('model accuracy is', test_acc)
print ('model loss is', test_loss)

model accuracy is 0.8609
model loss is 58.33825239868164


# Step 3: Convert Keras Model as Tf.Estimator

In [8]:
tf_classifier = tf.keras.estimator.model_to_estimator(keras_model=model)

W0908 23:06:50.650623  6300 estimator.py:1811] Using temporary folder as model directory: C:\Users\ADITYA~1\AppData\Local\Temp\tmpm47ya3vj
W0908 23:06:50.666625  6300 deprecation.py:506] From C:\MachineLearning\anaconda\lib\site-packages\tensorflow\python\ops\init_ops.py:97: calling GlorotUniform.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
W0908 23:06:50.668624  6300 deprecation.py:506] From C:\MachineLearning\anaconda\lib\site-packages\tensorflow\python\ops\init_ops.py:97: calling Zeros.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


In [9]:
print (model.input_names)
train_input_fn = tf.estimator.inputs.numpy_input_fn(
    x = {'flatten_input':train_images_scaled},
    y = train_labels_cat,
    batch_size = 32,
    num_epochs = 5,
    shuffle = True
)

tf_classifier.train(input_fn=train_input_fn)

W0908 23:06:51.578796  6300 deprecation.py:323] From C:\MachineLearning\anaconda\lib\site-packages\tensorflow\python\training\training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
W0908 23:06:51.590793  6300 deprecation.py:323] From C:\MachineLearning\anaconda\lib\site-packages\tensorflow_estimator\python\estimator\inputs\queues\feeding_queue_runner.py:62: QueueRunner.__init__ (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
W0908 23:06:51.593795  6300 deprecation.py:323] From C:\MachineLearning\anaconda\lib\site-packages\tensorflow_estimator\python\estimator\inputs\queues\feeding_functions.py:500: add_queue_run

['flatten_input']


W0908 23:06:52.119168  6300 deprecation.py:323] From C:\MachineLearning\anaconda\lib\site-packages\tensorflow\python\ops\array_ops.py:1354: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
W0908 23:06:52.247171  6300 deprecation.py:323] From C:\MachineLearning\anaconda\lib\site-packages\tensorflow\python\training\monitored_session.py:875: start_queue_runners (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.


<tensorflow_estimator.python.estimator.estimator.Estimator at 0x1b3b2bb27b8>

In [10]:
evaluate_input_fn = tf.estimator.inputs.numpy_input_fn(
    x = {'flatten_input':test_images_scaled},
    y = test_labels_cat,
    num_epochs = 1,
    shuffle = True
)

tf_classifier.evaluate(input_fn=evaluate_input_fn)

W0908 23:07:03.460589  6300 deprecation.py:323] From C:\MachineLearning\anaconda\lib\site-packages\tensorflow\python\training\saver.py:1276: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to check for files with this prefix.


{'acc': 0.8695, 'loss': 0.36077964, 'global_step': 9375}

# Step 4: Export the Model

In [11]:
model.save('fashion-mnist-keras-model.h5')