-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Description
🐛 Describe the bug
When I used resnet50
in torchvision.models
, I replaced BatchNorm2d
with FrozenBatchNorm2d
. Then I needed to extract the output of layer4
, but I found that create_feature_extractor
does not display FrozenBatchNorm2d
. Then I used IntermediateLayerGetter
and it displayed it properly.
As shown below:
import torch
from torchvision import models
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models._utils import IntermediateLayerGetter
print(torch.__version__) # 1.10.2+cu113
class FrozenBatchNorm2d(torch.nn.Module):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed.
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
without which any other models than torchvision.models.resnet[18,34,50,101]
produce nans.
"""
def __init__(self, n):
super(FrozenBatchNorm2d, self).__init__()
self.register_buffer("weight", torch.ones(n))
self.register_buffer("bias", torch.zeros(n))
self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n))
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
num_batches_tracked_key = prefix + 'num_batches_tracked'
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]
super(FrozenBatchNorm2d, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
def forward(self, x):
# move reshapes to the beginning
# to make it fuser-friendly
w = self.weight.reshape(1, -1, 1, 1)
b = self.bias.reshape(1, -1, 1, 1)
rv = self.running_var.reshape(1, -1, 1, 1)
rm = self.running_mean.reshape(1, -1, 1, 1)
eps = 1e-5
scale = w * (rv + eps).rsqrt()
bias = b - rm * scale
return x * scale + bias
model = models.resnet50(pretrained=True, norm_layer=FrozenBatchNorm2d)
model1 = create_feature_extractor(model, return_nodes={'layer4': '0'})
model2 = IntermediateLayerGetter(model, return_layers={'layer4': '0'})
# model1.load_state_dict(model2.state_dict()) # error !
print(model1)
print(model2)
The partial output of create_feature_extractor
is as follows:
ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Module(
(0): Module(
(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(downsample): Module(
(0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
The partial output of IntermediateLayerGetter
is as follows
(layer4): Sequential(
(0): Bottleneck(
(conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): FrozenBatchNorm2d()
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn2): FrozenBatchNorm2d()
(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): FrozenBatchNorm2d()
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): FrozenBatchNorm2d()
)
)
So, why does create_feature_extractor
not work as expected. And, I wonder if there is a built-in function with the same functionality as FrozenBatchNorm2d
above.
Versions
PyTorch version: 1.10.2+cu113
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A
OS: Microsoft Windows 10 专业版
GCC version: Could not collect
Clang version: Could not collect
CMake version: Could not collect
Libc version: N/A
Python version: 3.8.5 (default, Sep 3 2020, 21:29:08) [MSC v.1916 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.19041-SP0
Is CUDA available: True
CUDA runtime version: 11.1.74
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.22.1
[pip3] numpydoc==1.1.0
[pip3] torch==1.10.2+cu113
[pip3] torch-tb-profiler==0.2.0
[pip3] torchaudio==0.10.2+cu113
[pip3] torchfile==0.1.0
[pip3] torchsummary==1.5.1
[pip3] torchvision==0.11.3+cu113
[pip3] torchviz==0.0.2
[conda] blas 1.0 mkl
[conda] cudatoolkit 11.1.1 heb2d755_9 conda-forge
[conda] mkl 2020.2 256
[conda] mkl-service 2.3.0 py38hb782905_0
[conda] mkl_fft 1.2.0 py38h45dec08_0
[conda] mkl_random 1.1.1 py38h47e9c7a_0
[conda] mypy-extensions 0.4.3 pypi_0 pypi
[conda] numpy 1.22.1 pypi_0 pypi
[conda] numpydoc 1.1.0 pyhd3eb1b0_1
[conda] torch 1.10.2+cu113 pypi_0 pypi
[conda] torch-tb-profiler 0.2.0 pypi_0 pypi
[conda] torchaudio 0.10.2+cu113 pypi_0 pypi
[conda] torchfile 0.1.0 pypi_0 pypi
[conda] torchsummary 1.5.1 pypi_0 pypi
[conda] torchvision 0.11.3+cu113 pypi_0 pypi
[conda] torchviz 0.0.2 pypi_0 pypi