Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
cbeaafd
replace assert with valueerror
abhi-glitchhg Jan 23, 2022
84d9ceb
pytest should raise ValueError not AssertionError
piyush01123 Jan 23, 2022
b8cfa28
minor edit
abhi-glitchhg Jan 24, 2022
830e005
raise assert changed to raise valueerror in test
abhi-glitchhg Jan 24, 2022
5d78882
Merge branch 'main' into backbon_edit
abhi-glitchhg Jan 24, 2022
1a79a05
Update torchvision/models/detection/backbone_utils.py
abhi-glitchhg Jan 24, 2022
b8b22c8
Update torchvision/models/detection/backbone_utils.py
abhi-glitchhg Jan 24, 2022
1bd39a9
minor edits
abhi-glitchhg Jan 24, 2022
6ff20cc
minor edits
abhi-glitchhg Jan 24, 2022
1431ac3
added one test
abhi-glitchhg Jan 24, 2022
ebc3689
added another test
abhi-glitchhg Jan 24, 2022
93c796d
added another test
abhi-glitchhg Jan 24, 2022
14332cb
Merge branch 'main' into backbon_edit
abhi-glitchhg Jan 24, 2022
9ba1eae
test for mobilenet
piyush01123 Jan 25, 2022
1792d4b
Merge branch 'pytorch:main' into backbon_edit
piyush01123 Jan 25, 2022
47569fd
ufmt formatting
piyush01123 Jan 25, 2022
5a6cd57
Merge branch 'backbon_edit' of github.com:piyush01123/vision into bac…
piyush01123 Jan 25, 2022
5c5b0ec
Merge branch 'main' into backbon_edit
piyush01123 Jan 25, 2022
fd94a0a
cant have unused variables
piyush01123 Jan 25, 2022
40becbf
suggested changes
piyush01123 Jan 25, 2022
1083d89
minor edit
piyush01123 Jan 25, 2022
3745a32
corrected bug pointed out by datumbox
piyush01123 Jan 25, 2022
4804a8c
corrected bug pointed out by datumbox
piyush01123 Jan 25, 2022
1272b24
bug correction and shorten msg
piyush01123 Jan 25, 2022
c92c619
ufmt stuff
piyush01123 Jan 25, 2022
922c710
resolved last comment
piyush01123 Jan 25, 2022
d79b512
Merge branch 'main' into backbon_edit
datumbox Jan 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions test/test_backbone_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
from common_utils import set_rng_seed
from torchvision import models
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.detection.backbone_utils import mobilenet_backbone, resnet_fpn_backbone
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names


def get_available_models():
Expand All @@ -23,6 +22,23 @@ def test_resnet_fpn_backbone(backbone_name):
y = resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False)(x)
assert list(y.keys()) == ["0", "1", "2", "3", "pool"]

with pytest.raises(ValueError, match=r"Trainable layers should be in the range"):
resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False, trainable_layers=6)
with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
resnet_fpn_backbone(backbone_name, False, returned_layers=[0, 1, 2, 3])
with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
resnet_fpn_backbone(backbone_name, False, returned_layers=[2, 3, 4, 5])


@pytest.mark.parametrize("backbone_name", ("mobilenet_v2", "mobilenet_v3_large", "mobilenet_v3_small"))
def test_mobilenet_backbone(backbone_name):
with pytest.raises(ValueError, match=r"Trainable layers should be in the range"):
mobilenet_backbone(backbone_name=backbone_name, pretrained=False, fpn=False, trainable_layers=-1)
with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
mobilenet_backbone(backbone_name, False, fpn=True, returned_layers=[-1, 0, 1, 2])
with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
mobilenet_backbone(backbone_name, False, fpn=True, returned_layers=[3, 4, 5, 6])


# Needed by TestFxFeatureExtraction.test_leaf_module_and_function
def leaf_function(x):
Expand Down
5 changes: 2 additions & 3 deletions test/test_models_detection_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import pytest
import torch
from common_utils import assert_equal
from torchvision.models.detection import _utils
from torchvision.models.detection import backbone_utils
from torchvision.models.detection import _utils, backbone_utils
from torchvision.models.detection.transform import GeneralizedRCNNTransform


Expand Down Expand Up @@ -54,7 +53,7 @@ def test_validate_resnet_inputs_detection(self):
)
assert ret == 3
# can't go beyond 5
with pytest.raises(AssertionError):
with pytest.raises(ValueError, match=r"Trainable backbone layers should be in the range"):
ret = backbone_utils._validate_trainable_layers(
pretrained=True, trainable_backbone_layers=6, max_value=5, default_value=3
)
Expand Down
26 changes: 16 additions & 10 deletions torchvision/models/detection/backbone_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import warnings
from typing import Callable, Dict, Optional, List, Union
from typing import Callable, Dict, List, Optional, Union

from torch import nn, Tensor
from torchvision.ops import misc as misc_nn_ops
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool, ExtraFPNBlock
from torchvision.ops.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool

from .. import mobilenet
from .. import resnet
from .. import mobilenet, resnet
from .._utils import IntermediateLayerGetter


Expand Down Expand Up @@ -111,7 +110,8 @@ def _resnet_fpn_extractor(
) -> BackboneWithFPN:

# select layers that wont be frozen
assert 0 <= trainable_layers <= 5
if trainable_layers < 0 or trainable_layers > 5:
raise ValueError(f"Trainable layers should be in the range [0,5], got {trainable_layers}")
layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers]
if trainable_layers == 5:
layers_to_train.append("bn1")
Expand All @@ -124,7 +124,8 @@ def _resnet_fpn_extractor(

if returned_layers is None:
returned_layers = [1, 2, 3, 4]
assert min(returned_layers) > 0 and max(returned_layers) < 5
if min(returned_layers) <= 0 or max(returned_layers) >= 5:
raise ValueError(f"Each returned layer should be in the range [1,4]. Got {returned_layers}")
return_layers = {f"layer{k}": str(v) for v, k in enumerate(returned_layers)}

in_channels_stage2 = backbone.inplanes // 8
Expand Down Expand Up @@ -152,7 +153,10 @@ def _validate_trainable_layers(
# by default freeze first blocks
if trainable_backbone_layers is None:
trainable_backbone_layers = default_value
assert 0 <= trainable_backbone_layers <= max_value
if trainable_backbone_layers < 0 or trainable_backbone_layers > max_value:
raise ValueError(
f"Trainable backbone layers should be in the range [0,{max_value}], got {trainable_backbone_layers} "
)
return trainable_backbone_layers


Expand All @@ -172,7 +176,7 @@ def mobilenet_backbone(
def _mobilenet_extractor(
backbone: Union[mobilenet.MobileNetV2, mobilenet.MobileNetV3],
fpn: bool,
trainable_layers,
trainable_layers: int,
returned_layers: Optional[List[int]] = None,
extra_blocks: Optional[ExtraFPNBlock] = None,
) -> nn.Module:
Expand All @@ -183,7 +187,8 @@ def _mobilenet_extractor(
num_stages = len(stage_indices)

# find the index of the layer from which we wont freeze
assert 0 <= trainable_layers <= num_stages
if trainable_layers < 0 or trainable_layers > num_stages:
raise ValueError(f"Trainable layers should be in the range [0,{num_stages}], got {trainable_layers} ")
freeze_before = len(backbone) if trainable_layers == 0 else stage_indices[num_stages - trainable_layers]

for b in backbone[:freeze_before]:
Expand All @@ -197,7 +202,8 @@ def _mobilenet_extractor(

if returned_layers is None:
returned_layers = [num_stages - 2, num_stages - 1]
assert min(returned_layers) >= 0 and max(returned_layers) < num_stages
if min(returned_layers) < 0 or max(returned_layers) >= num_stages:
raise ValueError(f"Each returned layer should be in the range [0,{num_stages - 1}], got {returned_layers} ")
return_layers = {f"{stage_indices[k]}": str(v) for v, k in enumerate(returned_layers)}

in_channels_list = [backbone[stage_indices[i]].out_channels for i in returned_layers]
Expand Down