From 160e564a2a04efd06f218f2f3776aa2b332f95e9 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 27 Jan 2022 15:49:25 +0000 Subject: [PATCH] Add IntermediateLayerGetter on segmentation. --- torchvision/models/segmentation/deeplabv3.py | 6 +++--- torchvision/models/segmentation/fcn.py | 4 ++-- torchvision/models/segmentation/lraspp.py | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/torchvision/models/segmentation/deeplabv3.py b/torchvision/models/segmentation/deeplabv3.py index a9287f2f724..15ab5fffa5e 100644 --- a/torchvision/models/segmentation/deeplabv3.py +++ b/torchvision/models/segmentation/deeplabv3.py @@ -6,7 +6,7 @@ from .. import mobilenetv3 from .. import resnet -from ..feature_extraction import create_feature_extractor +from .._utils import IntermediateLayerGetter from ._utils import _SimpleSegmentationModel, _load_weights from .fcn import FCNHead @@ -121,7 +121,7 @@ def _deeplabv3_resnet( return_layers = {"layer4": "out"} if aux: return_layers["layer3"] = "aux" - backbone = create_feature_extractor(backbone, return_layers) + backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) aux_classifier = FCNHead(1024, num_classes) if aux else None classifier = DeepLabHead(2048, num_classes) @@ -144,7 +144,7 @@ def _deeplabv3_mobilenetv3( return_layers = {str(out_pos): "out"} if aux: return_layers[str(aux_pos)] = "aux" - backbone = create_feature_extractor(backbone, return_layers) + backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) aux_classifier = FCNHead(aux_inplanes, num_classes) if aux else None classifier = DeepLabHead(out_inplanes, num_classes) diff --git a/torchvision/models/segmentation/fcn.py b/torchvision/models/segmentation/fcn.py index b1bf3c41c09..307781ebf00 100644 --- a/torchvision/models/segmentation/fcn.py +++ b/torchvision/models/segmentation/fcn.py @@ -3,7 +3,7 @@ from torch import nn from .. import resnet -from ..feature_extraction import create_feature_extractor +from .._utils import IntermediateLayerGetter from ._utils import _SimpleSegmentationModel, _load_weights @@ -57,7 +57,7 @@ def _fcn_resnet( return_layers = {"layer4": "out"} if aux: return_layers["layer3"] = "aux" - backbone = create_feature_extractor(backbone, return_layers) + backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) aux_classifier = FCNHead(1024, num_classes) if aux else None classifier = FCNHead(2048, num_classes) diff --git a/torchvision/models/segmentation/lraspp.py b/torchvision/models/segmentation/lraspp.py index f6c2583cac1..ca73140661b 100644 --- a/torchvision/models/segmentation/lraspp.py +++ b/torchvision/models/segmentation/lraspp.py @@ -6,7 +6,7 @@ from ...utils import _log_api_usage_once from .. import mobilenetv3 -from ..feature_extraction import create_feature_extractor +from .._utils import IntermediateLayerGetter from ._utils import _load_weights @@ -90,7 +90,7 @@ def _lraspp_mobilenetv3(backbone: mobilenetv3.MobileNetV3, num_classes: int) -> high_pos = stage_indices[-1] # use C5 which has output_stride = 16 low_channels = backbone[low_pos].out_channels high_channels = backbone[high_pos].out_channels - backbone = create_feature_extractor(backbone, {str(low_pos): "low", str(high_pos): "high"}) + backbone = IntermediateLayerGetter(backbone, return_layers={str(low_pos): "low", str(high_pos): "high"}) return LRASPP(backbone, low_channels, high_channels, num_classes)