diff --git a/torchvision/models/segmentation/segmentation.py b/torchvision/models/segmentation/segmentation.py index d5223842010..c19e36e4705 100644 --- a/torchvision/models/segmentation/segmentation.py +++ b/torchvision/models/segmentation/segmentation.py @@ -5,7 +5,7 @@ from ..._internally_replaced_utils import load_state_dict_from_url from .. import mobilenetv3 from .. import resnet -from .._utils import IntermediateLayerGetter +from ..feature_extraction import create_feature_extractor from .deeplabv3 import DeepLabHead, DeepLabV3 from .fcn import FCN, FCNHead from .lraspp import LRASPP @@ -60,7 +60,7 @@ def _segm_model( return_layers = {out_layer: "out"} if aux: return_layers[aux_layer] = "aux" - backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) + backbone = create_feature_extractor(backbone, return_layers) aux_classifier = None if aux: @@ -116,7 +116,7 @@ def _segm_lraspp_mobilenetv3(backbone_name: str, num_classes: int, pretrained_ba low_channels = backbone[low_pos].out_channels high_channels = backbone[high_pos].out_channels - backbone = IntermediateLayerGetter(backbone, return_layers={str(low_pos): "low", str(high_pos): "high"}) + backbone = create_feature_extractor(backbone, {str(low_pos): "low", str(high_pos): "high"}) model = LRASPP(backbone, low_channels, high_channels, num_classes) return model