Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RetinaNet FPN with Low Resolution #3248

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ Network box AP mask AP keypoint AP
================================ ======= ======== ===========
Faster R-CNN ResNet-50 FPN 37.0 - -
RetinaNet ResNet-50 FPN 36.4 - -
RetinaNet MobileNetV3-Large FPN 25.6 - -
RetinaNet MobileNetV3-Large FPN ???? - -
Mask R-CNN ResNet-50 FPN 37.9 34.6 -
================================ ======= ======== ===========

Expand Down Expand Up @@ -420,7 +420,7 @@ Network train time (s / it) test time (s / it) memory
============================== =================== ================== ===========
Faster R-CNN ResNet-50 FPN 0.2288 0.0590 5.2
RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1
RetinaNet MobileNetV3-Large FPN 0.0928 0.0547 1.4
RetinaNet MobileNetV3-Large FPN ?????? ?????? ???
Mask R-CNN ResNet-50 FPN 0.2728 0.0903 5.4
Keypoint R-CNN ResNet-50 FPN 0.3789 0.1242 6.8
============================== =================== ================== ===========
Expand Down
Binary file not shown.
19 changes: 10 additions & 9 deletions torchvision/models/detection/retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,8 +559,7 @@ def forward(self, images, targets=None):

# TODO: replace with pytorch links
model_urls = {
'retinanet_mobilenet_v3_large_fpn_coco':
'https://download.pytorch.org/models/retinanet_mobilenet_v3_large_fpn-41c847a4.pth',
'retinanet_mobilenet_v3_large_fpn_coco': None,
'retinanet_resnet50_fpn_coco':
'https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth',
}
Expand Down Expand Up @@ -628,7 +627,7 @@ def retinanet_resnet50_fpn(pretrained=False, progress=True,


def retinanet_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True,
trainable_backbone_layers=None, **kwargs):
trainable_backbone_layers=None, min_size=320, max_size=640, **kwargs):
"""
Constructs a RetinaNet model with a MobileNetV3-Large-FPN backbone. It works similarly
to RetinaNet with ResNet-50-FPN backbone. See `retinanet_resnet50_fpn` for more details.
Expand All @@ -647,22 +646,24 @@ def retinanet_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_classe
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable.
min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
"""
# check default parameters and by default set it to 3 if possible
# check default parameters and by default set it to 6 if possible
trainable_backbone_layers = _validate_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3)
pretrained or pretrained_backbone, trainable_backbone_layers, 6, 6)

if pretrained:
pretrained_backbone = False
backbone = mobilenet_fpn_backbone("mobilenet_v3_large", pretrained_backbone, returned_layers=[4, 5],
trainable_layers=trainable_backbone_layers)

anchor_sizes = ((128,), (256,), (512,))
anchor_sizes = ((16, 32, 64, 128, 256,), ) * 3
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)

model = RetinaNet(backbone, num_classes, anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs)
model = RetinaNet(backbone, num_classes, anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios),
min_size=min_size, max_size=max_size, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls['retinanet_mobilenet_v3_large_fpn_coco'],
progress=progress)
state_dict = load_state_dict_from_url(model_urls['retinanet_mobilenet_v3_large_fpn_coco'], progress=progress)
model.load_state_dict(state_dict)
return model