In [None]:
import os
import tensorflow as tf
from tensorflow.python.lib.io import file_io



# global variables
TFR_TRAIN = 'train.tfrecord'
TFR_VALID = 'valid.tfrecord'
TFR_TEST = 'test.tfrecord'
BUCKET = 'gs://robolab/'

# image and classes
NUM_CLASSES = 2
IMG_HEIGHT = 80
IMG_WIDTH = 71

# model dir
OUTDIR = BUCKET + 'output_TPU'

# hypers
BATCH_SIZE = 800
TRAIN_STEPS = 10000
EVAL_STEPS = 10
LR = 0.0001



# TPU config
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
    None,
    zone=None,
    project=None)

tpu_config = tf.contrib.tpu.TPUConfig(
    iterations_per_loop=2,
    num_shards=8,
    per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2)
    
run_config = tf.contrib.tpu.RunConfig(
    cluster=tpu_cluster_resolver,
    model_dir=OUTDIR,
    tpu_config=tpu_config)



def parser(serialized_example):

    features = tf.parse_single_example(
        serialized_example,
        features={
            'image_raw': tf.FixedLenFeature([], tf.string),
            'label': tf.FixedLenFeature([], tf.string)})

    image = tf.decode_raw(features['image_raw'], tf.float32)
    image.set_shape([IMG_HEIGHT * IMG_WIDTH])
    image = tf.reshape(image, [IMG_HEIGHT, IMG_WIDTH])
    image = tf.expand_dims(image, axis=2)
    
    label = tf.decode_raw(features['label'], tf.int32)
    label.set_shape([1])

    return image, label

def test_parser(serialized_example):

    features = tf.parse_single_example(
        serialized_example,
        features={'image_raw': tf.FixedLenFeature([], tf.string),
                 'image_id': tf.FixedLenFeature([], tf.string)})

    image = tf.decode_raw(features['image_raw'], tf.float32)
    image.set_shape([IMG_HEIGHT * IMG_WIDTH])
    image = tf.reshape(image, [IMG_HEIGHT, IMG_WIDTH])
    image = tf.expand_dims(image, axis=2)

    return image


def train_input_fn(params):

    batch_size = params['batch_size']
    
    # get dataset from tf_record
    dataset = tf.data.TFRecordDataset(BUCKET + TFR_TRAIN)

    # map parser over dataset samples
    dataset = dataset.repeat()
    dataset = dataset.shuffle(1024)
    dataset = dataset.apply(
        tf.contrib.data.map_and_batch(
            parser,
            batch_size=batch_size,
            num_parallel_batches=8,
            drop_remainder=True))

    dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE)
    
    return dataset

def eval_input_fn(params):
    
    batch_size = params['batch_size']

    # get dataset from tf_record
    dataset = tf.data.TFRecordDataset(BUCKET + TFR_TRAIN)

    # map parser over dataset samples
    dataset = dataset.repeat()
    dataset = dataset.apply(
        tf.contrib.data.map_and_batch(
            parser,
            batch_size=batch_size,
            num_parallel_batches=8,
            drop_remainder=True))
    
    dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE)

    return dataset

def valid_input_fn(params):

    # get dataset from tf_record
    dataset = tf.data.TFRecordDataset(BUCKET + TFR_VALID)

    # map parser over dataset samples
    dataset = dataset.map(parser)
    dataset = dataset.batch(params['batch_size'])
    dataset = dataset.repeat(1)

    return dataset

def predict_input_fn():

    # get dataset from tf_record
    dataset = tf.data.TFRecordDataset(BUCKET + TFR_TEST)

    # map parser over dataset samples
    dataset = dataset.map(test_parser)
    dataset = dataset.batch(params['batch_size'])
    dataset = dataset.repeat(1)
    iterator = dataset.make_one_shot_iterator()

    features = iterator.get_next()

    return features

def get_image_id(serialized_example):

    features = tf.parse_single_example(
        serialized_example,
        features={'image_raw': tf.FixedLenFeature([], tf.string),
                  'image_id': tf.FixedLenFeature([], tf.string)})
    
    return features['image_id']



def metrics_fn(classes, labels):
    
    accuracy = tf.metrics.accuracy(
        labels=labels,
        predictions=classes)
    
    return {'accuracy': accuracy}


def cnn_model_fn(features, labels, mode, params):

    conv_layer_1 = tf.layers.conv2d(
        inputs=features,
        filters=8,
        kernel_size=[2, 2],
        padding='same',
        activation=tf.nn.relu)

    pool_layer_1 = tf.layers.max_pooling2d(
        inputs=conv_layer_1,
        pool_size=[2, 2],
        strides=2,
        padding='same')

    conv_layer_2 = tf.layers.conv2d(
        inputs=pool_layer_1,
        filters=32,
        kernel_size=[2, 2],
        padding='same',
        activation=tf.nn.relu)

    pool_layer_2 = tf.layers.max_pooling2d(
        inputs=conv_layer_2,
        pool_size=[2, 2],
        strides=2,
        padding='same')

    reshape_layer = tf.layers.flatten(pool_layer_2)

    dense_layer = tf.layers.dense(
        inputs=reshape_layer,
        units=256,
        activation=tf.nn.relu)
    
    is_train = False

    if mode == tf.estimator.ModeKeys.TRAIN:
        is_train = True

    dropout_layer = tf.layers.dropout(
        inputs=dense_layer,
        rate=0.2,
        training=is_train)

    logits = tf.layers.dense(
        inputs=dropout_layer,
        units=NUM_CLASSES)
    
    classes = tf.argmax(logits, axis=1)

    if mode == tf.estimator.ModeKeys.PREDICT:
        return estimator.EstimatorSpec(mode=mode,
                                       predictions={'classes':classes,
                                                    'probabilities':tf.nn.softmax(logits, axis=1)})

    loss = tf.losses.sparse_softmax_cross_entropy(
        labels=labels,
        logits=logits)

    if mode == tf.estimator.ModeKeys.TRAIN:
                                       
        train_optimizer = tf.train.AdamOptimizer(learning_rate=LR)
        train_optimizer = tf.contrib.tpu.CrossShardOptimizer(train_optimizer).minimize(loss=loss,
                                                                            global_step=tf.train.get_global_step())
        
        return tf.contrib.tpu.TPUEstimatorSpec(mode=mode,
                                               loss=loss,
                                               train_op=train_optimizer)
    
    # EVAL mode
    return tf.contrib.tpu.TPUEstimatorSpec(mode=mode,
                                           loss=loss,
                                           eval_metrics=(metrics_fn, [labels, classes]))



def train_and_evaluate(estimator):

    estimator.train(input_fn=train_input_fn, max_steps=TRAIN_STEPS)
    estimator.evaluate(input_fn=eval_input_fn, steps=EVAL_STEPS)



cnn_classifier = tf.contrib.tpu.TPUEstimator(
    model_fn=cnn_model_fn,
    config=run_config,
    use_tpu=True,
    train_batch_size=BATCH_SIZE,
    eval_batch_size=BATCH_SIZE,
    eval_on_tpu=True)

train_and_evaluate(cnn_classifier)