In [1]:
import tensorflow as tf
import numpy as np

  from ._conv import register_converters as _register_converters


In [2]:
def model_fn(features, labels, mode, params):
    x = tf.keras.layers.Conv2D(32, kernel_size=(5, 5), activation='relu')(features)
    x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(x)
    x = tf.keras.layers.Conv2D(64, (5, 5), activation='relu')(x)
    x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x)
    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(1000, activation='relu')(x)
    x_out = tf.keras.layers.Dense(10, activation= 'softmax')(x)
    prediction_cls = tf.argmax(x_out, -1)
    if mode == tf.estimator.ModeKeys.PREDICT:
        spec = tf.estimator.EstimatorSpec(mode=mode, predictions= prediction_cls)
    else:
        one_hot_labels = tf.one_hot(labels, 10)
        cross_entropy = tf.keras.backend.categorical_crossentropy(one_hot_labels, x_out)
        loss = tf.reduce_mean(cross_entropy)
        optimizer = tf.train.AdamOptimizer()
        train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step())
        metrics = {"accuracy": tf.metrics.accuracy(labels, prediction_cls)}
        spec = tf.estimator.EstimatorSpec(mode= mode, loss=loss, train_op=train_op, eval_metric_ops=metrics)
    return spec

In [3]:
def _input_fn(mode, features, labels):
    dataset = tf.data.Dataset.from_tensor_slices( (features, labels))
    if mode == tf.estimator.ModeKeys.TRAIN:
        dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(60000))
    dataset = dataset.batch(128)
    dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE)
    return dataset

#def train_input_fn():
    #return _input_fn(tf.estimator.ModeKeys.TRAIN, data_dir = '/home/jenno/Desktop/data/mnist/train.tfrecords')

#def eval_input_fn():
    #return _input_fn(tf.estimator.ModeKeys.EVAL, data_dir = '/home/jenno/Desktop/data/mnist/test.tfrecords')

In [4]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data('/home/jenno/Desktop/data/mnist/mnist.npz')

In [5]:
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)



In [6]:
x_train = np.float32(x_train) / 255
x_test = np.float32(x_test) / 255


In [7]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    model = tf.estimator.Estimator(model_fn= model_fn)
    model.train(input_fn = lambda: _input_fn(tf.estimator.ModeKeys.TRAIN, x_train, y_train) , steps = 4688)
    results = model.evaluate(input_fn = lambda: _input_fn(tf.estimator.ModeKeys.EVAL, x_test, y_test) , steps = 80)
    q = model.predict(input_fn = lambda: _input_fn(tf.estimator.ModeKeys.EVAL, x_test, y_test))
    print(q)

INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp3fi2y1d7', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f4821532e10>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp3fi2y1d7/model.ckpt.
INFO:ten

In [12]:
q = list(model.predict(input_fn = lambda: _input_fn(tf.estimator.ModeKeys.EVAL, x_test, y_test)))

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmp3fi2y1d7/model.ckpt-4688
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.


In [13]:
q

[7,
 2,
 1,
 0,
 4,
 1,
 4,
 9,
 5,
 9,
 0,
 6,
 9,
 0,
 1,
 5,
 9,
 7,
 3,
 4,
 9,
 6,
 6,
 5,
 4,
 0,
 7,
 4,
 0,
 1,
 3,
 1,
 3,
 4,
 7,
 2,
 7,
 1,
 2,
 1,
 1,
 7,
 4,
 2,
 3,
 5,
 1,
 2,
 4,
 4,
 6,
 3,
 5,
 5,
 6,
 0,
 4,
 1,
 9,
 5,
 7,
 8,
 9,
 3,
 7,
 4,
 6,
 4,
 3,
 0,
 7,
 0,
 2,
 9,
 1,
 7,
 3,
 2,
 9,
 7,
 7,
 6,
 2,
 7,
 8,
 4,
 7,
 3,
 6,
 1,
 3,
 6,
 9,
 3,
 1,
 4,
 1,
 7,
 6,
 9,
 6,
 0,
 5,
 4,
 9,
 9,
 2,
 1,
 9,
 4,
 8,
 7,
 3,
 9,
 7,
 4,
 4,
 4,
 9,
 2,
 5,
 4,
 7,
 6,
 7,
 9,
 0,
 5,
 8,
 5,
 6,
 6,
 5,
 7,
 8,
 1,
 0,
 1,
 6,
 4,
 6,
 7,
 3,
 1,
 7,
 1,
 8,
 2,
 0,
 2,
 9,
 9,
 5,
 5,
 1,
 5,
 6,
 0,
 3,
 4,
 4,
 6,
 5,
 4,
 6,
 5,
 4,
 5,
 1,
 4,
 4,
 7,
 2,
 3,
 2,
 7,
 1,
 8,
 1,
 8,
 1,
 8,
 5,
 0,
 8,
 9,
 2,
 5,
 0,
 1,
 1,
 1,
 0,
 9,
 0,
 3,
 1,
 6,
 4,
 2,
 3,
 6,
 1,
 1,
 1,
 3,
 9,
 5,
 2,
 9,
 4,
 5,
 9,
 3,
 9,
 0,
 3,
 6,
 5,
 5,
 7,
 2,
 2,
 7,
 1,
 2,
 8,
 4,
 1,
 7,
 3,
 3,
 8,
 8,
 7,
 9,
 2,
 2,
 4,
 1,
 5,
 9,
 8,
 7,
 2,
 3,
 0,
 2,
 4,
 2,
