Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions torchvision/models/detection/backbone_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ def resnet_fpn_backbone(
# select layers that wont be frozen
assert 0 <= trainable_layers <= 5
layers_to_train = ['layer4', 'layer3', 'layer2', 'layer1', 'conv1'][:trainable_layers]
# freeze layers only if pretrained backbone is used
if trainable_layers == 5:
layers_to_train.append('bn1')
for name, parameter in backbone.named_parameters():
if all([not name.startswith(layer) for layer in layers_to_train]):
parameter.requires_grad_(False)
Expand Down Expand Up @@ -152,7 +153,6 @@ def mobilenet_backbone(
assert 0 <= trainable_layers <= num_stages
freeze_before = num_stages if trainable_layers == 0 else stage_indices[num_stages - trainable_layers]

# freeze layers only if pretrained backbone is used
for b in backbone[:freeze_before]:
for parameter in b.parameters():
parameter.requires_grad_(False)
Expand Down