# Functions used while training

1. loss
2. optimizer and training
3. evaluation

In [1]:
import tensorflow as tf

## 1. loss function

Compute loss from logits and labels

In [2]:
def losses(logits, labels):
    '''
    Args:
        logits: logits tensor, float, [batch_size, n_classes]
        labels: label tensor, tf.int32, [batch_size]
        
    Returns:
        loss tensor of float type
    '''
    with tf.variable_scope('loss') as scope:
        cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits\
                        (logits=logits, labels=labels, name='xentropy_per_example')
        loss = tf.reduce_mean(cross_entropy, name='loss')
        #tf.summary.scalar(scope.name+'/loss', loss)
    return loss

## 2. optimizer and training

Training ops, the Op returned by this function is what must be passed to 'sess.run()' call to cause the model to train.

In [3]:
def trainning(loss, learning_rate):
    '''
    Args:
        loss: loss tensor, from losses()
        
    Returns:
        train_op: The op for trainning
    '''
    with tf.name_scope('optimizer'):
        optimizer = tf.train.AdamOptimizer(learning_rate= learning_rate)
        global_step = tf.Variable(0, name='global_step', trainable=False)
        train_op = optimizer.minimize(loss, global_step= global_step)
    return train_op

## 3. evaluation

Evaluate the quality of the logits at predicting the label.

In [4]:
def evaluation(logits, labels):
    '''
    Args:
        logits: Logits tensor, float - [batch_size, NUM_CLASSES].
        labels: Labels tensor, int32 - [batch_size], with values in the
        range [0, NUM_CLASSES).
    Returns:
        A scalar int32 tensor with the number of examples (out of batch_size) 
        that were predicted correctly.
    '''
    with tf.variable_scope('accuracy') as scope:
        correct = tf.nn.in_top_k(logits, labels, 1)
        correct = tf.cast(correct, tf.float16)
        accuracy = tf.reduce_mean(correct)
        #tf.summary.scalar(scope.name+'/accuracy', accuracy)
    return accuracy