Skip to content
This repository has been archived by the owner on Oct 19, 2019. It is now read-only.

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
ry committed May 11, 2016
1 parent 6b42dfa commit b994ea9
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 341 deletions.
24 changes: 16 additions & 8 deletions README.md
@@ -1,6 +1,10 @@
# ResNet Model in TensorFlow
# ResNet in TensorFlow

This is the second version of my ResNet implementation.
Implemenation of [Deep Residual Learning for Image
Recognition](http://arxiv.org/abs/1512.03385). Includes a tool to use He et
al's published trained Caffe weights in TensorFlow.

MIT license. Contributions welcome.

## Goals

Expand All @@ -13,16 +17,20 @@ This is the second version of my ResNet implementation.
not using any classes and making heavy use of variable scope. It should be
easily usable in other models.

* Expierment with changes to ResNet like [stochastic
* Foundation to experiment with changes to ResNet like [stochastic
depth](https://arxiv.org/abs/1603.09382), [shared weights at each
scale](https://arxiv.org/abs/1604.03640), and 1D convolutions for audio.
scale](https://arxiv.org/abs/1604.03640), and 1D convolutions for audio. (Not yet implemented.)

* ResNet is fully convolutional and the implementation should allow inputs to be any size.

* Be able to train out of the box on CIFAR-10, 100, and ImageNet.

* ResNet is fully convolutional and the implementation should allow inputs to
be any size.

## Pretrained Model

Instead of running `convert.py`, which depends on Caffe, you can download the converted model thru BitTorrent:
To convert the published Caffe pretrained model, run `convert.py`. However
Caffe is annoying to install so I'm providing a download of the output of
convert.py:

[tensorflow-resnet-pretrained-20160509.tar.gz.torrent](https://raw.githubusercontent.com/ry/tensorflow-resnet/master/data/tensorflow-resnet-pretrained-20160509.tar.gz.torrent) 464M

Expand All @@ -39,4 +47,4 @@ Instead of running `convert.py`, which depends on Caffe, you can download the co
TF and Caffe handle padding. Also preprocessing is done with color-channel means
instead of pixel-wise means.

* ResNet is full convolutional. You can resize the network input down to 65x65 images.

41 changes: 4 additions & 37 deletions image_processing.py
Expand Up @@ -103,41 +103,6 @@ def inputs(dataset, batch_size=None, num_preprocess_threads=None):
return images, labels


def distorted_inputs(dataset, batch_size=None, num_preprocess_threads=None):
"""Generate batches of distorted versions of ImageNet images.
Use this function as the inputs for training a network.
Distorting images provides a useful technique for augmenting the data
set during training in order to make the network invariant to aspects
of the image that do not effect the label.
Args:
dataset: instance of Dataset class specifying the dataset.
batch_size: integer, number of examples in batch
num_preprocess_threads: integer, total number of preprocessing threads but
None defaults to FLAGS.num_preprocess_threads.
Returns:
images: Images. 4D tensor of size [batch_size, FLAGS.image_size,
FLAGS.image_size, 3].
labels: 1-D integer Tensor of [batch_size].
"""
if not batch_size:
batch_size = FLAGS.batch_size

# Force all input processing onto CPU in order to reserve the GPU for
# the forward inference and back-propagation.
with tf.device('/cpu:0'):
images, labels = batch_inputs(
dataset,
batch_size,
train=True,
num_preprocess_threads=num_preprocess_threads,
num_readers=FLAGS.num_readers)
return images, labels


def decode_jpeg(image_buffer, scope=None):
"""Decode a JPEG string into one 3-D float image Tensor.
Expand Down Expand Up @@ -216,8 +181,10 @@ def distort_image(image, height, width, bbox, thread_id=0, scope=None):
"""
with tf.op_scope([image, height, width, bbox], scope, 'distort_image'):

# Crop the image to the specified bounding box.
distorted_image = image #tf.slice(image, bbox_begin, bbox_size)
# NOTE(ry) I unceremoniously removed all the bounding box code.
# Original here: https://github.com/tensorflow/models/blob/148a15fb043dacdd1595eb4c5267705fbd362c6a/inception/inception/image_processing.py

distorted_image = image

# This resizing operation may distort the images because the aspect
# ratio is not respected. We select a resize method in a round robin
Expand Down
102 changes: 0 additions & 102 deletions resnet.py
Expand Up @@ -16,118 +16,16 @@
CONV_WEIGHT_STDDEV = 0.1
FC_WEIGHT_DECAY = 0.00004
FC_WEIGHT_STDDEV = 0.01
MOMENTUM = 0.9
RESNET_VARIABLES = 'resnet_variables'
UPDATE_OPS_COLLECTION = 'resnet_update_ops' # must be grouped with training op
IMAGENET_MEAN_BGR = [103.062623801, 115.902882574, 123.151630838, ]

FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('train_dir', '/tmp/resnet_train',
"""Directory where to write event logs """
"""and checkpoint.""")
tf.app.flags.DEFINE_float('learning_rate', 0.1, "learning rate.")
tf.app.flags.DEFINE_integer('batch_size', 16, "batch size")
tf.app.flags.DEFINE_integer('input_size', 224, "input image size")
tf.app.flags.DEFINE_boolean('resume', False,
'resume from latest saved state')


def train(images, labels, small=False):
global_step = tf.get_variable('global_step', [],
initializer=tf.constant_initializer(0),
trainable=False)

if small:
logits = inference_small(images,
num_classes=10,
is_training=True,
num_blocks=3)
else:
logits = inference(images,
num_classes=1000,
is_training=True,
preprocess=True,
bottleneck=False,
num_blocks=[2, 2, 2, 2])

loss_ = loss(logits, labels)

# loss_avg
ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
tf.add_to_collection(UPDATE_OPS_COLLECTION, ema.apply([loss_]))
loss_avg = ema.average(loss_)
tf.scalar_summary('loss_avg', loss_avg)

tf.scalar_summary('learning_rate', FLAGS.learning_rate)

opt = tf.train.MomentumOptimizer(FLAGS.learning_rate, MOMENTUM)
grads = opt.compute_gradients(loss_)
for grad, var in grads:
if grad is not None:
tf.histogram_summary(var.op.name + '/gradients', grad)
apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)

for var in tf.trainable_variables():
tf.histogram_summary(var.op.name, var)

batchnorm_updates = tf.get_collection(UPDATE_OPS_COLLECTION)
batchnorm_updates_op = tf.group(*batchnorm_updates)
train_op = tf.group(apply_gradient_op, batchnorm_updates_op)

saver = tf.train.Saver(tf.all_variables())

summary_op = tf.merge_all_summaries()

init = tf.initialize_all_variables()

sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
sess.run(init)
tf.train.start_queue_runners(sess=sess)

summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)

if FLAGS.resume:
latest = tf.train.latest_checkpoint(FLAGS.train_dir)
if not latest:
print "No checkpoint to continue from in", FLAGS.train_dir
sys.exit(1)
print "resume", latest
saver.restore(sess, latest)

while True:
start_time = time.time()

#images_, labels_ = dataset.get_batch(FLAGS.batch_size, FLAGS.input_size)

step = sess.run(global_step)
i = [train_op, loss_]

write_summary = step % 100 and step > 1
if write_summary:
i.append(summary_op)

o = sess.run(i)

loss_value = o[1]

duration = time.time() - start_time

assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

if step % 5 == 0:
examples_per_sec = FLAGS.batch_size / float(duration)
format_str = ('step %d, loss = %.2f (%.1f examples/sec; %.3f '
'sec/batch)')
print(format_str % (step, loss_value, examples_per_sec, duration))

if write_summary:
summary_str = o[2]
summary_writer.add_summary(summary_str, step)

# Save the model checkpoint periodically.
if step > 1 and step % 100 == 0:
checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
saver.save(sess, checkpoint_path, global_step=global_step)


def inference(x, is_training,
Expand Down
108 changes: 108 additions & 0 deletions resnet_train.py
@@ -0,0 +1,108 @@
from resnet import *
import tensorflow as tf

MOMENTUM = 0.9

FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('train_dir', '/tmp/resnet_train',
"""Directory where to write event logs """
"""and checkpoint.""")
tf.app.flags.DEFINE_float('learning_rate', 0.1, "learning rate.")
tf.app.flags.DEFINE_integer('batch_size', 16, "batch size")


def train(images, labels, small=False):
global_step = tf.get_variable('global_step', [],
initializer=tf.constant_initializer(0),
trainable=False)
is_training = tf.placeholder('bool', [], name="is_training")
if small:
# CIFAR
logits = inference_small(images,
num_classes=10,
is_training=is_training,
num_blocks=3)
else:
# ImageNet
logits = inference(images,
num_classes=1000,
is_training=is_training,
bottleneck=False,
num_blocks=[2, 2, 2, 2])

loss_ = loss(logits, labels)

# loss_avg
ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
tf.add_to_collection(UPDATE_OPS_COLLECTION, ema.apply([loss_]))
loss_avg = ema.average(loss_)
tf.scalar_summary('loss_avg', loss_avg)

tf.scalar_summary('learning_rate', FLAGS.learning_rate)

opt = tf.train.MomentumOptimizer(FLAGS.learning_rate, MOMENTUM)
grads = opt.compute_gradients(loss_)
for grad, var in grads:
if grad is not None:
tf.histogram_summary(var.op.name + '/gradients', grad)
apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)

for var in tf.trainable_variables():
tf.histogram_summary(var.op.name, var)

batchnorm_updates = tf.get_collection(UPDATE_OPS_COLLECTION)
batchnorm_updates_op = tf.group(*batchnorm_updates)
train_op = tf.group(apply_gradient_op, batchnorm_updates_op)

saver = tf.train.Saver(tf.all_variables())

summary_op = tf.merge_all_summaries()

init = tf.initialize_all_variables()

sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
sess.run(init)
tf.train.start_queue_runners(sess=sess)

summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)

if FLAGS.resume:
latest = tf.train.latest_checkpoint(FLAGS.train_dir)
if not latest:
print "No checkpoint to continue from in", FLAGS.train_dir
sys.exit(1)
print "resume", latest
saver.restore(sess, latest)

while True:
start_time = time.time()

step = sess.run(global_step)
i = [train_op, loss_]

write_summary = step % 100 and step > 1
if write_summary:
i.append(summary_op)

o = sess.run(i, { is_training: True })

loss_value = o[1]

duration = time.time() - start_time

assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

if step % 5 == 0:
examples_per_sec = FLAGS.batch_size / float(duration)
format_str = ('step %d, loss = %.2f (%.1f examples/sec; %.3f '
'sec/batch)')
print(format_str % (step, loss_value, examples_per_sec, duration))

if write_summary:
summary_str = o[2]
summary_writer.add_summary(summary_str, step)

# Save the model checkpoint periodically.
if step > 1 and step % 100 == 0:
checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
saver.save(sess, checkpoint_path, global_step=global_step)

0 comments on commit b994ea9

Please sign in to comment.