In [None]:
from __future__ import division

import os
import tensorflow as tf
from tensorflow.python.ops import control_flow_ops
import tensorflow.contrib.slim as slim

import flowers_dataset
import inception_preprocessing
import inception_v1
import model_deploy

In [None]:
# Deployment flags.
task_id = 0                # Task id of the replica running the training.
worker_replicas = 1        # Number of worker replicas.

# Dataset flags.
dataset_split_name = "train"
home_dir = os.path.expanduser('~')
base_data_dir = os.path.join(home_dir, "data/flowers")
dataset_dir = os.path.join(base_data_dir, "tfrecords")
train_dir = os.path.join(base_data_dir, "train_logs/v003")

# Training flags.
batch_size = 32
num_preprocessing_threads = 4
save_summaries_secs = 600
save_interval_secs = 600
log_every_n_steps = 10

# Used for pre-training.
checkpoint_path = None
checkpoint_exclude_scopes = None
ignore_missing_vars = False

In [None]:
def configure_learning_rate(num_samples_per_epoch, global_step):
    initial_learning_rate = 0.01
    num_epochs_per_decay = 2
    learning_rate_decay_factor = 0.94
    decay_steps = int(num_samples_per_epoch / batch_size * num_epochs_per_decay)
    return tf.train.exponential_decay(initial_learning_rate,
                                      global_step,
                                      decay_steps,
                                      learning_rate_decay_factor,
                                      staircase=True,
                                      name='exponential_decay_learning_rate')


def configure_optimizer(learning_rate):
    optimizer_type = "rmsprop"
    if optimizer_type == 'adam':
        optimizer = tf.train.AdamOptimizer(
                learning_rate,
                beta1=FLAGS.adam_beta1,
                beta2=FLAGS.adam_beta2,
                epsilon=FLAGS.opt_epsilon)
    elif optimizer_type == 'rmsprop':
        optimizer = tf.train.RMSPropOptimizer(
                learning_rate,
                momentum=0.9,
                epsilon=1.0)
    else:
        raise ValueError('Optimizer [%s] was not recognized', optimizer_type)
    return optimizer


def get_init_fn(checkpoint_path, checkpoint_exclude_scopes, ignore_missing_vars):
    if checkpoint_path is None:
        return None

    exclusions = []
    if checkpoint_exclude_scopes:
        exclusions = [scope.strip() for scope in checkpoint_exclude_scopes.split(',')]

    variables_to_restore = []
    for var in slim.get_model_variables():
        excluded = False
        for exclusion in exclusions:
            if var.op.name.startswith(exclusion):
                excluded = True
                break
        if not excluded:
            variables_to_restore.append(var)

    if tf.gfile.IsDirectory(checkpoint_path):
        checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)
    else:
        checkpoint_path = checkpoint_path

    tf.logging.info('Fine-tuning from %s' % checkpoint_path)

    return slim.assign_from_checkpoint_fn(
            checkpoint_path,
            variables_to_restore,
            ignore_missing_vars=ignore_missing_vars)

In [None]:
tf.logging.set_verbosity(tf.logging.INFO)
with tf.Graph().as_default():
    # Config model_deploy
    deploy_config = model_deploy.DeploymentConfig(
        num_clones=1,
        clone_on_cpu=False,
        replica_id=task_id,
        num_replicas=worker_replicas,
        num_ps_tasks=0)

    # Create global_step
    with tf.device(deploy_config.variables_device()):
        global_step = slim.create_global_step()

    # Select the dataset
    dataset = flowers_dataset.get_dataset(dataset_split_name, dataset_dir)

    # Create a dataset provider that loads data from the dataset
    with tf.device(deploy_config.inputs_device()):
        provider = slim.dataset_data_provider.DatasetDataProvider(
            dataset,
            num_readers=4,
            common_queue_capacity=20 * batch_size,
            common_queue_min=10 * batch_size)
        [image, label] = provider.get(['image', 'label'])

        # Preprocess the images.
        train_image_size = inception_v1.inception_v1.default_image_size
        image = inception_preprocessing.preprocess_image(image, train_image_size, train_image_size, is_training=True)

        # Create training batch.
        images, labels = tf.train.batch(
            [image, label],
            batch_size=batch_size,
            num_threads=num_preprocessing_threads,
            capacity=5 * batch_size)
        labels = slim.one_hot_encoding(labels, dataset.num_classes)
        batch_queue = slim.prefetch_queue.prefetch_queue([images, labels], capacity=2 * deploy_config.num_clones)

    # Define the model
    def clone_fn(batch_queue):
        """Allows data parallelism by creating multiple clones of network_fn."""
        images, labels = batch_queue.dequeue()
        
        with slim.arg_scope(inception_v1.inception_v1_arg_scope(weight_decay=0.00004)):
            logits, end_points = inception_v1.inception_v1(
                images, num_classes=dataset.num_classes, is_training=True)

        slim.losses.softmax_cross_entropy(logits, labels, label_smoothing=0.0, weight=1.0)
        return end_points

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
    first_clone_scope = deploy_config.clone_scope(0)
    # Gather update_ops from the first clone. These contain, for example,
    # the updates for the batch_norm variables created by network_fn.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Add summaries for end_points.
    end_points = clones[0].outputs
    for end_point in end_points:
        x = end_points[end_point]
        summaries.add(tf.histogram_summary('activations/' + end_point, x))
        summaries.add(tf.scalar_summary('sparsity/' + end_point,
                                                                        tf.nn.zero_fraction(x)))

    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
        summaries.add(tf.scalar_summary('losses/%s' % loss.op.name, loss))

    # Add summaries for variables.
    for variable in slim.get_model_variables():
        summaries.add(tf.histogram_summary(variable.op.name, variable))

    # Configure the optimization procedure.
    with tf.device(deploy_config.optimizer_device()):
        learning_rate = configure_learning_rate(dataset.num_samples, global_step)
        optimizer = configure_optimizer(learning_rate)
        summaries.add(tf.scalar_summary('learning_rate', learning_rate, name='learning_rate'))

    # Variables to train.
    variables_to_train = tf.trainable_variables()

    #    and returns a train_tensor and summary_op
    total_loss, clones_gradients = model_deploy.optimize_clones(clones, optimizer, var_list=variables_to_train)
    # Add total_loss to summary.
    summaries.add(tf.scalar_summary('total_loss', total_loss, name='total_loss'))

    # Create gradient updates.
    grad_updates = optimizer.apply_gradients(clones_gradients, global_step=global_step)
    update_ops.append(grad_updates)

    update_op = tf.group(*update_ops)
    train_tensor = control_flow_ops.with_dependencies([update_op], total_loss, name='train_op')

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.merge_summary(list(summaries), name='summary_op')

    # Kicks off the training.
    slim.learning.train(
        train_tensor,
        logdir=train_dir,
        master="",
        is_chief=(task_id == 0),
        init_fn=get_init_fn(checkpoint_path, checkpoint_exclude_scopes, ignore_missing_vars),
        summary_op=summary_op,
        log_every_n_steps=log_every_n_steps,
        save_summaries_secs=save_summaries_secs,
        save_interval_secs=save_interval_secs)