In [None]:
import tensorflow as tf
import numpy as np
import sys
import os
import matplotlib.pyplot as plt
from scipy import ndimage
import tensorflow.contrib.slim as slim
from tensorflow.contrib.slim.nets import vgg
import dataset_utils

%matplotlib inline
tf.logging.set_verbosity(tf.logging.INFO)

dataset_root = './dataset/'
pre_trained_model_path = './pre_trained_models/vgg_16.ckpt'
trained_model_path = './trained_model'

image_size = 224
num_channels = 3
num_classes = 3


In [None]:
# Step 1 
# Convert images dataset to TFRecord format.
dataset_utils.process_directory(dataset_root)

In [None]:
# Step 2
# Restore the model from checkpoint and creat model function
def model_fn(images, labels, num_classes, mode):
    
    with tf.contrib.slim.arg_scope(vgg.vgg_arg_scope()):
        logits, end_points = vgg.vgg_16(images, num_classes, is_training=is_training)
        
        predictions = {
          'classes': tf.argmax(input=logits, axis=1),
          'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
        }
        
        if mode == tf.estimator.ModeKeys.PREDICT:
            return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
        
        accuracy = tf.metrics.accuracy(labels=labels, predictions=predictions["classes"])
        tf.summary.scalar('accuracy', accuracy[1])
        
        # Restore all the variables except from the last layer 
        variables_to_restore = tf.contrib.slim.get_variables_to_restore(exclude=['vgg_16/fc8'])
        scopes = { os.path.dirname(v.name) for v in variables_to_restore }
        tf.train.init_from_checkpoint(pre_trained_model_path, 
                              {v.name.split(':')[0]: v for v in variables_to_restore})
        
        
        # Get a handle to last variable and initalize it from scratch
        fc8_variables = tf.contrib.framework.get_variables('vgg_16/fc8')
        fc8_init = tf.variables_initializer(fc8_variables)

        loss = tf.losses.sparse_softmax_cross_entropy(logits=logits, labels=labels)  
        reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
        total_loss = tf.add_n([loss] + reg_losses, name='total_loss')   
        
        if mode == tf.estimator.ModeKeys.EVAL:
            metrics = {'eval_accuracy': accuracy}
            return tf.estimator.EstimatorSpec(
                mode, loss=loss, eval_metric_ops=metrics)

        # Re-train the last layer of model
        global_step = tf.train.get_or_create_global_step()
        optimizer = tf.train.AdamOptimizer(learning_rate=0.00016)
        train_op = optimizer.minimize(total_loss, global_step, var_list=fc8_variables)
        
        logging_hook = tf.train.LoggingTensorHook({"loss" : loss, 'total_loss' : total_loss}, every_n_iter=10)
     
        return tf.estimator.EstimatorSpec(
            mode=mode,
            predictions=predictions,
            loss=total_loss,
            train_op=train_op,
            training_hooks = [logging_hook])



In [None]:
# Step 3 
# Create data loading pipeline based on TFRecords
def preprocess_image_record(record, size, num_channels, is_training=False):
    imgdata, label, text = dataset_utils.deserialize_image_record(record)
    
    # Decode JPEG files
    image = tf.image.decode_jpeg(imgdata, channels=num_channels,
                                fancy_upscaling=False,
                                dct_method='INTEGER_FAST')

    # Resize the images to all have the same size
    image = tf.image.resize_images(image, [image_size, image_size], 
                               method=tf.image.ResizeMethod.BILINEAR,
                                align_corners=False)
    
    # Augument the data for training
    if is_training:
        image = tf.image.random_flip_left_right(image)

    image = tf.image.per_image_standardization(image)
    return image, label


def load_tfrecord_dataset(filenames, batch_size, is_training):
    shuffle_buffer_size = 1000
    
    # Load the dataset from TFRecord files
    dataset = tf.data.TFRecordDataset(filenames)
    
    # Shuffle it
    dataset = dataset.shuffle(shuffle_buffer_size)
    
    preproc_func = lambda record, : preprocess_image_record(
        record, image_size, num_channels, is_training)

    # Preprocess the dataset 
    dataset = dataset.map(map_func=preproc_func)
   
    # Repeat the dataset indefenietly
    dataset = dataset.repeat()  
    
    # Create batches of data
    dataset = dataset.batch(batch_size)
    
    return dataset

In [None]:
# Step 4
# Create the classifier and training
classifier = tf.estimator.Estimator(
    model_fn=lambda features, labels, mode: model_fn(features, labels, num_classes, mode),
    model_dir=trained_model_path,
    config=tf.estimator.RunConfig(
        save_summary_steps=100,
        save_checkpoints_steps=500
    )
)

batch_size = 8
train_examples = 589
steps_per_epoch = train_examples/batch_size
num_epochs = 10

# Create input function for training
train_filenames = ['/workspace/dataset/train-0.tfrecords']
train_input_fn = lambda: load_tfrecord_dataset(train_filenames, batch_size, is_training=True)

# Create input function for validation
eval_filenames =  ['/workspace/dataset/validation-0.tfrecords']
eval_input_fn = lambda: load_tfrecord_dataset(eval_filenames, batch_size, is_training=False)

# Train the model for num_epochs and evaluate it after each
for i in range(num_epochs):
    train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=steps_per_epoch*(i+1))
    eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn)
    tf.estimator.train_and_evaluate(classifier,train_spec,eval_spec)

In [None]:
# Step 5
# Test the model on test set
# How can we modify model function to provide accuracy metric for predictions?
test_filenames = ['/workspace/dataset/test-0.tfrecords']
test_input_fn = lambda: load_tfrecord_dataset(test_filenames, batch_size=1, is_training=False)
for prediction in classifier.predict(input_fn=test_input_fn):
    print(prediction)

In [None]:
# Step 6
# Test the model on new data
predict_filenames = None
predict_input_fn = None

for prediction in classifier.predict(input_fn=predict_input_fn):
    print(prediction)