Permalink
Browse files

Run validation periodically during training

  • Loading branch information...
1 parent aa97493 commit a2f1af991466e64c00b5d753c51370b39ddc32b8 @ry committed May 29, 2016
Showing with 63 additions and 31 deletions.
  1. +39 −21 resnet_train.py
  2. +18 −6 train_cifar.py
  3. +6 −4 train_imagenet.py
View
@@ -7,50 +7,60 @@
tf.app.flags.DEFINE_string('train_dir', '/tmp/resnet_train',
"""Directory where to write event logs """
"""and checkpoint.""")
-tf.app.flags.DEFINE_float('learning_rate', 0.1, "learning rate.")
+tf.app.flags.DEFINE_float('learning_rate', 0.01, "learning rate.")
tf.app.flags.DEFINE_integer('batch_size', 16, "batch size")
tf.app.flags.DEFINE_boolean('resume', False,
'resume from latest saved state')
+tf.app.flags.DEFINE_boolean('minimal_summaries', True,
+ 'produce fewer summaries to save HD space')
-def train(images, labels, small=False):
+def top_k_error(predictions, labels, k):
+ batch_size = float(FLAGS.batch_size) #tf.shape(predictions)[0]
+ in_top1 = tf.to_float(tf.nn.in_top_k(predictions, labels, k=1))
+ num_correct = tf.reduce_sum(in_top1)
+ return (batch_size - num_correct) / batch_size
+
+def train(is_training, logits, images, labels):
global_step = tf.get_variable('global_step', [],
initializer=tf.constant_initializer(0),
trainable=False)
- is_training = tf.placeholder('bool', [], name="is_training")
- if small:
- # CIFAR
- logits = inference_small(images,
- num_classes=10,
- is_training=is_training,
- num_blocks=3)
- else:
- # ImageNet
- logits = inference(images,
- num_classes=1000,
- is_training=is_training,
- bottleneck=False,
- num_blocks=[2, 2, 2, 2])
+ val_step = tf.get_variable('val_step', [],
+ initializer=tf.constant_initializer(0),
+ trainable=False)
loss_ = loss(logits, labels)
+ predictions = tf.nn.softmax(logits)
+
+ top1_error = top_k_error(predictions, labels, 1)
+
# loss_avg
ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
tf.add_to_collection(UPDATE_OPS_COLLECTION, ema.apply([loss_]))
- loss_avg = ema.average(loss_)
- tf.scalar_summary('loss_avg', loss_avg)
+ tf.scalar_summary('loss_avg', ema.average(loss_))
+
+ # validation stats
+ ema = tf.train.ExponentialMovingAverage(0.9, val_step)
+ val_op = tf.group(val_step.assign_add(1), ema.apply([top1_error]))
+ top1_error_avg = ema.average(top1_error)
+ tf.scalar_summary('val_top1_error_avg', top1_error_avg)
tf.scalar_summary('learning_rate', FLAGS.learning_rate)
opt = tf.train.MomentumOptimizer(FLAGS.learning_rate, MOMENTUM)
grads = opt.compute_gradients(loss_)
for grad, var in grads:
- if grad is not None:
+ if grad is not None and not FLAGS.minimal_summaries:
tf.histogram_summary(var.op.name + '/gradients', grad)
apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
- for var in tf.trainable_variables():
- tf.histogram_summary(var.op.name, var)
+ if not FLAGS.minimal_summaries:
+ # Display the training images in the visualizer.
+ tf.image_summary('images', images)
+
+ for var in tf.trainable_variables():
+ tf.histogram_summary(var.op.name, var)
batchnorm_updates = tf.get_collection(UPDATE_OPS_COLLECTION)
batchnorm_updates_op = tf.group(*batchnorm_updates)
@@ -108,3 +118,11 @@ def train(images, labels, small=False):
if step > 1 and step % 100 == 0:
checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
saver.save(sess, checkpoint_path, global_step=global_step)
+
+ # Run validation periodically
+ if step > 1 and step % 100 == 0:
+ _, top1_error_value = sess.run([val_op, top1_error], { is_training: False })
+ print('Validation top1 error %.2f' % top1_error_value)
+
+
+
View
@@ -26,6 +26,7 @@
from six.moves import urllib
from resnet_train import train
+from resnet import inference_small
import tensorflow as tf
DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'
@@ -141,9 +142,6 @@ def _generate_image_and_label_batch(image, label, min_queue_examples,
num_threads=num_preprocess_threads,
capacity=min_queue_examples + 3 * batch_size)
- # Display the training images in the visualizer.
- tf.image_summary('images', images)
-
return images, tf.reshape(label_batch, [batch_size])
@@ -223,11 +221,12 @@ def inputs(eval_data, data_dir, batch_size):
labels: Labels. 1D tensor of [batch_size] size.
"""
if not eval_data:
+ assert False, "hack. shouldn't go here"
filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
for i in xrange(1, 6)]
num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN
else:
- filenames = [os.path.join(data_dir, 'test_batch.bin')]
+ filenames = [os.path.join(data_dir, 'cifar-10-batches-bin', 'test_batch.bin')]
num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_EVAL
for f in filenames:
@@ -290,8 +289,21 @@ def _progress(count, block_size, total_size):
def main(argv=None): # pylint: disable=unused-argument
maybe_download_and_extract()
- images, labels = distorted_inputs(FLAGS.data_dir, FLAGS.batch_size)
- train(images, labels, small=True)
+ images_train, labels_train = distorted_inputs(FLAGS.data_dir, FLAGS.batch_size)
+ images_val, labels_val = inputs(True, FLAGS.data_dir, FLAGS.batch_size)
+
+ is_training = tf.placeholder('bool', [], name='is_training')
+
+ images, labels = tf.cond(is_training,
+ lambda: (images_train, labels_train),
+ lambda: (images_val, labels_val))
+
+ logits = inference_small(images,
+ num_classes=10,
+ is_training=is_training,
+ use_bias=False,
+ num_blocks=3)
+ train(is_training, logits, images, labels)
if __name__ == '__main__':
View
@@ -87,16 +87,18 @@ def distorted_inputs():
images = tf.cast(images, tf.float32)
images = tf.reshape(images, shape=[FLAGS.batch_size, height, width, depth])
- # Display the training images in the visualizer.
- tf.image_summary('images', images)
-
return images, tf.reshape(label_index_batch, [FLAGS.batch_size])
def main(_):
images, labels = distorted_inputs()
- train(images, labels)
+ logits = inference(images,
+ num_classes=1000,
+ is_training=True,
+ bottleneck=False,
+ num_blocks=[2, 2, 2, 2])
+ train(logits, images, labels)
if __name__ == '__main__':

0 comments on commit a2f1af9

Please sign in to comment.