Permalink
Browse files

add preprocessing

  • Loading branch information...
1 parent c1e027f commit 58c046dc7f9ee3c6f6966c57e2e178801170e1a6 @ry committed May 6, 2016
Showing with 29 additions and 30 deletions.
  1. +1 −1 .gitignore
  2. +7 −6 convert.py
  3. +2 −3 forward.py
  4. +19 −20 resnet.py
View
@@ -5,4 +5,4 @@
*.tfmodel
checkpoint
ResNet-L*.ckpt
-ResNet-L*.ckpt.meta
+ResNet-L*.meta
View
@@ -16,10 +16,6 @@ class CaffeParamProvider():
def __init__(self, caffe_net):
self.caffe_net = caffe_net
- def mean_bgr(self):
- mean_bgr = load_mean_bgr()
- return mean_bgr.reshape((1, 224, 224, 3))
-
def conv_kernel(self, name):
k = self.caffe_net.params[name][0].data
# caffe [out_channels, in_channels, filter_height, filter_width]
@@ -55,6 +51,9 @@ def fc_biases(self, name):
def preprocess(img):
"""Changes RGB [0,1] valued image to BGR [0,255] with mean subtracted."""
mean_bgr = load_mean_bgr()
+ print 'mean blue', np.mean(mean_bgr[:,:,0])
+ print 'mean green', np.mean(mean_bgr[:,:,1])
+ print 'mean red', np.mean(mean_bgr[:,:,2])
out = np.copy(img) * 255.0
out = out[:, :, [2,1,0]] # swap channel from RGB to BGR
out -= mean_bgr
@@ -217,7 +216,7 @@ def checkpoint_fn(layers):
return 'ResNet-L%d.ckpt' % layers
def meta_fn(layers):
- return checkpoint_fn(layers) + '.meta'
+ return 'ResNet-L%d.meta' % layers
def convert(graph, img, img_p, layers):
caffe_model = load_caffe(img_p, layers)
@@ -239,6 +238,7 @@ def convert(graph, img, img_p, layers):
logits = resnet.inference(images,
is_training=False,
num_blocks=num_blocks,
+ preprocess=True,
bottleneck=True)
prob = tf.nn.softmax(logits, name='prob')
@@ -279,7 +279,7 @@ def convert(graph, img, img_p, layers):
]
o = sess.run(i, {
- images: img_p[np.newaxis,:]
+ images: img[np.newaxis,:]
})
assert_almost_equal(caffe_model.blobs['conv1'].data, o[0])
@@ -315,6 +315,7 @@ def save_graph(save_path):
def main(_):
img = load_image("data/cat.jpg")
+ print img
img_p = preprocess(img)
for layers in [50, 101, 152]:
View
@@ -1,10 +1,9 @@
-from convert import print_prob, load_image, preprocess, checkpoint_fn, meta_fn
+from convert import print_prob, load_image, checkpoint_fn, meta_fn
import tensorflow as tf
layers = 50
img = load_image("data/cat.jpg")
-img_p = preprocess(img)
sess = tf.Session()
@@ -21,7 +20,7 @@
#sess.run(init)
print "graph restored"
-batch = img_p.reshape((1, 224, 224, 3))
+batch = img.reshape((1, 224, 224, 3))
feed_dict = { images: batch }
View
@@ -13,11 +13,23 @@
BN_EPSILON = 0.001
RESNET_VARIABLES = 'resnet_variables'
UPDATE_OPS_COLLECTION = 'resnet_update_ops' # must be grouped with training op
+MEAN_BGR = [
+ 103.062623801,
+ 115.902882574,
+ 123.151630838,
+]
+
def inference(x, is_training,
num_classes=1000,
num_blocks=[2, 2, 2, 2], # defaults to 18-layer network
+ preprocess=True,
bottleneck=True):
+ # if preprocess is True, input should be RGB [0,1], otherwise BGR with mean
+ # subtracted
+ if preprocess:
+ x = _preprocess(x)
+
is_training = tf.convert_to_tensor(is_training,
dtype='bool',
name='is_training')
@@ -47,6 +59,13 @@ def inference(x, is_training,
return logits
+def _preprocess(rgb):
+ """Changes RGB [0,1] valued image to BGR [0,255] with mean subtracted."""
+ red, green, blue = tf.split(3, 3, rgb * 255.0)
+ bgr = tf.concat(3, [blue, green, red])
+ bgr -= MEAN_BGR
+ return bgr
+
def loss(logits, labels, batch_size=None, label_smoothing=0.1):
if not batch_size:
batch_size = FLAGS.batch_size
@@ -204,24 +223,4 @@ def _max_pool(x, ksize=3, stride=2):
return tf.nn.max_pool(x, ksize=[1, ksize, ksize, 1],
strides=[ 1, stride, stride, 1], padding='SAME')
-def preprocess(self, rgb):
- rgb_scaled = rgb * 255.0
-
- red, green, blue = tf.split(3, 3, rgb_scaled)
-
- mean_bgr = self.param_provider.mean_bgr()
-
- # resize mean_bgr to match input
- input_width = rgb.get_shape().as_list()[2]
- mean_bgr = tf.image.resize_bilinear(mean_bgr, [input_width, input_width])
-
- mean_blue, mean_green, mean_red = tf.split(3, 3, mean_bgr)
-
- bgr = tf.concat(3, [
- blue - mean_blue,
- green - mean_green,
- red - mean_red,
- ], name="centered_bgr")
-
- return bgr

0 comments on commit 58c046d

Please sign in to comment.