In [1]:
import tensorflow as tf

  from ._conv import register_converters as _register_converters


In [2]:
def model_fn(features, labels, mode, params):
    x = tf.keras.layers.Conv2D(32, kernel_size=(5, 5), activation='relu')(features['image'])
    x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(x)
    x = tf.keras.layers.Conv2D(64, (5, 5), activation='relu')(x)
    x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x)
    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(1000, activation='relu')(x)
    x_out = tf.keras.layers.Dense(10, activation= 'softmax')(x)
    prediction_cls = tf.argmax(x_out, -1)
    if mode == tf.estimator.ModeKeys.PREDICT:
        spec = tf.estimator.EstimatorSpec(mode=mode, predictions= prediction_cls)
    else:
        one_hot_labels = tf.one_hot(labels, 10)
        cross_entropy = tf.keras.backend.categorical_crossentropy(one_hot_labels, x_out)
        loss = tf.reduce_mean(cross_entropy)
        optimizer = tf.train.AdamOptimizer()
        train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step())
        metrics = {"accuracy": tf.metrics.accuracy(labels, prediction_cls)}
        spec = tf.estimator.EstimatorSpec(mode= mode, loss=loss, train_op=train_op, eval_metric_ops=metrics)
    return spec

In [3]:
def read_and_decode(record):
    keys_to_features = {
        "image_raw": tf.FixedLenFeature([], tf.string),
        "label":     tf.FixedLenFeature([], tf.int64)
    }
    parsed = tf.parse_single_example(record, keys_to_features)
    image = tf.decode_raw(parsed['image_raw'], tf.uint8)
    label = parsed['label']
    image = tf.cast(image, tf.float32)
    image = image / 255
    image = tf.reshape(image, [28, 28, 1])
    return {"image": image}, label

def _input_fn(mode, data_dir):
    dataset = tf.data.TFRecordDataset(data_dir, num_parallel_reads= 4)
    
    if mode == tf.estimator.ModeKeys.TRAIN:
        dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(60000))
    dataset = dataset.apply(tf.contrib.data.map_and_batch(read_and_decode, 128))
    dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE)
    return dataset

def train_input_fn():
    return _input_fn(tf.estimator.ModeKeys.TRAIN, data_dir = '/home/jenno/Desktop/data/mnist/train.tfrecords')

def eval_input_fn():
    return _input_fn(tf.estimator.ModeKeys.EVAL, data_dir = '/home/jenno/Desktop/data/mnist/test.tfrecords')
    
    

In [5]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    model = tf.estimator.Estimator(model_fn= model_fn)
    model.train(input_fn = train_input_fn, steps = 4688)
    results = model.evaluate(input_fn=eval_input_fn, steps = 100)

INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpnzwe1spy', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f7e243b82b0>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpnzwe1spy/model.ckpt.
INFO:ten