Skip to content
This repository has been archived by the owner on Oct 19, 2019. It is now read-only.

Commit

Permalink
make it work for cifar
Browse files Browse the repository at this point in the history
  • Loading branch information
ry committed May 11, 2016
1 parent 3144115 commit 6b42dfa
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 27 deletions.
40 changes: 17 additions & 23 deletions resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,27 @@
tf.app.flags.DEFINE_float('learning_rate', 0.1, "learning rate.")
tf.app.flags.DEFINE_integer('batch_size', 16, "batch size")
tf.app.flags.DEFINE_integer('input_size', 224, "input image size")
tf.app.flags.DEFINE_boolean('continue', False,
tf.app.flags.DEFINE_boolean('resume', False,
'resume from latest saved state')


def train(images, labels):
def train(images, labels, small=False):
global_step = tf.get_variable('global_step', [],
initializer=tf.constant_initializer(0),
trainable=False)

logits = inference(images,
num_classes=1000,
is_training=True,
preprocess=True,
bottleneck=False,
num_blocks=[2, 2, 2, 2])
if small:
logits = inference_small(images,
num_classes=10,
is_training=True,
num_blocks=3)
else:
logits = inference(images,
num_classes=1000,
is_training=True,
preprocess=True,
bottleneck=False,
num_blocks=[2, 2, 2, 2])

loss_ = loss(logits, labels)

Expand Down Expand Up @@ -80,12 +86,12 @@ def train(images, labels):

summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)

if FLAGS.__getattr__('continue'):
if FLAGS.resume:
latest = tf.train.latest_checkpoint(FLAGS.train_dir)
if not latest:
print "No checkpoint to continue from in", FLAGS.train_dir
sys.exit(1)
print "continue", latest
print "resume", latest
saver.restore(sess, latest)

while True:
Expand Down Expand Up @@ -127,13 +133,7 @@ def train(images, labels):
def inference(x, is_training,
num_classes=1000,
num_blocks=[3, 4, 6, 3], # defaults to 50-layer network
preprocess=True,
bottleneck=True):
# if preprocess is True, input should be RGB [0,1], otherwise BGR with mean
# subtracted
if preprocess:
x = _imagenet_preprocess(x)

is_training = tf.convert_to_tensor(is_training,
dtype='bool',
name='is_training')
Expand Down Expand Up @@ -168,14 +168,8 @@ def inference(x, is_training,
# See Section 4.2 in http://arxiv.org/abs/1512.03385
def inference_small(x,
is_training,
num_classes=10,
num_blocks=3, # 6n+2 total weight layers will be used.
preprocess=True):
# if preprocess is True, input should be RGB [0,1], otherwise BGR with mean
# subtracted
if preprocess:
x = _imagenet_preprocess(x)

num_classes=10):
bottleneck = False
is_training = tf.convert_to_tensor(is_training,
dtype='bool',
Expand Down
10 changes: 6 additions & 4 deletions train_cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,12 +289,14 @@ def _progress(count, block_size, total_size):

def main(argv=None): # pylint: disable=unused-argument
maybe_download_and_extract()
if tf.gfile.Exists(FLAGS.train_dir):
tf.gfile.DeleteRecursively(FLAGS.train_dir)
tf.gfile.MakeDirs(FLAGS.train_dir)

if not FLAGS.resume:
if tf.gfile.Exists(FLAGS.train_dir):
tf.gfile.DeleteRecursively(FLAGS.train_dir)
tf.gfile.MakeDirs(FLAGS.train_dir)

images, labels = distorted_inputs(FLAGS.data_dir, FLAGS.batch_size)
resnet.train(images, labels)
resnet.train(images, labels, small=True)


if __name__ == '__main__':
Expand Down
5 changes: 5 additions & 0 deletions train_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,11 @@ def distorted_inputs():


def main(_):
if not FLAGS.resume:
if tf.gfile.Exists(FLAGS.train_dir):
tf.gfile.DeleteRecursively(FLAGS.train_dir)
tf.gfile.MakeDirs(FLAGS.train_dir)

dataset = DataSet(FLAGS.data_dir)
images, labels = distorted_inputs()
resnet.train(images, labels)
Expand Down

0 comments on commit 6b42dfa

Please sign in to comment.