/
roi_sampler.py
175 lines (155 loc) · 8.08 KB
/
roi_sampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains definitions of ROI sampler."""
# Import libraries
import tensorflow as tf
from official.vision.modeling.layers import box_sampler
from official.vision.ops import box_matcher
from official.vision.ops import iou_similarity
from official.vision.ops import target_gather
@tf.keras.utils.register_keras_serializable(package='Vision')
class ROISampler(tf.keras.layers.Layer):
"""Samples ROIs and assigns targets to the sampled ROIs."""
def __init__(self,
mix_gt_boxes: bool = True,
num_sampled_rois: int = 512,
foreground_fraction: float = 0.25,
foreground_iou_threshold: float = 0.5,
background_iou_high_threshold: float = 0.5,
background_iou_low_threshold: float = 0,
skip_subsampling: bool = False,
**kwargs):
"""Initializes a ROI sampler.
Args:
mix_gt_boxes: A `bool` of whether to mix the groundtruth boxes with
proposed ROIs.
num_sampled_rois: An `int` of the number of sampled ROIs per image.
foreground_fraction: A `float` in [0, 1], what percentage of proposed ROIs
should be sampled from the foreground boxes.
foreground_iou_threshold: A `float` that represents the IoU threshold for
a box to be considered as positive (if >= `foreground_iou_threshold`).
background_iou_high_threshold: A `float` that represents the IoU threshold
for a box to be considered as negative (if overlap in
[`background_iou_low_threshold`, `background_iou_high_threshold`]).
background_iou_low_threshold: A `float` that represents the IoU threshold
for a box to be considered as negative (if overlap in
[`background_iou_low_threshold`, `background_iou_high_threshold`])
skip_subsampling: a bool that determines if we want to skip the sampling
procedure than balances the fg/bg classes. Used for upper frcnn layers
in cascade RCNN.
**kwargs: Additional keyword arguments passed to Layer.
"""
self._config_dict = {
'mix_gt_boxes': mix_gt_boxes,
'num_sampled_rois': num_sampled_rois,
'foreground_fraction': foreground_fraction,
'foreground_iou_threshold': foreground_iou_threshold,
'background_iou_high_threshold': background_iou_high_threshold,
'background_iou_low_threshold': background_iou_low_threshold,
'skip_subsampling': skip_subsampling,
}
self._sim_calc = iou_similarity.IouSimilarity()
self._box_matcher = box_matcher.BoxMatcher(
thresholds=[
background_iou_low_threshold, background_iou_high_threshold,
foreground_iou_threshold
],
indicators=[-3, -1, -2, 1])
self._target_gather = target_gather.TargetGather()
self._sampler = box_sampler.BoxSampler(
num_sampled_rois, foreground_fraction)
super(ROISampler, self).__init__(**kwargs)
def call(self, boxes: tf.Tensor, gt_boxes: tf.Tensor, gt_classes: tf.Tensor):
"""Assigns the proposals with groundtruth classes and performs subsmpling.
Given `proposed_boxes`, `gt_boxes`, and `gt_classes`, the function uses the
following algorithm to generate the final `num_samples_per_image` RoIs.
1. Calculates the IoU between each proposal box and each gt_boxes.
2. Assigns each proposed box with a groundtruth class and box by choosing
the largest IoU overlap.
3. Samples `num_samples_per_image` boxes from all proposed boxes, and
returns box_targets, class_targets, and RoIs.
Args:
boxes: A `tf.Tensor` of shape of [batch_size, N, 4]. N is the number of
proposals before groundtruth assignment. The last dimension is the
box coordinates w.r.t. the scaled images in [ymin, xmin, ymax, xmax]
format.
gt_boxes: A `tf.Tensor` of shape of [batch_size, MAX_NUM_INSTANCES, 4].
The coordinates of gt_boxes are in the pixel coordinates of the scaled
image. This tensor might have padding of values -1 indicating the
invalid box coordinates.
gt_classes: A `tf.Tensor` with a shape of [batch_size, MAX_NUM_INSTANCES].
This tensor might have paddings with values of -1 indicating the invalid
classes.
Returns:
sampled_rois: A `tf.Tensor` of shape of [batch_size, K, 4], representing
the coordinates of the sampled RoIs, where K is the number of the
sampled RoIs, i.e. K = num_samples_per_image.
sampled_gt_boxes: A `tf.Tensor` of shape of [batch_size, K, 4], storing
the box coordinates of the matched groundtruth boxes of the samples
RoIs.
sampled_gt_classes: A `tf.Tensor` of shape of [batch_size, K], storing the
classes of the matched groundtruth boxes of the sampled RoIs.
sampled_gt_indices: A `tf.Tensor` of shape of [batch_size, K], storing the
indices of the sampled groudntruth boxes in the original `gt_boxes`
tensor, i.e.,
gt_boxes[sampled_gt_indices[:, i]] = sampled_gt_boxes[:, i].
"""
gt_boxes = tf.cast(gt_boxes, dtype=boxes.dtype)
if self._config_dict['mix_gt_boxes']:
boxes = tf.concat([boxes, gt_boxes], axis=1)
boxes_invalid_mask = tf.less(
tf.reduce_max(boxes, axis=-1, keepdims=True), 0.0)
gt_invalid_mask = tf.less(
tf.reduce_max(gt_boxes, axis=-1, keepdims=True), 0.0)
similarity_matrix = self._sim_calc(boxes, gt_boxes, boxes_invalid_mask,
gt_invalid_mask)
matched_gt_indices, match_indicators = self._box_matcher(similarity_matrix)
positive_matches = tf.greater_equal(match_indicators, 0)
negative_matches = tf.equal(match_indicators, -1)
ignored_matches = tf.equal(match_indicators, -2)
invalid_matches = tf.equal(match_indicators, -3)
background_mask = tf.expand_dims(
tf.logical_or(negative_matches, invalid_matches), -1)
gt_classes = tf.expand_dims(gt_classes, axis=-1)
matched_gt_classes = self._target_gather(gt_classes, matched_gt_indices,
background_mask)
matched_gt_classes = tf.where(background_mask,
tf.zeros_like(matched_gt_classes),
matched_gt_classes)
matched_gt_boxes = self._target_gather(gt_boxes, matched_gt_indices,
tf.tile(background_mask, [1, 1, 4]))
matched_gt_boxes = tf.where(background_mask,
tf.zeros_like(matched_gt_boxes),
matched_gt_boxes)
matched_gt_indices = tf.where(
tf.squeeze(background_mask, -1), -tf.ones_like(matched_gt_indices),
matched_gt_indices)
if self._config_dict['skip_subsampling']:
return (boxes, matched_gt_boxes, tf.squeeze(matched_gt_classes,
axis=-1), matched_gt_indices)
sampled_indices = self._sampler(
positive_matches, negative_matches, ignored_matches)
sampled_rois = self._target_gather(boxes, sampled_indices)
sampled_gt_boxes = self._target_gather(matched_gt_boxes, sampled_indices)
sampled_gt_classes = tf.squeeze(self._target_gather(
matched_gt_classes, sampled_indices), axis=-1)
sampled_gt_indices = tf.squeeze(self._target_gather(
tf.expand_dims(matched_gt_indices, -1), sampled_indices), axis=-1)
return (sampled_rois, sampled_gt_boxes, sampled_gt_classes,
sampled_gt_indices)
def get_config(self):
return self._config_dict
@classmethod
def from_config(cls, config):
return cls(**config)