diff --git a/torchvision/models/detection/backbone_utils.py b/torchvision/models/detection/backbone_utils.py index 4781c912d52..d662de8078a 100644 --- a/torchvision/models/detection/backbone_utils.py +++ b/torchvision/models/detection/backbone_utils.py @@ -96,7 +96,8 @@ def resnet_fpn_backbone( # select layers that wont be frozen assert 0 <= trainable_layers <= 5 layers_to_train = ['layer4', 'layer3', 'layer2', 'layer1', 'conv1'][:trainable_layers] - # freeze layers only if pretrained backbone is used + if trainable_layers == 5: + layers_to_train.append('bn1') for name, parameter in backbone.named_parameters(): if all([not name.startswith(layer) for layer in layers_to_train]): parameter.requires_grad_(False) @@ -152,7 +153,6 @@ def mobilenet_backbone( assert 0 <= trainable_layers <= num_stages freeze_before = num_stages if trainable_layers == 0 else stage_indices[num_stages - trainable_layers] - # freeze layers only if pretrained backbone is used for b in backbone[:freeze_before]: for parameter in b.parameters(): parameter.requires_grad_(False)