Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions test/test_models_detection_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def test_box_linear_coder(self):

proposals = torch.tensor([0, 0, 101, 101] * 10).reshape(10, 4).float()

rel_codes = box_coder.encode_single(boxes, proposals)
pred_boxes = box_coder.decode_single(rel_codes, boxes)
rel_codes = box_coder.encode(boxes, proposals)
pred_boxes = box_coder.decode(rel_codes, boxes)
torch.allclose(proposals, pred_boxes)

@pytest.mark.parametrize("train_layers, exp_froz_params", [(0, 53), (1, 43), (2, 24), (3, 11), (4, 1), (5, 0)])
Expand Down
68 changes: 4 additions & 64 deletions torchvision/models/detection/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,42 +237,10 @@ def __init__(self, normalize_by_size: bool = True) -> None:
"""
self.normalize_by_size = normalize_by_size

def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
def encode(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
"""
Encode a set of proposals with respect to some reference boxes

Args:
reference_boxes (Tensor): reference boxes
proposals (Tensor): boxes to be encoded

Returns:
Tensor: the encoded relative box offsets that can be used to
decode the boxes.
"""
# get the center of reference_boxes
reference_boxes_ctr_x = 0.5 * (reference_boxes[:, 0] + reference_boxes[:, 2])
reference_boxes_ctr_y = 0.5 * (reference_boxes[:, 1] + reference_boxes[:, 3])

# get box regression transformation deltas
target_l = reference_boxes_ctr_x - proposals[:, 0]
target_t = reference_boxes_ctr_y - proposals[:, 1]
target_r = proposals[:, 2] - reference_boxes_ctr_x
target_b = proposals[:, 3] - reference_boxes_ctr_y

targets = torch.stack((target_l, target_t, target_r, target_b), dim=1)
if self.normalize_by_size:
reference_boxes_w = reference_boxes[:, 2] - reference_boxes[:, 0]
reference_boxes_h = reference_boxes[:, 3] - reference_boxes[:, 1]
reference_boxes_size = torch.stack(
(reference_boxes_w, reference_boxes_h, reference_boxes_w, reference_boxes_h), dim=1
)
targets = targets / reference_boxes_size

return targets

def encode_all(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
"""
vectorized version of `encode_single`
Args:
reference_boxes (Tensor): reference boxes
proposals (Tensor): boxes to be encoded
Expand Down Expand Up @@ -304,7 +272,8 @@ def encode_all(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
targets = targets / reference_boxes_size
return targets

def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
def decode(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:

"""
From a set of original boxes and encoded relative box offsets,
get the decoded boxes.
Expand All @@ -313,35 +282,6 @@ def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
rel_codes (Tensor): encoded boxes
boxes (Tensor): reference boxes.

Returns:
Tensor: the predicted boxes with the encoded relative box offsets.
"""

boxes = boxes.to(rel_codes.dtype)

ctr_x = 0.5 * (boxes[:, 0] + boxes[:, 2])
ctr_y = 0.5 * (boxes[:, 1] + boxes[:, 3])
if self.normalize_by_size:
boxes_w = boxes[:, 2] - boxes[:, 0]
boxes_h = boxes[:, 3] - boxes[:, 1]
boxes_size = torch.stack((boxes_w, boxes_h, boxes_w, boxes_h), dim=1)
rel_codes = rel_codes * boxes_size

pred_boxes1 = ctr_x - rel_codes[:, 0]
pred_boxes2 = ctr_y - rel_codes[:, 1]
pred_boxes3 = ctr_x + rel_codes[:, 2]
pred_boxes4 = ctr_y + rel_codes[:, 3]
pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=1)
return pred_boxes

def decode_all(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor:
"""
Vectorized version of `decode_single` method.

Args:
rel_codes (Tensor): encoded boxes
boxes (List[Tensor]): List of reference boxes.

Returns:
Tensor: the predicted boxes with the encoded relative box offsets.

Expand All @@ -350,7 +290,7 @@ def decode_all(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor:

"""

boxes = torch.stack(boxes).to(dtype=rel_codes.dtype)
boxes = boxes.to(dtype=rel_codes.dtype)

ctr_x = 0.5 * (boxes[..., 0] + boxes[..., 2])
ctr_y = 0.5 * (boxes[..., 1] + boxes[..., 3])
Expand Down
22 changes: 12 additions & 10 deletions torchvision/models/detection/fcos.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,13 @@ def compute_loss(
all_gt_classes_targets.append(gt_classes_targets)
all_gt_boxes_targets.append(gt_boxes_targets)

all_gt_classes_targets = torch.stack(all_gt_classes_targets)
# List[Tensor] to Tensor conversion of `all_gt_boxes_target`, `all_gt_classes_targets` and `anchors`
all_gt_boxes_targets, all_gt_classes_targets, anchors = (
torch.stack(all_gt_boxes_targets),
torch.stack(all_gt_classes_targets),
torch.stack(anchors),
)

# compute foregroud
foregroud_mask = all_gt_classes_targets >= 0
num_foreground = foregroud_mask.sum().item()
Expand All @@ -84,14 +90,10 @@ def compute_loss(
gt_classes_targets[foregroud_mask, all_gt_classes_targets[foregroud_mask]] = 1.0
loss_cls = sigmoid_focal_loss(cls_logits, gt_classes_targets, reduction="sum")

# regression loss: GIoU loss

pred_boxes = self.box_coder.decode_all(bbox_regression, anchors)

# List[Tensor] to Tensor conversion of `all_gt_boxes_target` and `anchors`
all_gt_boxes_targets, anchors = torch.stack(all_gt_boxes_targets), torch.stack(anchors)

# amp issue: pred_boxes need to convert float
pred_boxes = self.box_coder.decode(bbox_regression, anchors)

# regression loss: GIoU loss
loss_bbox_reg = generalized_box_iou_loss(
pred_boxes[foregroud_mask],
all_gt_boxes_targets[foregroud_mask],
Expand All @@ -100,7 +102,7 @@ def compute_loss(

# ctrness loss

bbox_reg_targets = self.box_coder.encode_all(anchors, all_gt_boxes_targets)
bbox_reg_targets = self.box_coder.encode(anchors, all_gt_boxes_targets)

if len(bbox_reg_targets) == 0:
gt_ctrness_targets = bbox_reg_targets.new_zeros(bbox_reg_targets.size()[:-1])
Expand Down Expand Up @@ -522,7 +524,7 @@ def postprocess_detections(
anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode="floor")
labels_per_level = topk_idxs % num_classes

boxes_per_level = self.box_coder.decode_single(
boxes_per_level = self.box_coder.decode(
box_regression_per_level[anchor_idxs], anchors_per_level[anchor_idxs]
)
boxes_per_level = box_ops.clip_boxes_to_image(boxes_per_level, image_shape)
Expand Down