Skip to content

Commit

Permalink
implement inference code in box_predictor
Browse files Browse the repository at this point in the history
Summary: keep some methods in FastRCNNOutputs since they are still used

Reviewed By: rbgirshick

Differential Revision: D21073561

fbshipit-source-id: dd55f485eb576f775453227775cde27332d1c189
  • Loading branch information
ppwwyyxx authored and facebook-github-bot committed Apr 22, 2020
1 parent 273e3b9 commit 5d967ed
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 62 deletions.
26 changes: 26 additions & 0 deletions detectron2/modeling/box_regression.py
Expand Up @@ -12,6 +12,32 @@
__all__ = ["Box2BoxTransform", "Box2BoxTransformRotated"]


def apply_deltas_broadcast(box2box_transform, deltas, boxes):
"""
Apply transform deltas to boxes. Similar to `box2box_transform.apply_deltas`,
but allow broadcasting boxes when the second dimension of deltas is a multiple
of box dimension.
Args:
box2box_transform (Box2BoxTransform or Box2BoxTransformRotated): the transform to apply
deltas (Tensor): NxB or Nx(KxB)
boxes (Tensor): NxB
Returns:
Tensor: same shape as deltas.
"""
assert deltas.dim() == boxes.dim() == 2, f"{deltas.shape}, {boxes.shape}"
N, B = boxes.shape
assert (
deltas.shape[1] % B == 0
), f"Second dim of deltas should be a multiple of {B}. Got {deltas.shape}"
K = deltas.shape[1] // B
ret = box2box_transform.apply_deltas(
deltas.view(N * K, B), boxes.unsqueeze(1).expand(N, K, B).reshape(N * K, B)
)
return ret.view(N, K * B)


@torch.jit.script
class Box2BoxTransform(object):
"""
Expand Down
144 changes: 82 additions & 62 deletions detectron2/modeling/roi_heads/fast_rcnn.py
Expand Up @@ -7,10 +7,13 @@

from detectron2.config import configurable
from detectron2.layers import Linear, ShapeSpec, batched_nms, cat
from detectron2.modeling.box_regression import Box2BoxTransform
from detectron2.modeling.box_regression import Box2BoxTransform, apply_deltas_broadcast
from detectron2.structures import Boxes, Instances
from detectron2.utils.events import get_event_storage

__all__ = ["fast_rcnn_inference", "FastRCNNOutputLayers"]


logger = logging.getLogger(__name__)

"""
Expand Down Expand Up @@ -48,10 +51,10 @@ def fast_rcnn_inference(boxes, scores, image_shapes, score_thresh, nms_thresh, t
boxes for each image. Element i has shape (Ri, K * 4) if doing
class-specific regression, or (Ri, 4) if doing class-agnostic
regression, where Ri is the number of predicted objects for image i.
This is compatible with the output of :meth:`FastRCNNOutputs.predict_boxes`.
This is compatible with the output of :meth:`FastRCNNOutputLayers.predict_boxes`.
scores (list[Tensor]): A list of Tensors of predicted class scores for each image.
Element i has shape (Ri, K + 1), where Ri is the number of predicted objects
for image i. Compatible with the output of :meth:`FastRCNNOutputs.predict_probs`.
for image i. Compatible with the output of :meth:`FastRCNNOutputLayers.predict_probs`.
image_shapes (list[tuple]): A list of (width, height) tuples for each image in the batch.
score_thresh (float): Only return detections with a confidence score exceeding this
threshold.
Expand Down Expand Up @@ -216,6 +219,7 @@ def softmax_cross_entropy_loss(self):
scalar Tensor
"""
if self._no_instances:
# TODO 0.0 * pred.sum() is enough since PT1.6
return 0.0 * F.cross_entropy(
self.pred_class_logits,
torch.zeros(0, dtype=torch.long, device=self.pred_class_logits.device),
Expand All @@ -233,6 +237,7 @@ def smooth_l1_loss(self):
scalar Tensor
"""
if self._no_instances:
# TODO 0.0 * pred.sum() is enough since PT1.6
return 0.0 * smooth_l1_loss(
self.pred_proposal_deltas,
torch.zeros_like(self.pred_proposal_deltas),
Expand Down Expand Up @@ -295,14 +300,9 @@ def _predict_boxes(self):
for all images in a batch. Element i has shape (Ri, K * B) or (Ri, B), where Ri is
the number of predicted objects for image i and B is the box dimension (4 or 5)
"""
num_pred = len(self.proposals)
B = self.proposals.tensor.shape[1]
K = self.pred_proposal_deltas.shape[1] // B
boxes = self.box2box_transform.apply_deltas(
self.pred_proposal_deltas.view(num_pred * K, B),
self.proposals.tensor.unsqueeze(1).expand(num_pred, K, B).reshape(-1, B),
return apply_deltas_broadcast(
self.box2box_transform, self.pred_proposal_deltas, self.proposals.tensor
)
return boxes.view(num_pred, K * B)

"""
A subclass is expected to have the following methods because
Expand All @@ -324,58 +324,24 @@ def losses(self):

def predict_boxes(self):
"""
Returns:
list[Tensor]: A list of Tensors of predicted class-specific or class-agnostic boxes
for each image. Element i has shape (Ri, K * B) or (Ri, B), where Ri is
the number of predicted objects for image i and B is the box dimension (4 or 5)
Deprecated
"""
return self._predict_boxes().split(self.num_preds_per_image, dim=0)

def predict_boxes_for_gt_classes(self):
"""
Returns:
list[Tensor]: A list of Tensors of predicted boxes for GT classes in case of
class-specific box head. Element i of the list has shape (Ri, B), where Ri is
the number of predicted objects for image i and B is the box dimension (4 or 5)
"""
predicted_boxes = self._predict_boxes()
B = self.proposals.tensor.shape[1]
# If the box head is class-agnostic, then the method is equivalent to `predicted_boxes`.
if predicted_boxes.shape[1] > B:
num_pred = len(self.proposals)
num_classes = predicted_boxes.shape[1] // B
# Some proposals are ignored or have a background class. Their gt_classes
# cannot be used as index.
gt_classes = torch.clamp(self.gt_classes, 0, num_classes - 1)
predicted_boxes = predicted_boxes.view(num_pred, num_classes, B)[
torch.arange(num_pred, dtype=torch.long, device=predicted_boxes.device), gt_classes
]
return predicted_boxes.split(self.num_preds_per_image, dim=0)

def predict_probs(self):
"""
Returns:
list[Tensor]: A list of Tensors of predicted class probabilities for each image.
Element i has shape (Ri, K + 1), where Ri is the number of predicted objects
for image i.
Deprecated
"""
probs = F.softmax(self.pred_class_logits, dim=-1)
return probs.split(self.num_preds_per_image, dim=0)

def inference(self, score_thresh, nms_thresh, topk_per_image):
"""
Args:
score_thresh (float): same as fast_rcnn_inference.
nms_thresh (float): same as fast_rcnn_inference.
topk_per_image (int): same as fast_rcnn_inference.
Returns:
list[Instances]: same as fast_rcnn_inference.
list[Tensor]: same as fast_rcnn_inference.
Deprecated
"""
boxes = self.predict_boxes()
scores = self.predict_probs()
image_shapes = self.image_shapes

return fast_rcnn_inference(
boxes, scores, image_shapes, score_thresh, nms_thresh, topk_per_image
)
Expand Down Expand Up @@ -477,25 +443,79 @@ def losses(self, predictions, proposals):
).losses()

def inference(self, predictions, proposals):
scores, proposal_deltas = predictions
return FastRCNNOutputs(
self.box2box_transform, scores, proposal_deltas, proposals, self.smooth_l1_beta
).inference(self.test_score_thresh, self.test_nms_thresh, self.test_topk_per_image)
"""
Returns:
list[Instances]: same as `fast_rcnn_inference`.
list[Tensor]: same as `fast_rcnn_inference`.
"""
boxes = self.predict_boxes(predictions, proposals)
scores = self.predict_probs(predictions, proposals)
image_shapes = [x.image_size for x in proposals]
return fast_rcnn_inference(
boxes,
scores,
image_shapes,
self.test_score_thresh,
self.test_nms_thresh,
self.test_topk_per_image,
)

def predict_boxes_for_gt_classes(self, predictions, proposals):
"""
Returns:
list[Tensor]: A list of Tensors of predicted boxes for GT classes in case of
class-specific box head. Element i of the list has shape (Ri, B), where Ri is
the number of predicted objects for image i and B is the box dimension (4 or 5)
"""
if not len(proposals):
return []
scores, proposal_deltas = predictions
return FastRCNNOutputs(
self.box2box_transform, scores, proposal_deltas, proposals, self.smooth_l1_beta
).predict_boxes_for_gt_classes()
proposal_boxes = [p.proposal_boxes for p in proposals]
proposal_boxes = proposal_boxes[0].cat(proposal_boxes).tensor
N, B = proposal_boxes.shape
predict_boxes = apply_deltas_broadcast(
self.box2box_transform, proposal_deltas, proposal_boxes
) # Nx(KxB)

K = predict_boxes.shape[1] // B
if K > 1:
gt_classes = torch.cat([p.gt_classes for p in proposals], dim=0)
# Some proposals are ignored or have a background class. Their gt_classes
# cannot be used as index.
gt_classes = gt_classes.clamp_(0, K - 1)

predict_boxes = predict_boxes.view(N, K, B)[
torch.arange(N, dtype=torch.long, device=predict_boxes.device), gt_classes
]
num_prop_per_image = [len(p) for p in proposals]
return predict_boxes.split(num_prop_per_image)

def predict_boxes(self, predictions, proposals):
scores, proposal_deltas = predictions
return FastRCNNOutputs(
self.box2box_transform, scores, proposal_deltas, proposals, self.smooth_l1_beta
).predict_boxes()
"""
Returns:
list[Tensor]: A list of Tensors of predicted class-specific or class-agnostic boxes
for each image. Element i has shape (Ri, K * B) or (Ri, B), where Ri is
the number of predicted objects for image i and B is the box dimension (4 or 5)
"""
if not len(proposals):
return []
_, proposal_deltas = predictions
num_prop_per_image = [len(p) for p in proposals]
proposal_boxes = [p.proposal_boxes for p in proposals]
proposal_boxes = proposal_boxes[0].cat(proposal_boxes).tensor
predict_boxes = apply_deltas_broadcast(
self.box2box_transform, proposal_deltas, proposal_boxes
) # Nx(KxB)
return predict_boxes.split(num_prop_per_image)

def predict_probs(self, predictions, proposals):
scores, proposal_deltas = predictions
return FastRCNNOutputs(
self.box2box_transform, scores, proposal_deltas, proposals, self.smooth_l1_beta
).predict_probs()
"""
Returns:
list[Tensor]: A list of Tensors of predicted class probabilities for each image.
Element i has shape (Ri, K + 1), where Ri is the number of predicted objects
for image i.
"""
scores, _ = predictions
num_inst_per_image = [len(p) for p in proposals]
probs = F.softmax(scores, dim=-1)
return probs.split(num_inst_per_image, dim=0)

0 comments on commit 5d967ed

Please sign in to comment.