Skip to content
This repository has been archived by the owner on Oct 19, 2019. It is now read-only.

Commit

Permalink
add inference_small
Browse files Browse the repository at this point in the history
  • Loading branch information
ry committed May 9, 2016
1 parent 2a004d3 commit 51c7d0a
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 28 deletions.
21 changes: 7 additions & 14 deletions dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ def epoch_complete(self, batch_size):
self.index = 0
self.epochs_completed += 1

def get_batch(self, batch_size):
def get_batch(self, batch_size, input_size):
imgs = []
labels = []
while len(imgs) < batch_size:
try:
fn = self.data[self.index]['filename']
img = load_image(fn)
img = load_image(fn, input_size)
imgs.append(img)
labels.append(self.data[self.index]['label_index'])
self.index += 1
Expand All @@ -35,7 +35,7 @@ def get_batch(self, batch_size):
del self.data[self.index]

batch_images = np.stack(imgs)
assert batch_images.shape == (batch_size, 224, 224, 3)
assert batch_images.shape == (batch_size, input_size, input_size, 3)

batch_labels = np.asarray(labels).reshape((batch_size, 1))
assert batch_labels.shape == (batch_size, 1)
Expand Down Expand Up @@ -83,30 +83,23 @@ def load_data(data_dir):
return data


# Returns a numpy array of shape [height, width, 3]
def load_image(path):
# load image
# Returns a numpy array of shape [size, size, 3]
def load_image(path, size):
img = skimage.io.imread(path)

#print "Original Image Shape: ", img.shape
# we crop image from center
short_edge = min(img.shape[:2])
yy = int((img.shape[0] - short_edge) / 2)
xx = int((img.shape[1] - short_edge) / 2)
crop_img = img[yy : yy + short_edge, xx : xx + short_edge]
# resize to 224, 224

img = skimage.transform.resize(crop_img, (224, 224))

#print img
#img = img / 255.0
img = skimage.transform.resize(crop_img, (size, size))

# if it's a black and white photo, we need to change it to 3 channel
# or raise an error if we're not allowing b&w (which we do during training)
if len(img.shape) == 2:
img = np.stack([img, img, img], axis=-1)

assert img.shape == (224, 224, 3)
assert img.shape == (size, size, 3)
assert (0 <= img).all() and (img <= 1.0).all()

return img
45 changes: 41 additions & 4 deletions resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
BN_EPSILON = 0.001
RESNET_VARIABLES = 'resnet_variables'
UPDATE_OPS_COLLECTION = 'resnet_update_ops' # must be grouped with training op
MEAN_BGR = [
IMAGENET_MEAN_BGR = [
103.062623801,
115.902882574,
123.151630838,
Expand All @@ -29,7 +29,7 @@ def inference(x, is_training,
# if preprocess is True, input should be RGB [0,1], otherwise BGR with mean
# subtracted
if preprocess:
x = _preprocess(x)
x = _imagenet_preprocess(x)

is_training = tf.convert_to_tensor(is_training,
dtype='bool',
Expand Down Expand Up @@ -60,11 +60,48 @@ def inference(x, is_training,

return logits

def _preprocess(rgb):
# This is what they use for CIFAR-10 and 100.
# See Section 4.2 in http://arxiv.org/abs/1512.03385
def inference_small(x,
is_training,
num_classes=10,
num_blocks=3, # 6n+2 total weight layers will be used.
preprocess=True):
# if preprocess is True, input should be RGB [0,1], otherwise BGR with mean
# subtracted
if preprocess:
x = _imagenet_preprocess(x)

bottleneck = False
is_training = tf.convert_to_tensor(is_training,
dtype='bool',
name='is_training')

with tf.variable_scope('scale1'):
x = _conv(x, 16, ksize=3, stride=1)
x = _bn(x, is_training)
x = _relu(x)

x = stack(x, num_blocks, 16, bottleneck, is_training, stride=1)

with tf.variable_scope('scale2'):
x = stack(x, num_blocks, 32, bottleneck, is_training, stride=2)

with tf.variable_scope('scale3'):
x = stack(x, num_blocks, 64, bottleneck, is_training, stride=2)

# post-net
x = tf.reduce_mean(x, reduction_indices=[1, 2], name="avg_pool")
with tf.variable_scope('fc'):
logits = _fc(x, num_units_out=num_classes)

return logits

def _imagenet_preprocess(rgb):
"""Changes RGB [0,1] valued image to BGR [0,255] with mean subtracted."""
red, green, blue = tf.split(3, 3, rgb * 255.0)
bgr = tf.concat(3, [blue, green, red])
bgr -= MEAN_BGR
bgr -= IMAGENET_MEAN_BGR
return bgr

def loss(logits, labels, batch_size=None, label_smoothing=0.1):
Expand Down
25 changes: 15 additions & 10 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,21 @@
"""and checkpoint.""")
tf.app.flags.DEFINE_float('learning_rate', 0.1, "learning rate.")
tf.app.flags.DEFINE_integer('batch_size', 16, "batch size")
tf.app.flags.DEFINE_integer('input_size', 224, "input image size")
tf.app.flags.DEFINE_boolean('continue', False, 'resume from latest saved state')


def train(dataset):
# Create a variable to count the number of train() calls. This equals the
# number of batches processed * FLAGS.num_gpus.
global_step = tf.get_variable('global_step', [],
initializer=tf.constant_initializer(0),
trainable=False)

images = tf.placeholder("float", [None, 224, 224, 3], name="images")
labels = tf.placeholder("int32", [None, 1], name="labels")
images = tf.placeholder("float",
[None, FLAGS.input_size, FLAGS.input_size, 3],
name="images")
tf.image_summary('images', images)


logits = resnet.inference(images,
num_classes=1000,
Expand All @@ -40,13 +43,15 @@ def train(dataset):

loss = resnet.loss(logits, labels, batch_size=FLAGS.batch_size)
tf.scalar_summary('loss', loss)
tf.scalar_summary('learning_rate', FLAGS.learning_rate)

# loss_avg
ema = tf.train.ExponentialMovingAverage(resnet.MOVING_AVERAGE_DECAY, global_step)
tf.add_to_collection(resnet.UPDATE_OPS_COLLECTION, ema.apply([loss]))
loss_avg = ema.average(loss)
tf.scalar_summary('loss_avg', loss_avg)

tf.scalar_summary('learning_rate', FLAGS.learning_rate)

opt = tf.train.MomentumOptimizer(FLAGS.learning_rate, MOMENTUM)
grads = opt.compute_gradients(loss)
for grad, var in grads:
Expand All @@ -61,7 +66,7 @@ def train(dataset):
batchnorm_updates_op = tf.group(*batchnorm_updates)
train_op = tf.group(apply_gradient_op, batchnorm_updates_op)

saver = tf.train.Saver(tf.all_variables())
saver = tf.train.Saver(tf.all_variables())

summary_op = tf.merge_all_summaries()

Expand All @@ -80,10 +85,10 @@ def train(dataset):
print "continue", latest
saver.restore(sess, latest)

while True:
while True:
start_time = time.time()

images_, labels_ = dataset.get_batch(FLAGS.batch_size)
images_, labels_ = dataset.get_batch(FLAGS.batch_size, FLAGS.input_size)

step = sess.run(global_step)
i = [train_op, loss]
Expand All @@ -102,17 +107,17 @@ def train(dataset):
duration = time.time() - start_time

assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

if step % 5 == 0:
examples_per_sec = FLAGS.batch_size / float(duration)
format_str = ('step %d, loss = %.2f (%.1f examples/sec; %.3f '
'sec/batch)')
print(format_str % (step, loss_value, examples_per_sec, duration))

if write_summary:
summary_str = o[2]
summary_writer.add_summary(summary_str, step)

# Save the model checkpoint periodically.
if step > 1 and step % 100 == 0:
checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
Expand Down

0 comments on commit 51c7d0a

Please sign in to comment.