-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Description
🐛 Describe the bug
I called torchvision.models.detection.ssd300_vgg16
with different values for trainable_backbone_layers
and it seems to freeze more layers than expected.
To Reproduce
I took a part of the script from this #3905 issue:
import torch
import torchvision
n_trainable_params = []
for trainable_layers in range(6):
model = torchvision.models.detection.ssd300_vgg16(pretrained=False,
pretrained_backbone=True,
trainable_backbone_layers=trainable_layers)
n_trainable_params.append(len([p for p in model.parameters() if p.requires_grad]))
print(f"SSD number of trainable parameters: {n_trainable_params}")
print(f"SSD total number of parameters: {len([p for p in model.parameters()])}")
Output:
SSD number of trainable parameters: [45, 45, 51, 57, 63, 67]
SSD total number of parameters: 71
So we have the same number of trainable parameters when setting trainable_backbone_layers
to 0 and 1.
And even when we set trainable_backbone_layers
to 5, the number of trainable parameters is less than their total number.
Looks like it works correctly when setting trainable_backbone_layers
to 0, but in the other cases it freezes one layer more than needed.
I think the problem is with these two lines:
vision/torchvision/models/detection/ssd.py
Line 535 in 4b6fc6b
stage_indices = [i for i, b in enumerate(backbone) if isinstance(b, nn.MaxPool2d)] |
vision/torchvision/models/detection/ssd.py
Line 540 in 4b6fc6b
freeze_before = len(backbone) if trainable_layers == 0 else stage_indices[num_stages - trainable_layers] |
Expected behavior
The number of trainable parameters should be greater when trainable_backbone_layers
is set to a value greater than 0.
Additional notes
In function header trainable_backbone_layers
is set to None
by default:
vision/torchvision/models/detection/ssd.py
Lines 549 to 556 in 4b6fc6b
def ssd300_vgg16( | |
pretrained: bool = False, | |
progress: bool = True, | |
num_classes: int = 91, | |
pretrained_backbone: bool = True, | |
trainable_backbone_layers: Optional[int] = None, | |
**kwargs: Any, | |
): |
But a few lines below, it defaults to 5 implicitly:
vision/torchvision/models/detection/ssd.py
Lines 604 to 606 in 4b6fc6b
trainable_backbone_layers = _validate_trainable_layers( | |
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 5 | |
) |
This confused me somewhat when I tried to train this model on my own dataset, because I thought (for no reason) the entire backbone was frozen by default. Maybe, for more transparency, it would be better to explicitly set the
trainable_backbone_layers
to 5 by default in the header?
Versions
Google Colab environment:
PyTorch version: 1.9.0+cu102
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A
OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: 6.0.0-1ubuntu2 (tags/RELEASE_600/final)
CMake version: version 3.12.0
Libc version: glibc-2.26
Python version: 3.7.12 (default, Sep 10 2021, 00:21:48) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.4.104+-x86_64-with-Ubuntu-18.04-bionic
Is CUDA available: False
CUDA runtime version: 11.1.105
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.0.5
HIP runtime version: N/A
MIOpen runtime version: N/A
Versions of relevant libraries:
[pip3] numpy==1.19.5
[pip3] torch==1.9.0+cu102
[pip3] torchsummary==1.5.1
[pip3] torchtext==0.10.0
[pip3] torchvision==0.10.0+cu102
[conda] Could not collect
cc @datumbox