Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 13 additions & 22 deletions torchvision/models/detection/rpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@


@torch.jit.unused
def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n):
# type: (Tensor, int) -> Tuple[int, int]
def _onnx_get_num_anchors_and_pre_nms_top_n(ob: Tensor, orig_pre_nms_top_n: int) -> Tuple[int, int]:
from torch.onnx import operators

num_anchors = operators.shape_as_tensor(ob)[1].unsqueeze(0)
Expand All @@ -35,16 +34,15 @@ class RPNHead(nn.Module):

def __init__(self, in_channels, num_anchors):
super(RPNHead, self).__init__()
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
self.conv = nn.Conv2d(in_channels, num_anchors, kernel_size=3, stride=1, padding=1)
self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1)
self.bbox_pred = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=1, stride=1)

for layer in self.children():
torch.nn.init.normal_(layer.weight, std=0.01)
torch.nn.init.constant_(layer.bias, 0)

def forward(self, x):
# type: (List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
def forward(self, x: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
logits = []
bbox_reg = []
for feature in x:
Expand All @@ -54,16 +52,14 @@ def forward(self, x):
return logits, bbox_reg


def permute_and_flatten(layer, N, A, C, H, W):
# type: (Tensor, int, int, int, int, int) -> Tensor
def permute_and_flatten(layer: Tensor, N: int, A: int, C: int, H: int, W: int) -> Tensor:
layer = layer.view(N, -1, C, H, W)
layer = layer.permute(0, 3, 4, 1, 2)
layer = layer.reshape(N, -1, C)
return layer


def concat_box_prediction_layers(box_cls, box_regression):
# type: (List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
def concat_box_prediction_layers(box_cls: List[Tensor], box_regression: List[Tensor]) -> Tuple[Tensor, Tensor]:
box_cls_flattened = []
box_regression_flattened = []
# for each feature level, permute the outputs to make them be in the
Expand Down Expand Up @@ -169,8 +165,7 @@ def post_nms_top_n(self):
return self._post_nms_top_n["training"]
return self._post_nms_top_n["testing"]

def assign_targets_to_anchors(self, anchors, targets):
# type: (List[Tensor], List[Dict[str, Tensor]]) -> Tuple[List[Tensor], List[Tensor]]
def assign_targets_to_anchors(self, anchors: List[Tensor], targets: List[Dict[str, Tensor]]) -> Tuple[List[Tensor], List[Tensor]]:
labels = []
matched_gt_boxes = []
for anchors_per_image, targets_per_image in zip(anchors, targets):
Expand Down Expand Up @@ -205,8 +200,7 @@ def assign_targets_to_anchors(self, anchors, targets):
matched_gt_boxes.append(matched_gt_boxes_per_image)
return labels, matched_gt_boxes

def _get_top_n_idx(self, objectness, num_anchors_per_level):
# type: (Tensor, List[int]) -> Tensor
def _get_top_n_idx(self, objectness: Tensor, num_anchors_per_level: List[int]) -> Tensor:
r = []
offset = 0
for ob in objectness.split(num_anchors_per_level, 1):
Expand All @@ -220,8 +214,7 @@ def _get_top_n_idx(self, objectness, num_anchors_per_level):
offset += num_anchors
return torch.cat(r, dim=1)

def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_level):
# type: (Tensor, Tensor, List[Tuple[int, int]], List[int]) -> Tuple[List[Tensor], List[Tensor]]
def filter_proposals(self, proposals: Tensor, objectness: Tensor, image_shapes: List[Tuple[int, int]], num_anchors_per_level: List[int]) -> Tuple[List[Tensor], List[Tensor]]:
num_images = proposals.shape[0]
device = proposals.device
# do not backprop through objectness
Expand Down Expand Up @@ -271,8 +264,7 @@ def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_
final_scores.append(scores)
return final_boxes, final_scores

def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets):
# type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
def compute_loss(self, objectness: Tensor, pred_bbox_deltas: Tensor, labels: List[Tensor], regression_targets: List[Tensor]) -> Tuple[Tensor, Tensor]:
"""
Args:
objectness (Tensor)
Expand Down Expand Up @@ -312,11 +304,10 @@ def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets)

def forward(
self,
images, # type: ImageList
features, # type: Dict[str, Tensor]
targets=None, # type: Optional[List[Dict[str, Tensor]]]
):
# type: (...) -> Tuple[List[Tensor], Dict[str, Tensor]]
images: ImageList,
features: Dict[str, Tensor],
targets: Optional[List[Dict[str, Tensor]]] = None,
) -> Tuple[List[Tensor], Dict[str, Tensor]]:
"""
Args:
images (ImageList): images for which we want to compute the predictions
Expand Down