Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 21 additions & 14 deletions caffe2/operators/box_with_nms_limit_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,27 @@ bool BoxWithNMSLimitOp<CPUContext>::RunOnDevice() {

// tscores: (num_boxes, num_classes), 0 for background
if (tscores.dim() == 4) {
CAFFE_ENFORCE_EQ(tscores.size(2), 1, tscores.size(2));
CAFFE_ENFORCE_EQ(tscores.size(3), 1, tscores.size(3));
CAFFE_ENFORCE_EQ(tscores.size(2), 1);
CAFFE_ENFORCE_EQ(tscores.size(3), 1);
} else {
CAFFE_ENFORCE_EQ(tscores.dim(), 2, tscores.dim());
CAFFE_ENFORCE_EQ(tscores.dim(), 2);
}
CAFFE_ENFORCE(tscores.template IsType<float>(), tscores.dtype().name());
// tboxes: (num_boxes, num_classes * box_dim)
if (tboxes.dim() == 4) {
CAFFE_ENFORCE_EQ(tboxes.size(2), 1, tboxes.size(2));
CAFFE_ENFORCE_EQ(tboxes.size(3), 1, tboxes.size(3));
CAFFE_ENFORCE_EQ(tboxes.size(2), 1);
CAFFE_ENFORCE_EQ(tboxes.size(3), 1);
} else {
CAFFE_ENFORCE_EQ(tboxes.dim(), 2, tboxes.dim());
CAFFE_ENFORCE_EQ(tboxes.dim(), 2);
}
CAFFE_ENFORCE(tboxes.template IsType<float>(), tboxes.dtype().name());

int N = tscores.size(0);
int num_classes = tscores.size(1);

CAFFE_ENFORCE_EQ(N, tboxes.size(0));
CAFFE_ENFORCE_EQ(num_classes * box_dim, tboxes.size(1));
int num_boxes_classes = get_box_cls_index(num_classes - 1) + 1;
CAFFE_ENFORCE_EQ(num_boxes_classes * box_dim, tboxes.size(1));

int batch_size = 1;
vector<float> batch_splits_default(1, tscores.size(0));
Expand Down Expand Up @@ -82,12 +83,13 @@ bool BoxWithNMSLimitOp<CPUContext>::RunOnDevice() {
// skip j = 0, because it's the background class
int total_keep_count = 0;
for (int j = 1; j < num_classes; j++) {
auto cur_scores = scores.col(j);
auto cur_scores = scores.col(get_score_cls_index(j));
auto inds = utils::GetArrayIndices(cur_scores > score_thres_);
auto cur_boxes = boxes.block(0, j * box_dim, boxes.rows(), box_dim);
auto cur_boxes =
boxes.block(0, get_box_cls_index(j) * box_dim, boxes.rows(), box_dim);

if (soft_nms_enabled_) {
auto cur_soft_nms_scores = soft_nms_scores.col(j);
auto cur_soft_nms_scores = soft_nms_scores.col(get_score_cls_index(j));
keeps[j] = utils::soft_nms_cpu(
&cur_soft_nms_scores,
cur_boxes,
Expand Down Expand Up @@ -173,8 +175,9 @@ bool BoxWithNMSLimitOp<CPUContext>::RunOnDevice() {

int cur_out_idx = 0;
for (int j = 1; j < num_classes; j++) {
auto cur_scores = scores.col(j);
auto cur_boxes = boxes.block(0, j * box_dim, boxes.rows(), box_dim);
auto cur_scores = scores.col(get_score_cls_index(j));
auto cur_boxes =
boxes.block(0, get_box_cls_index(j) * box_dim, boxes.rows(), box_dim);
auto& cur_keep = keeps[j];
Eigen::Map<EArrXf> cur_out_scores(
out_scores->template mutable_data<float>() + cur_start_idx +
Expand All @@ -195,7 +198,8 @@ bool BoxWithNMSLimitOp<CPUContext>::RunOnDevice() {
utils::GetSubArrayRows(
cur_boxes, utils::AsEArrXt(cur_keep), &cur_out_boxes);
for (int k = 0; k < cur_keep.size(); k++) {
cur_out_classes[k] = static_cast<float>(j);
cur_out_classes[k] =
static_cast<float>(j - !output_classes_include_bg_cls_);
}

cur_out_idx += cur_keep.size();
Expand Down Expand Up @@ -309,7 +313,10 @@ C10_REGISTER_CAFFE2_OPERATOR_CPU(
"str soft_nms_method, "
"float soft_nms_sigma, "
"float soft_nms_min_score_thres, "
"bool rotated"
"bool rotated, "
"bool cls_agnostic_bbox_reg, "
"bool input_boxes_include_bg_cls, "
"bool output_classes_include_bg_cls "
") -> ("
"Tensor scores, "
"Tensor boxes, "
Expand Down
46 changes: 45 additions & 1 deletion caffe2/operators/box_with_nms_limit_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,26 @@ class BoxWithNMSLimitOp final : public Operator<Context> {
soft_nms_min_score_thres_(this->template GetSingleArgument<float>(
"soft_nms_min_score_thres",
0.001)),
rotated_(this->template GetSingleArgument<bool>("rotated", false)) {
rotated_(this->template GetSingleArgument<bool>("rotated", false)),
cls_agnostic_bbox_reg_(this->template GetSingleArgument<bool>(
"cls_agnostic_bbox_reg",
false)),
input_boxes_include_bg_cls_(this->template GetSingleArgument<bool>(
"input_boxes_include_bg_cls",
true)),
output_classes_include_bg_cls_(this->template GetSingleArgument<bool>(
"output_classes_include_bg_cls",
true)) {
CAFFE_ENFORCE(
soft_nms_method_str_ == "linear" || soft_nms_method_str_ == "gaussian",
"Unexpected soft_nms_method");
soft_nms_method_ = (soft_nms_method_str_ == "linear") ? 1 : 2;

// When input `boxes` doesn't inlcude background class, the score will skip
// background class and start with foreground classes directly, and put the
// background class in the end, i.e. score[:, 0:NUM_CLASSES-1] represents
// foreground classes and score[:,NUM_CLASSES] represents background class.
input_scores_fg_cls_starting_id_ = (int)input_boxes_include_bg_cls_;
}

~BoxWithNMSLimitOp() {}
Expand All @@ -65,6 +80,35 @@ class BoxWithNMSLimitOp final : public Operator<Context> {
// Set for RRPN case to handle rotated boxes. Inputs should be in format
// [ctr_x, ctr_y, width, height, angle (in degrees)].
bool rotated_{false};
// MODEL.ROI_BOX_HEAD.CLS_AGNOSTIC_BBOX_REG
bool cls_agnostic_bbox_reg_{false};
// Whether input `boxes` includes background class. If true, boxes will have
// shape of (N, (num_fg_class+1) * 4or5), otherwise (N, num_fg_class * 4or5)
bool input_boxes_include_bg_cls_{true};
// Whether output `classes` includes background class. If true, index 0 will
// represent background, and valid outputs start from 1.
bool output_classes_include_bg_cls_{true};
// The index where foreground starts in scoures. Eg. if 0 represents
// background class then foreground class starts with 1.
int input_scores_fg_cls_starting_id_{1};

// Map a class id (starting with background and then foreground) from (0, 1,
// ..., NUM_FG_CLASSES) to it's matching value in box
inline int get_box_cls_index(int bg_fg_cls_id) {
if (cls_agnostic_bbox_reg_) {
return 0;
} else if (!input_boxes_include_bg_cls_) {
return bg_fg_cls_id - 1;
} else {
return bg_fg_cls_id;
}
}

// Map a class id (starting with background and then foreground) from (0, 1,
// ..., NUM_FG_CLASSES) to it's matching value in score
inline int get_score_cls_index(int bg_fg_cls_id) {
return bg_fg_cls_id - 1 + input_scores_fg_cls_starting_id_;
}
};

} // namespace caffe2
Expand Down
36 changes: 33 additions & 3 deletions caffe2/python/operator_test/box_with_nms_limit_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,25 +120,55 @@ def ref(*args, **kwargs):

self.assertReferenceChecks(gc, op, [scores, boxes], ref)

@given(num_classes=st.integers(2, 10), **HU_CONFIG)
def test_multiclass(self, num_classes, gc):
@given(
num_classes=st.integers(2, 10),
cls_agnostic_bbox_reg=st.booleans(),
input_boxes_include_bg_cls=st.booleans(),
output_classes_include_bg_cls=st.booleans(),
**HU_CONFIG
)
def test_multiclass(
self,
num_classes,
cls_agnostic_bbox_reg,
input_boxes_include_bg_cls,
output_classes_include_bg_cls,
gc
):
in_centers = [(0, 0), (20, 20), (50, 50)]
in_scores = [0.7, 0.85, 0.6]
boxes, scores = gen_multiple_boxes(in_centers, in_scores, 10, num_classes)

if not input_boxes_include_bg_cls:
# remove backgound class
boxes = boxes[:, 4:]
if cls_agnostic_bbox_reg:
# only leave one class
boxes = boxes[:, :4]

gt_centers = [(20, 20), (0, 0), (50, 50)]
gt_scores = [0.85, 0.7, 0.6]
gt_boxes, gt_scores = gen_multiple_boxes(gt_centers, gt_scores, 1, 1)
# [1, 1, 1, 2, 2, 2, 3, 3, 3, ...]
gt_classes = np.tile(
np.array(range(1, num_classes), dtype=np.float32),
(gt_boxes.shape[0], 1)).T.flatten()
if not output_classes_include_bg_cls:
# remove backgound class
gt_classes -= 1
gt_boxes = np.tile(gt_boxes, (num_classes - 1, 1))
gt_scores = np.tile(gt_scores, (num_classes - 1, 1)).flatten()

op = get_op(
2, 3,
{"score_thresh": 0.5, "nms": 0.9, "detections_per_im": 100}
{
"score_thresh": 0.5,
"nms": 0.9,
"detections_per_im": 100,
"cls_agnostic_bbox_reg": cls_agnostic_bbox_reg,
"input_boxes_include_bg_cls": input_boxes_include_bg_cls,
"output_classes_include_bg_cls": output_classes_include_bg_cls
}
)

def ref(*args, **kwargs):
Expand Down
3 changes: 3 additions & 0 deletions caffe2/python/operator_test/torch_integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,9 @@ def box_with_nms_limit_ref():
soft_nms_sigma=0.5,
soft_nms_min_score_thres=0.001,
rotated=rotated,
cls_agnostic_bbox_reg=False,
input_boxes_include_bg_cls=True,
output_classes_include_bg_cls=True,
)

for o, o_ref in zip(outputs, output_refs):
Expand Down
3 changes: 3 additions & 0 deletions test/onnx/test_pytorch_onnx_caffe2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1513,6 +1513,9 @@ def forward(self, class_prob, pred_bbox, batch_splits):
soft_nms_sigma=0.5,
soft_nms_min_score_thres=0.001,
rotated=rotated,
cls_agnostic_bbox_reg=False,
input_boxes_include_bg_cls=True,
output_classes_include_bg_cls=True,
)
return a, b, c, d

Expand Down