diff --git a/test/test_models_detection_utils.py b/test/test_models_detection_utils.py index a160113cbbf..09895057a9a 100644 --- a/test/test_models_detection_utils.py +++ b/test/test_models_detection_utils.py @@ -30,8 +30,8 @@ def test_box_linear_coder(self): proposals = torch.tensor([0, 0, 101, 101] * 10).reshape(10, 4).float() - rel_codes = box_coder.encode_single(boxes, proposals) - pred_boxes = box_coder.decode_single(rel_codes, boxes) + rel_codes = box_coder.encode(boxes, proposals) + pred_boxes = box_coder.decode(rel_codes, boxes) torch.allclose(proposals, pred_boxes) @pytest.mark.parametrize("train_layers, exp_froz_params", [(0, 53), (1, 43), (2, 24), (3, 11), (4, 1), (5, 0)]) diff --git a/torchvision/models/detection/_utils.py b/torchvision/models/detection/_utils.py index b484bbaaf3e..10d31852856 100644 --- a/torchvision/models/detection/_utils.py +++ b/torchvision/models/detection/_utils.py @@ -237,42 +237,10 @@ def __init__(self, normalize_by_size: bool = True) -> None: """ self.normalize_by_size = normalize_by_size - def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor: + def encode(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor: """ Encode a set of proposals with respect to some reference boxes - Args: - reference_boxes (Tensor): reference boxes - proposals (Tensor): boxes to be encoded - - Returns: - Tensor: the encoded relative box offsets that can be used to - decode the boxes. - """ - # get the center of reference_boxes - reference_boxes_ctr_x = 0.5 * (reference_boxes[:, 0] + reference_boxes[:, 2]) - reference_boxes_ctr_y = 0.5 * (reference_boxes[:, 1] + reference_boxes[:, 3]) - - # get box regression transformation deltas - target_l = reference_boxes_ctr_x - proposals[:, 0] - target_t = reference_boxes_ctr_y - proposals[:, 1] - target_r = proposals[:, 2] - reference_boxes_ctr_x - target_b = proposals[:, 3] - reference_boxes_ctr_y - - targets = torch.stack((target_l, target_t, target_r, target_b), dim=1) - if self.normalize_by_size: - reference_boxes_w = reference_boxes[:, 2] - reference_boxes[:, 0] - reference_boxes_h = reference_boxes[:, 3] - reference_boxes[:, 1] - reference_boxes_size = torch.stack( - (reference_boxes_w, reference_boxes_h, reference_boxes_w, reference_boxes_h), dim=1 - ) - targets = targets / reference_boxes_size - - return targets - - def encode_all(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor: - """ - vectorized version of `encode_single` Args: reference_boxes (Tensor): reference boxes proposals (Tensor): boxes to be encoded @@ -304,7 +272,8 @@ def encode_all(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor: targets = targets / reference_boxes_size return targets - def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor: + def decode(self, rel_codes: Tensor, boxes: Tensor) -> Tensor: + """ From a set of original boxes and encoded relative box offsets, get the decoded boxes. @@ -313,35 +282,6 @@ def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor: rel_codes (Tensor): encoded boxes boxes (Tensor): reference boxes. - Returns: - Tensor: the predicted boxes with the encoded relative box offsets. - """ - - boxes = boxes.to(rel_codes.dtype) - - ctr_x = 0.5 * (boxes[:, 0] + boxes[:, 2]) - ctr_y = 0.5 * (boxes[:, 1] + boxes[:, 3]) - if self.normalize_by_size: - boxes_w = boxes[:, 2] - boxes[:, 0] - boxes_h = boxes[:, 3] - boxes[:, 1] - boxes_size = torch.stack((boxes_w, boxes_h, boxes_w, boxes_h), dim=1) - rel_codes = rel_codes * boxes_size - - pred_boxes1 = ctr_x - rel_codes[:, 0] - pred_boxes2 = ctr_y - rel_codes[:, 1] - pred_boxes3 = ctr_x + rel_codes[:, 2] - pred_boxes4 = ctr_y + rel_codes[:, 3] - pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=1) - return pred_boxes - - def decode_all(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor: - """ - Vectorized version of `decode_single` method. - - Args: - rel_codes (Tensor): encoded boxes - boxes (List[Tensor]): List of reference boxes. - Returns: Tensor: the predicted boxes with the encoded relative box offsets. @@ -350,7 +290,7 @@ def decode_all(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor: """ - boxes = torch.stack(boxes).to(dtype=rel_codes.dtype) + boxes = boxes.to(dtype=rel_codes.dtype) ctr_x = 0.5 * (boxes[..., 0] + boxes[..., 2]) ctr_y = 0.5 * (boxes[..., 1] + boxes[..., 3]) diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index f95ee5b763f..73c9a6e042d 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -74,7 +74,13 @@ def compute_loss( all_gt_classes_targets.append(gt_classes_targets) all_gt_boxes_targets.append(gt_boxes_targets) - all_gt_classes_targets = torch.stack(all_gt_classes_targets) + # List[Tensor] to Tensor conversion of `all_gt_boxes_target`, `all_gt_classes_targets` and `anchors` + all_gt_boxes_targets, all_gt_classes_targets, anchors = ( + torch.stack(all_gt_boxes_targets), + torch.stack(all_gt_classes_targets), + torch.stack(anchors), + ) + # compute foregroud foregroud_mask = all_gt_classes_targets >= 0 num_foreground = foregroud_mask.sum().item() @@ -84,14 +90,10 @@ def compute_loss( gt_classes_targets[foregroud_mask, all_gt_classes_targets[foregroud_mask]] = 1.0 loss_cls = sigmoid_focal_loss(cls_logits, gt_classes_targets, reduction="sum") - # regression loss: GIoU loss - - pred_boxes = self.box_coder.decode_all(bbox_regression, anchors) - - # List[Tensor] to Tensor conversion of `all_gt_boxes_target` and `anchors` - all_gt_boxes_targets, anchors = torch.stack(all_gt_boxes_targets), torch.stack(anchors) - # amp issue: pred_boxes need to convert float + pred_boxes = self.box_coder.decode(bbox_regression, anchors) + + # regression loss: GIoU loss loss_bbox_reg = generalized_box_iou_loss( pred_boxes[foregroud_mask], all_gt_boxes_targets[foregroud_mask], @@ -100,7 +102,7 @@ def compute_loss( # ctrness loss - bbox_reg_targets = self.box_coder.encode_all(anchors, all_gt_boxes_targets) + bbox_reg_targets = self.box_coder.encode(anchors, all_gt_boxes_targets) if len(bbox_reg_targets) == 0: gt_ctrness_targets = bbox_reg_targets.new_zeros(bbox_reg_targets.size()[:-1]) @@ -522,7 +524,7 @@ def postprocess_detections( anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode="floor") labels_per_level = topk_idxs % num_classes - boxes_per_level = self.box_coder.decode_single( + boxes_per_level = self.box_coder.decode( box_regression_per_level[anchor_idxs], anchors_per_level[anchor_idxs] ) boxes_per_level = box_ops.clip_boxes_to_image(boxes_per_level, image_shape)