# Tensorflow Ear

In [1]:
import tensorflow as tf
import numpy as np
import menpo.io as mio

from pathlib import Path
from scipy.io import loadmat

In [2]:
batch_size = 64

In [3]:
class Dataset(object):
    def __init__(self, name, root, batch_size=1):
        self.name = name
        self.root = Path(root)
        self.batch_size = batch_size

    def get_keys(self, path='images'):
        path = self.root / path
        keys = [x.stem for x in path.glob('*')]
        print('Found {} files.'.format(len(keys)))

        if len(keys) == 0:
            raise RuntimeError('No images found in {}'.format(path))
        return tf.constant(keys, tf.string)

class EarWPUTEDB(Dataset):
    def __init__(self, batch_size=1, db_name='WPUTEDB-train'):
        self.name = 'EarWPUTEDB'
        self.batch_size = batch_size
        self.root = Path('/homes/yz4009/wd/PickleModel/EarRecognition/')
        self.dataset = mio.import_pickle(str(self.root / '{}.pkl'.format(db_name)))
        self.num_classes = 500
        self.shape = (250, 190)

    def get_keys(self, path='images'):
        path = self.root / path
        keys = map(str, np.arange(len(self.dataset)))
        print('Found {} files.'.format(len(keys)))

        if len(keys) == 0:
            raise RuntimeError('No images found in {}'.format(path))
        return tf.constant(keys, tf.string)

    def get_images(self, key, shape=None):
        def wrapper(index):
            return self.dataset[int(index)][1].resize(self.shape).pixels.reshape(self.shape + (1,)).astype(np.float32)

        image = tf.py_func(wrapper, [key],
                                   [tf.float32])[0]
        
        image.set_shape(self.shape + (1,))
        return image

    def get_labels(self, key, shape=None):
        def wrapper(index):
            return self.dataset[int(index)][0].astype(np.int32)

        label = tf.py_func(wrapper, [key],
                                   [tf.int32])[0]

        label = tf.one_hot(label, self.num_classes, dtype=tf.int32)
        label.set_shape([500,])
        return label, None

    def get(self, *names):
        producer = tf.train.string_input_producer(self.get_keys(),
                                                  shuffle=True)
        key = producer.dequeue()
        images = self.get_images(key)

        image_shape = tf.shape(images)
        tensors = [images]

        for name in names:
            fun = getattr(self, 'get_' + name.split('/')[0])
            use_mask = (
                len(name.split('/')) > 1) and name.split('/')[1] == 'mask'

            label, mask = fun(key, shape=image_shape)
            tensors.append(label)

        return tf.train.shuffle_batch(tensors,
                              self.batch_size,
                              capacity=2000, min_after_dequeue=200)

In [4]:
from tensorflow.contrib.slim import nets

In [5]:
from tensorflow.python.platform import tf_logging as logging
import tensorflow.contrib.slim as slim

FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_float('initial_learning_rate', 0.001,
                          '''Initial learning rate.''')
tf.app.flags.DEFINE_float('num_epochs_per_decay', 5.0,
                          '''Epochs after which learning rate decays.''')
tf.app.flags.DEFINE_float('learning_rate_decay_factor', 0.97,
                          '''Learning rate decay factor.''')
tf.app.flags.DEFINE_integer('batch_size', batch_size, '''The batch size to use.''')
tf.app.flags.DEFINE_integer('num_preprocess_threads', 4,
                            '''How many preprocess threads to use.''')
tf.app.flags.DEFINE_string('train_dir', 'ckpt/ear_train',
                           '''Directory where to write event logs '''
                           '''and checkpoint.''')
# tf.app.flags.DEFINE_string('pretrained_model_checkpoint_path', '/vol/atlas/homes/gt108/Projects/ibugface/pretrained_models/resnet_v1_50.ckpt',
#                            '''If specified, restore this pretrained model '''
#                            '''before beginning any training.''')
tf.app.flags.DEFINE_integer('max_steps', 100000,
                            '''Number of batches to run.''')
tf.app.flags.DEFINE_string('train_device', '/gpu:3',
                           '''Device to train with.''')
tf.app.flags.DEFINE_string('dataset_path', '', 'Dataset directory')
# The decay to use for the moving average.
MOVING_AVERAGE_DECAY = 0.9999



In [6]:
tf.reset_default_graph()

In [7]:
def network(inputs, scale=1, output_classes=500):
    net, _ = nets.resnet_v1.resnet_v1_50(inputs)
    net = slim.layers.fully_connected(slim.flatten(net), output_classes, activation_fn=None, scope='logits')
    return net

### Evalutation

In [8]:
tf.reset_default_graph()

In [9]:
test_provider = EarWPUTEDB(batch_size=batch_size, db_name='WPUTEDB-test')

images, labels = test_provider.get('labels')

predictions = network(images)

predictions = tf.to_int32(tf.argmax(predictions, 1))
labels = tf.to_int32(tf.argmax(labels, 1))

tf.scalar_summary('accuracy', slim.metrics.accuracy(predictions, labels))

Found 859 files.


<tf.Tensor 'ScalarSummary:0' shape=() dtype=string>

In [10]:
num_batches = 859 // 32

In [None]:
sess = tf.Session()

In [None]:
# These are streaming metrics which compute the "running" metric,
# e.g running accuracy
metrics_to_values, metrics_to_updates = slim.metrics.aggregate_metric_map({
    "streaming_accuracy": slim.metrics.streaming_accuracy(predictions, labels),
})

# Define the streaming summaries to write:
for metric_name, metric_value in metrics_to_values.iteritems():
    tf.scalar_summary(metric_name, metric_value)

# Evaluate every 30 seconds
slim.evaluation.evaluation_loop(
    '',
    'ckpt/ear_train',
    'ckpt/ear_eval',
    num_evals=batch_size,
    eval_op=metrics_to_updates.values(),
    summary_op=tf.merge_all_summaries(),
    eval_interval_secs=30)