Skip to content

Commit

Permalink
Fix crop_and_resize usage
Browse files Browse the repository at this point in the history
  • Loading branch information
vierja committed Jun 19, 2017
1 parent 9214617 commit 13de22d
Showing 1 changed file with 18 additions and 16 deletions.
34 changes: 18 additions & 16 deletions frcnn/roi_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import tensorflow as tf
import numpy as np

slim = tf.contrib.slim

CROP = 'crop'
ROI_POOLING = 'roi_pooling'
Expand All @@ -18,45 +19,46 @@ def __init__(self, pooling_mode=CROP, pooled_width=7, pooled_height=7,
self._spatial_scale = spatial_scale
self._feat_stride = feat_stride

def _get_bboxes(self, rois, pretrained):
def _get_bboxes(self, roi_proposals, pretrained):
"""
Get normalized coordinates for RoIs (betweetn 0 and 1 for easy cropping)
"""
pretrained_shape = tf.shape(pretrained)
height = (tf.to_float(pretrained_shape[1]) - 1.) * np.float32(self._feat_stride[0])
width = (tf.to_float(pretrained_shape[2]) - 1.) * np.float32(self._feat_stride[0])

x1 = tf.slice(rois, [0, 1], [-1, 1], name="x1") / width
y1 = tf.slice(rois, [0, 2], [-1, 1], name="y1") / height
x2 = tf.slice(rois, [0, 3], [-1, 1], name="x2") / width
y2 = tf.slice(rois, [0, 4], [-1, 1], name="y2") / height
x1 = tf.slice(roi_proposals, [0, 1], [-1, 1], name="x1") / width
y1 = tf.slice(roi_proposals, [0, 2], [-1, 1], name="y1") / height
x2 = tf.slice(roi_proposals, [0, 3], [-1, 1], name="x2") / width
y2 = tf.slice(roi_proposals, [0, 4], [-1, 1], name="y2") / height

# Won't be backpropagated to rois anyway, but to save time TODO: What time is saved?
# Won't be backpropagated to rois anyway, but to save time TODO: Remove?
bboxes = tf.stop_gradient(tf.concat([y1, x1, y2, x2], axis=1))

return bboxes

def _roi_crop(self, rois, pretrained):
def _roi_crop(self, roi_proposals, pretrained):

bboxes = self._get_bboxes(rois, pretrained)
bboxes = self._get_bboxes(roi_proposals, pretrained)
# TODO: Why?!!?
batch_ids = tf.squeeze(tf.slice(rois, [0, 0], [-1, 1], name="batch_id"), [1])

# batch_ids = tf.squeeze(tf.slice(roi_proposals, [0, 0], [-1, 1], name="batch_id"), [1])
bboxes_shape = tf.shape(bboxes)
batch_ids = tf.zeros((bboxes_shape[0], ), dtype=tf.int32)
crops = tf.image.crop_and_resize(
pretrained, bboxes, tf.to_int32(batch_ids),
pretrained, bboxes, batch_ids,
[self._pooled_width * 2, self._pooled_height * 2], name="crops"
)

return tf.nn.max_pool(crops, [1, 1, 2, 2], [2] * 4, padding='SAME')
return slim.max_pool2d(crops, [2, 2], stride=2)


def _roi_pooling(self, roi, pretrained):
def _roi_pooling(self, roi_proposals, pretrained):
raise NotImplemented()

def _build(self, roi, pretrained):
def _build(self, roi_proposals, pretrained):
if self._pooling_mode == CROP:
return self._roi_crop(roi, pretrained)
return self._roi_crop(roi_proposals, pretrained)
elif self._pooling_mode == ROI_POOLING:
return self._roi_pooling(roi, pretrained)
return self._roi_pooling(roi_proposals, pretrained)
else:
raise NotImplemented('Pooling mode {} does not exist.'.format(self._pooling_mode))

0 comments on commit 13de22d

Please sign in to comment.