Permalink
Browse files

clean up

1 parent 6b42dfa commit b994ea996c4946d139010f6ed409757c851edbda @ry committed May 11, 2016
Showing with 133 additions and 341 deletions.
  1. +16 −8 README.md
  2. +4 −37 image_processing.py
  3. +0 −102 resnet.py
  4. +108 −0 resnet_train.py
  5. +0 −180 test_error_rate.py
  6. +2 −2 train_cifar.py
  7. +3 −12 train_imagenet.py
View
@@ -1,6 +1,10 @@
-# ResNet Model in TensorFlow
+# ResNet in TensorFlow
-This is the second version of my ResNet implementation.
+Implemenation of [Deep Residual Learning for Image
+Recognition](http://arxiv.org/abs/1512.03385). Includes a tool to use He et
+al's published trained Caffe weights in TensorFlow.
+
+MIT license. Contributions welcome.
## Goals
@@ -13,16 +17,20 @@ This is the second version of my ResNet implementation.
not using any classes and making heavy use of variable scope. It should be
easily usable in other models.
-* Expierment with changes to ResNet like [stochastic
+* Foundation to experiment with changes to ResNet like [stochastic
depth](https://arxiv.org/abs/1603.09382), [shared weights at each
- scale](https://arxiv.org/abs/1604.03640), and 1D convolutions for audio.
+ scale](https://arxiv.org/abs/1604.03640), and 1D convolutions for audio. (Not yet implemented.)
+
+* ResNet is fully convolutional and the implementation should allow inputs to be any size.
+
+* Be able to train out of the box on CIFAR-10, 100, and ImageNet.
-* ResNet is fully convolutional and the implementation should allow inputs to
- be any size.
## Pretrained Model
-Instead of running `convert.py`, which depends on Caffe, you can download the converted model thru BitTorrent:
+To convert the published Caffe pretrained model, run `convert.py`. However
+Caffe is annoying to install so I'm providing a download of the output of
+convert.py:
[tensorflow-resnet-pretrained-20160509.tar.gz.torrent](https://raw.githubusercontent.com/ry/tensorflow-resnet/master/data/tensorflow-resnet-pretrained-20160509.tar.gz.torrent) 464M
@@ -39,4 +47,4 @@ Instead of running `convert.py`, which depends on Caffe, you can download the co
TF and Caffe handle padding. Also preprocessing is done with color-channel means
instead of pixel-wise means.
-* ResNet is full convolutional. You can resize the network input down to 65x65 images.
+
View
@@ -103,41 +103,6 @@ def inputs(dataset, batch_size=None, num_preprocess_threads=None):
return images, labels
-def distorted_inputs(dataset, batch_size=None, num_preprocess_threads=None):
- """Generate batches of distorted versions of ImageNet images.
-
- Use this function as the inputs for training a network.
-
- Distorting images provides a useful technique for augmenting the data
- set during training in order to make the network invariant to aspects
- of the image that do not effect the label.
-
- Args:
- dataset: instance of Dataset class specifying the dataset.
- batch_size: integer, number of examples in batch
- num_preprocess_threads: integer, total number of preprocessing threads but
- None defaults to FLAGS.num_preprocess_threads.
-
- Returns:
- images: Images. 4D tensor of size [batch_size, FLAGS.image_size,
- FLAGS.image_size, 3].
- labels: 1-D integer Tensor of [batch_size].
- """
- if not batch_size:
- batch_size = FLAGS.batch_size
-
- # Force all input processing onto CPU in order to reserve the GPU for
- # the forward inference and back-propagation.
- with tf.device('/cpu:0'):
- images, labels = batch_inputs(
- dataset,
- batch_size,
- train=True,
- num_preprocess_threads=num_preprocess_threads,
- num_readers=FLAGS.num_readers)
- return images, labels
-
-
def decode_jpeg(image_buffer, scope=None):
"""Decode a JPEG string into one 3-D float image Tensor.
@@ -216,8 +181,10 @@ def distort_image(image, height, width, bbox, thread_id=0, scope=None):
"""
with tf.op_scope([image, height, width, bbox], scope, 'distort_image'):
- # Crop the image to the specified bounding box.
- distorted_image = image #tf.slice(image, bbox_begin, bbox_size)
+ # NOTE(ry) I unceremoniously removed all the bounding box code.
+ # Original here: https://github.com/tensorflow/models/blob/148a15fb043dacdd1595eb4c5267705fbd362c6a/inception/inception/image_processing.py
+
+ distorted_image = image
# This resizing operation may distort the images because the aspect
# ratio is not respected. We select a resize method in a round robin
View
@@ -16,118 +16,16 @@
CONV_WEIGHT_STDDEV = 0.1
FC_WEIGHT_DECAY = 0.00004
FC_WEIGHT_STDDEV = 0.01
-MOMENTUM = 0.9
RESNET_VARIABLES = 'resnet_variables'
UPDATE_OPS_COLLECTION = 'resnet_update_ops' # must be grouped with training op
IMAGENET_MEAN_BGR = [103.062623801, 115.902882574, 123.151630838, ]
FLAGS = tf.app.flags.FLAGS
-tf.app.flags.DEFINE_string('train_dir', '/tmp/resnet_train',
- """Directory where to write event logs """
- """and checkpoint.""")
-tf.app.flags.DEFINE_float('learning_rate', 0.1, "learning rate.")
-tf.app.flags.DEFINE_integer('batch_size', 16, "batch size")
tf.app.flags.DEFINE_integer('input_size', 224, "input image size")
tf.app.flags.DEFINE_boolean('resume', False,
'resume from latest saved state')
-def train(images, labels, small=False):
- global_step = tf.get_variable('global_step', [],
- initializer=tf.constant_initializer(0),
- trainable=False)
-
- if small:
- logits = inference_small(images,
- num_classes=10,
- is_training=True,
- num_blocks=3)
- else:
- logits = inference(images,
- num_classes=1000,
- is_training=True,
- preprocess=True,
- bottleneck=False,
- num_blocks=[2, 2, 2, 2])
-
- loss_ = loss(logits, labels)
-
- # loss_avg
- ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
- tf.add_to_collection(UPDATE_OPS_COLLECTION, ema.apply([loss_]))
- loss_avg = ema.average(loss_)
- tf.scalar_summary('loss_avg', loss_avg)
-
- tf.scalar_summary('learning_rate', FLAGS.learning_rate)
-
- opt = tf.train.MomentumOptimizer(FLAGS.learning_rate, MOMENTUM)
- grads = opt.compute_gradients(loss_)
- for grad, var in grads:
- if grad is not None:
- tf.histogram_summary(var.op.name + '/gradients', grad)
- apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
-
- for var in tf.trainable_variables():
- tf.histogram_summary(var.op.name, var)
-
- batchnorm_updates = tf.get_collection(UPDATE_OPS_COLLECTION)
- batchnorm_updates_op = tf.group(*batchnorm_updates)
- train_op = tf.group(apply_gradient_op, batchnorm_updates_op)
-
- saver = tf.train.Saver(tf.all_variables())
-
- summary_op = tf.merge_all_summaries()
-
- init = tf.initialize_all_variables()
-
- sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
- sess.run(init)
- tf.train.start_queue_runners(sess=sess)
-
- summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)
-
- if FLAGS.resume:
- latest = tf.train.latest_checkpoint(FLAGS.train_dir)
- if not latest:
- print "No checkpoint to continue from in", FLAGS.train_dir
- sys.exit(1)
- print "resume", latest
- saver.restore(sess, latest)
-
- while True:
- start_time = time.time()
-
- #images_, labels_ = dataset.get_batch(FLAGS.batch_size, FLAGS.input_size)
-
- step = sess.run(global_step)
- i = [train_op, loss_]
-
- write_summary = step % 100 and step > 1
- if write_summary:
- i.append(summary_op)
-
- o = sess.run(i)
-
- loss_value = o[1]
-
- duration = time.time() - start_time
-
- assert not np.isnan(loss_value), 'Model diverged with loss = NaN'
-
- if step % 5 == 0:
- examples_per_sec = FLAGS.batch_size / float(duration)
- format_str = ('step %d, loss = %.2f (%.1f examples/sec; %.3f '
- 'sec/batch)')
- print(format_str % (step, loss_value, examples_per_sec, duration))
-
- if write_summary:
- summary_str = o[2]
- summary_writer.add_summary(summary_str, step)
-
- # Save the model checkpoint periodically.
- if step > 1 and step % 100 == 0:
- checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
- saver.save(sess, checkpoint_path, global_step=global_step)
def inference(x, is_training,
View
@@ -0,0 +1,108 @@
+from resnet import *
+import tensorflow as tf
+
+MOMENTUM = 0.9
+
+FLAGS = tf.app.flags.FLAGS
+tf.app.flags.DEFINE_string('train_dir', '/tmp/resnet_train',
+ """Directory where to write event logs """
+ """and checkpoint.""")
+tf.app.flags.DEFINE_float('learning_rate', 0.1, "learning rate.")
+tf.app.flags.DEFINE_integer('batch_size', 16, "batch size")
+
+
+def train(images, labels, small=False):
+ global_step = tf.get_variable('global_step', [],
+ initializer=tf.constant_initializer(0),
+ trainable=False)
+ is_training = tf.placeholder('bool', [], name="is_training")
+ if small:
+ # CIFAR
+ logits = inference_small(images,
+ num_classes=10,
+ is_training=is_training,
+ num_blocks=3)
+ else:
+ # ImageNet
+ logits = inference(images,
+ num_classes=1000,
+ is_training=is_training,
+ bottleneck=False,
+ num_blocks=[2, 2, 2, 2])
+
+ loss_ = loss(logits, labels)
+
+ # loss_avg
+ ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
+ tf.add_to_collection(UPDATE_OPS_COLLECTION, ema.apply([loss_]))
+ loss_avg = ema.average(loss_)
+ tf.scalar_summary('loss_avg', loss_avg)
+
+ tf.scalar_summary('learning_rate', FLAGS.learning_rate)
+
+ opt = tf.train.MomentumOptimizer(FLAGS.learning_rate, MOMENTUM)
+ grads = opt.compute_gradients(loss_)
+ for grad, var in grads:
+ if grad is not None:
+ tf.histogram_summary(var.op.name + '/gradients', grad)
+ apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
+
+ for var in tf.trainable_variables():
+ tf.histogram_summary(var.op.name, var)
+
+ batchnorm_updates = tf.get_collection(UPDATE_OPS_COLLECTION)
+ batchnorm_updates_op = tf.group(*batchnorm_updates)
+ train_op = tf.group(apply_gradient_op, batchnorm_updates_op)
+
+ saver = tf.train.Saver(tf.all_variables())
+
+ summary_op = tf.merge_all_summaries()
+
+ init = tf.initialize_all_variables()
+
+ sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
+ sess.run(init)
+ tf.train.start_queue_runners(sess=sess)
+
+ summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)
+
+ if FLAGS.resume:
+ latest = tf.train.latest_checkpoint(FLAGS.train_dir)
+ if not latest:
+ print "No checkpoint to continue from in", FLAGS.train_dir
+ sys.exit(1)
+ print "resume", latest
+ saver.restore(sess, latest)
+
+ while True:
+ start_time = time.time()
+
+ step = sess.run(global_step)
+ i = [train_op, loss_]
+
+ write_summary = step % 100 and step > 1
+ if write_summary:
+ i.append(summary_op)
+
+ o = sess.run(i, { is_training: True })
+
+ loss_value = o[1]
+
+ duration = time.time() - start_time
+
+ assert not np.isnan(loss_value), 'Model diverged with loss = NaN'
+
+ if step % 5 == 0:
+ examples_per_sec = FLAGS.batch_size / float(duration)
+ format_str = ('step %d, loss = %.2f (%.1f examples/sec; %.3f '
+ 'sec/batch)')
+ print(format_str % (step, loss_value, examples_per_sec, duration))
+
+ if write_summary:
+ summary_str = o[2]
+ summary_writer.add_summary(summary_str, step)
+
+ # Save the model checkpoint periodically.
+ if step > 1 and step % 100 == 0:
+ checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
+ saver.save(sess, checkpoint_path, global_step=global_step)
Oops, something went wrong.

0 comments on commit b994ea9

Please sign in to comment.