Permalink
Browse files

add max_steps flag

1 parent a2f1af9 commit 76d0b5cc34f586a0d510f6f079add9fcfc5cce49 @ry committed May 29, 2016
Showing with 4 additions and 2 deletions.
  1. +2 −1 resnet_train.py
  2. +2 −1 train_cifar.py
View
@@ -9,6 +9,7 @@
"""and checkpoint.""")
tf.app.flags.DEFINE_float('learning_rate', 0.01, "learning rate.")
tf.app.flags.DEFINE_integer('batch_size', 16, "batch size")
+tf.app.flags.DEFINE_integer('max_steps', 500000, "max steps")
tf.app.flags.DEFINE_boolean('resume', False,
'resume from latest saved state')
tf.app.flags.DEFINE_boolean('minimal_summaries', True,
@@ -86,7 +87,7 @@ def train(is_training, logits, images, labels):
print "resume", latest
saver.restore(sess, latest)
- while True:
+ for x in xrange(FLAGS.max_steps + 1):
start_time = time.time()
step = sess.run(global_step)
View
@@ -34,6 +34,7 @@
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('data_dir', '/tmp/cifar-data',
'where to store the dataset')
+tf.app.flags.DEFINE_boolean('use_bn', True, 'use batch normalization. otherwise use biases')
# Process images of this size. Note that this differs from the original CIFAR
# image size of 32 x 32. If one alters this number, then the entire model
@@ -301,7 +302,7 @@ def main(argv=None): # pylint: disable=unused-argument
logits = inference_small(images,
num_classes=10,
is_training=is_training,
- use_bias=False,
+ use_bias=(not FLAGS.use_bn),
num_blocks=3)
train(is_training, logits, images, labels)

0 comments on commit 76d0b5c

Please sign in to comment.