From 92580540d3d34e9eeee380958314b9923bc66612 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 7 Oct 2021 13:58:11 +0000 Subject: [PATCH 1/2] Use FX feature extractor for segm model --- test/test_models.py | 4 +++- .../models/segmentation/segmentation.py | 21 +++++++++++++++---- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/test/test_models.py b/test/test_models.py index 5e5b3429778..5ad46181a9c 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -470,12 +470,14 @@ def test_classification_model(model_name, dev): @pytest.mark.parametrize("model_name", get_available_segmentation_models()) @pytest.mark.parametrize("dev", cpu_and_gpu()) -def test_segmentation_model(model_name, dev): +@pytest.mark.parametrize("use_fe", [True, False]) +def test_segmentation_model(model_name, dev, use_fe): set_rng_seed(0) defaults = { "num_classes": 10, "pretrained_backbone": False, "input_shape": (1, 3, 32, 32), + "use_fe": use_fe, } kwargs = {**defaults, **_model_params.get(model_name, {})} input_shape = kwargs.pop("input_shape") diff --git a/torchvision/models/segmentation/segmentation.py b/torchvision/models/segmentation/segmentation.py index d5223842010..ac79b46698c 100644 --- a/torchvision/models/segmentation/segmentation.py +++ b/torchvision/models/segmentation/segmentation.py @@ -6,6 +6,7 @@ from .. import mobilenetv3 from .. import resnet from .._utils import IntermediateLayerGetter +from ..feature_extraction import create_feature_extractor from .deeplabv3 import DeepLabHead, DeepLabV3 from .fcn import FCN, FCNHead from .lraspp import LRASPP @@ -32,7 +33,8 @@ def _segm_model( - name: str, backbone_name: str, num_classes: int, aux: Optional[bool], pretrained_backbone: bool = True + name: str, backbone_name: str, num_classes: int, aux: Optional[bool], + pretrained_backbone: bool = True, use_fe: bool = False, ) -> nn.Module: if "resnet" in backbone_name: backbone = resnet.__dict__[backbone_name]( @@ -60,7 +62,11 @@ def _segm_model( return_layers = {out_layer: "out"} if aux: return_layers[aux_layer] = "aux" - backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) + + if use_fe: + backbone = create_feature_extractor(backbone, return_layers) + else: + backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) aux_classifier = None if aux: @@ -105,7 +111,10 @@ def _load_weights(model: nn.Module, arch_type: str, backbone: str, progress: boo model.load_state_dict(state_dict) -def _segm_lraspp_mobilenetv3(backbone_name: str, num_classes: int, pretrained_backbone: bool = True) -> LRASPP: +def _segm_lraspp_mobilenetv3( + backbone_name: str, num_classes: int, + pretrained_backbone: bool = True, use_fe: bool = False +) -> LRASPP: backbone = mobilenetv3.__dict__[backbone_name](pretrained=pretrained_backbone, dilated=True).features # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. @@ -116,7 +125,11 @@ def _segm_lraspp_mobilenetv3(backbone_name: str, num_classes: int, pretrained_ba low_channels = backbone[low_pos].out_channels high_channels = backbone[high_pos].out_channels - backbone = IntermediateLayerGetter(backbone, return_layers={str(low_pos): "low", str(high_pos): "high"}) + return_layers = {str(low_pos): "low", str(high_pos): "high"} + if use_fe: + backbone = create_feature_extractor(backbone, return_layers) + else: + backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) model = LRASPP(backbone, low_channels, high_channels, num_classes) return model From 21b0ff8f9093d0c1c654299cc3108327fb8e4763 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 7 Oct 2021 15:16:41 +0000 Subject: [PATCH 2/2] Removed use_fe option --- test/test_models.py | 4 +--- .../models/segmentation/segmentation.py | 21 ++++--------------- 2 files changed, 5 insertions(+), 20 deletions(-) diff --git a/test/test_models.py b/test/test_models.py index 5ad46181a9c..5e5b3429778 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -470,14 +470,12 @@ def test_classification_model(model_name, dev): @pytest.mark.parametrize("model_name", get_available_segmentation_models()) @pytest.mark.parametrize("dev", cpu_and_gpu()) -@pytest.mark.parametrize("use_fe", [True, False]) -def test_segmentation_model(model_name, dev, use_fe): +def test_segmentation_model(model_name, dev): set_rng_seed(0) defaults = { "num_classes": 10, "pretrained_backbone": False, "input_shape": (1, 3, 32, 32), - "use_fe": use_fe, } kwargs = {**defaults, **_model_params.get(model_name, {})} input_shape = kwargs.pop("input_shape") diff --git a/torchvision/models/segmentation/segmentation.py b/torchvision/models/segmentation/segmentation.py index ac79b46698c..c19e36e4705 100644 --- a/torchvision/models/segmentation/segmentation.py +++ b/torchvision/models/segmentation/segmentation.py @@ -5,7 +5,6 @@ from ..._internally_replaced_utils import load_state_dict_from_url from .. import mobilenetv3 from .. import resnet -from .._utils import IntermediateLayerGetter from ..feature_extraction import create_feature_extractor from .deeplabv3 import DeepLabHead, DeepLabV3 from .fcn import FCN, FCNHead @@ -33,8 +32,7 @@ def _segm_model( - name: str, backbone_name: str, num_classes: int, aux: Optional[bool], - pretrained_backbone: bool = True, use_fe: bool = False, + name: str, backbone_name: str, num_classes: int, aux: Optional[bool], pretrained_backbone: bool = True ) -> nn.Module: if "resnet" in backbone_name: backbone = resnet.__dict__[backbone_name]( @@ -62,11 +60,7 @@ def _segm_model( return_layers = {out_layer: "out"} if aux: return_layers[aux_layer] = "aux" - - if use_fe: - backbone = create_feature_extractor(backbone, return_layers) - else: - backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) + backbone = create_feature_extractor(backbone, return_layers) aux_classifier = None if aux: @@ -111,10 +105,7 @@ def _load_weights(model: nn.Module, arch_type: str, backbone: str, progress: boo model.load_state_dict(state_dict) -def _segm_lraspp_mobilenetv3( - backbone_name: str, num_classes: int, - pretrained_backbone: bool = True, use_fe: bool = False -) -> LRASPP: +def _segm_lraspp_mobilenetv3(backbone_name: str, num_classes: int, pretrained_backbone: bool = True) -> LRASPP: backbone = mobilenetv3.__dict__[backbone_name](pretrained=pretrained_backbone, dilated=True).features # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. @@ -125,11 +116,7 @@ def _segm_lraspp_mobilenetv3( low_channels = backbone[low_pos].out_channels high_channels = backbone[high_pos].out_channels - return_layers = {str(low_pos): "low", str(high_pos): "high"} - if use_fe: - backbone = create_feature_extractor(backbone, return_layers) - else: - backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) + backbone = create_feature_extractor(backbone, {str(low_pos): "low", str(high_pos): "high"}) model = LRASPP(backbone, low_channels, high_channels, num_classes) return model