From d4a21c05d752ace66aef5488356c1e304feae52b Mon Sep 17 00:00:00 2001 From: Tal Regev Date: Thu, 3 Mar 2022 23:10:41 +0200 Subject: [PATCH] Add efficientnet_fpn_backbone --- test/test_backbone_utils.py | 30 +++++++++- .../models/detection/backbone_utils.py | 57 ++++++++++++++++++- 2 files changed, 83 insertions(+), 4 deletions(-) diff --git a/test/test_backbone_utils.py b/test/test_backbone_utils.py index a55929d4b36..fcb58fd493e 100644 --- a/test/test_backbone_utils.py +++ b/test/test_backbone_utils.py @@ -6,8 +6,14 @@ import torch from common_utils import set_rng_seed from torchvision import models +from torchvision.models import efficientnet, mobilenet, resnet from torchvision.models._utils import IntermediateLayerGetter -from torchvision.models.detection.backbone_utils import mobilenet_backbone, resnet_fpn_backbone +from torchvision.models.detection.backbone_utils import ( + BackboneWithFPN, + efficientnet_fpn_backbone, + mobilenet_backbone, + resnet_fpn_backbone, +) from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names @@ -16,7 +22,7 @@ def get_available_models(): return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"] -@pytest.mark.parametrize("backbone_name", ("resnet18", "resnet50")) +@pytest.mark.parametrize("backbone_name", resnet.__all__[1:]) def test_resnet_fpn_backbone(backbone_name): x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device="cpu") y = resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False)(x) @@ -28,9 +34,11 @@ def test_resnet_fpn_backbone(backbone_name): resnet_fpn_backbone(backbone_name, False, returned_layers=[0, 1, 2, 3]) with pytest.raises(ValueError, match=r"Each returned layer should be in the range"): resnet_fpn_backbone(backbone_name, False, returned_layers=[2, 3, 4, 5]) + model = resnet_fpn_backbone(backbone_name, False) + assert isinstance(model, BackboneWithFPN) -@pytest.mark.parametrize("backbone_name", ("mobilenet_v2", "mobilenet_v3_large", "mobilenet_v3_small")) +@pytest.mark.parametrize("backbone_name", mobilenet.mv2_all[1:] + mobilenet.mv3_all[1:]) def test_mobilenet_backbone(backbone_name): with pytest.raises(ValueError, match=r"Trainable layers should be in the range"): mobilenet_backbone(backbone_name=backbone_name, pretrained=False, fpn=False, trainable_layers=-1) @@ -38,6 +46,22 @@ def test_mobilenet_backbone(backbone_name): mobilenet_backbone(backbone_name, False, fpn=True, returned_layers=[-1, 0, 1, 2]) with pytest.raises(ValueError, match=r"Each returned layer should be in the range"): mobilenet_backbone(backbone_name, False, fpn=True, returned_layers=[3, 4, 5, 6]) + model_fpn = mobilenet_backbone(backbone_name, False, fpn=True) + assert isinstance(model_fpn, BackboneWithFPN) + model = mobilenet_backbone(backbone_name, False, fpn=False) + assert isinstance(model, torch.nn.Sequential) + + +@pytest.mark.parametrize("backbone_name", efficientnet.__all__[1:]) +def test_efficientnet_fpn_backbone(backbone_name): + with pytest.raises(ValueError, match=r"Trainable layers should be in the range"): + efficientnet_fpn_backbone(backbone_name=backbone_name, pretrained=False, trainable_layers=-1) + with pytest.raises(ValueError, match=r"Each returned layer should be in the range"): + efficientnet_fpn_backbone(backbone_name, False, returned_layers=[-1, 0, 1, 2]) + with pytest.raises(ValueError, match=r"Each returned layer should be in the range"): + efficientnet_fpn_backbone(backbone_name, False, returned_layers=[3, 4, 5, 6, 9]) + model = efficientnet_fpn_backbone(backbone_name, False) + assert isinstance(model, BackboneWithFPN) # Needed by TestFxFeatureExtraction.test_leaf_module_and_function diff --git a/torchvision/models/detection/backbone_utils.py b/torchvision/models/detection/backbone_utils.py index 5ac5f179479..1ee93b5ec72 100644 --- a/torchvision/models/detection/backbone_utils.py +++ b/torchvision/models/detection/backbone_utils.py @@ -5,7 +5,7 @@ from torchvision.ops import misc as misc_nn_ops from torchvision.ops.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool -from .. import mobilenet, resnet +from .. import efficientnet, mobilenet, resnet from .._utils import IntermediateLayerGetter @@ -216,3 +216,58 @@ def _mobilenet_extractor( ) m.out_channels = out_channels # type: ignore[assignment] return m + + +def efficientnet_fpn_backbone( + backbone_name: str, + pretrained: bool, + norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d, + trainable_layers: int = 2, + returned_layers: Optional[List[int]] = None, + extra_blocks: ExtraFPNBlock = LastLevelMaxPool(), +) -> nn.Module: + if backbone_name in [ + "efficientnet_b5", + "efficientnet_b6", + "efficientnet_b7", + "efficientnet_v2_s", + "efficientnet_v2_m", + "efficientnet_v2_l", + ]: + backbone = efficientnet.__dict__[backbone_name](pretrained=pretrained) + else: + backbone = efficientnet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer) + return _efficientnet_extractor(backbone, trainable_layers, returned_layers, extra_blocks) + + +def _efficientnet_extractor( + backbone: efficientnet.EfficientNet, + trainable_layers: int, + returned_layers: Optional[List[int]] = None, + extra_blocks: ExtraFPNBlock = LastLevelMaxPool(), +) -> nn.Module: + backbone = backbone.features + # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. + # The first and last blocks are always included because they are the C0 (conv1) and Cn. + stage_indices = [i for i, b in enumerate(backbone) if getattr(b[0], "out_channels", False)] + num_stages = len(stage_indices) + + # find the index of the layer from which we wont freeze + if trainable_layers < 0 or trainable_layers > num_stages: + raise ValueError(f"Trainable layers should be in the range [0,{num_stages}], got {trainable_layers} ") + freeze_before = len(backbone) if trainable_layers == 0 else stage_indices[num_stages - trainable_layers] + + for b in backbone[:freeze_before]: + for parameter in b.parameters(): + parameter.requires_grad_(False) + + out_channels = 256 + + if returned_layers is None: + returned_layers = [num_stages - 2, num_stages - 1] + if min(returned_layers) < 0 or max(returned_layers) >= num_stages: + raise ValueError(f"Each returned layer should be in the range [0,{num_stages - 1}], got {returned_layers} ") + return_layers = {f"{stage_indices[k]}": str(v) for v, k in enumerate(returned_layers)} + + in_channels_list = [backbone[stage_indices[i]][0].out_channels for i in returned_layers] + return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks)