Skip to content

create_feature_extractor V.S. _utils.IntermediateLayerGetter #5508

@Alokia

Description

@Alokia

🐛 Describe the bug

When I used resnet50 in torchvision.models, I replaced BatchNorm2d with FrozenBatchNorm2d. Then I needed to extract the output of layer4, but I found that create_feature_extractor does not display FrozenBatchNorm2d. Then I used IntermediateLayerGetter and it displayed it properly.
As shown below:

import torch
from torchvision import models
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models._utils import IntermediateLayerGetter

print(torch.__version__)  # 1.10.2+cu113

class FrozenBatchNorm2d(torch.nn.Module):
    """
    BatchNorm2d where the batch statistics and the affine parameters are fixed.

    Copy-paste from torchvision.misc.ops with added eps before rqsrt,
    without which any other models than torchvision.models.resnet[18,34,50,101]
    produce nans.
    """

    def __init__(self, n):
        super(FrozenBatchNorm2d, self).__init__()
        self.register_buffer("weight", torch.ones(n))
        self.register_buffer("bias", torch.zeros(n))
        self.register_buffer("running_mean", torch.zeros(n))
        self.register_buffer("running_var", torch.ones(n))

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        num_batches_tracked_key = prefix + 'num_batches_tracked'
        if num_batches_tracked_key in state_dict:
            del state_dict[num_batches_tracked_key]

        super(FrozenBatchNorm2d, self)._load_from_state_dict(
            state_dict, prefix, local_metadata, strict,
            missing_keys, unexpected_keys, error_msgs)

    def forward(self, x):
        # move reshapes to the beginning
        # to make it fuser-friendly
        w = self.weight.reshape(1, -1, 1, 1)
        b = self.bias.reshape(1, -1, 1, 1)
        rv = self.running_var.reshape(1, -1, 1, 1)
        rm = self.running_mean.reshape(1, -1, 1, 1)
        eps = 1e-5
        scale = w * (rv + eps).rsqrt()
        bias = b - rm * scale
        return x * scale + bias

model = models.resnet50(pretrained=True, norm_layer=FrozenBatchNorm2d)
model1 = create_feature_extractor(model, return_nodes={'layer4': '0'})
model2 = IntermediateLayerGetter(model, return_layers={'layer4': '0'})

# model1.load_state_dict(model2.state_dict())  # error !

print(model1)
print(model2)

The partial output of create_feature_extractor is as follows:

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Module(
    (0): Module(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (downsample): Module(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )

The partial output of IntermediateLayerGetter is as follows

(layer4): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): FrozenBatchNorm2d()
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn2): FrozenBatchNorm2d()
      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): FrozenBatchNorm2d()
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): FrozenBatchNorm2d()
      )
    )

So, why does create_feature_extractor not work as expected. And, I wonder if there is a built-in function with the same functionality as FrozenBatchNorm2d above.

Versions

PyTorch version: 1.10.2+cu113
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A

OS: Microsoft Windows 10 专业版
GCC version: Could not collect
Clang version: Could not collect
CMake version: Could not collect
Libc version: N/A

Python version: 3.8.5 (default, Sep 3 2020, 21:29:08) [MSC v.1916 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.19041-SP0
Is CUDA available: True
CUDA runtime version: 11.1.74
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.22.1
[pip3] numpydoc==1.1.0
[pip3] torch==1.10.2+cu113
[pip3] torch-tb-profiler==0.2.0
[pip3] torchaudio==0.10.2+cu113
[pip3] torchfile==0.1.0
[pip3] torchsummary==1.5.1
[pip3] torchvision==0.11.3+cu113
[pip3] torchviz==0.0.2
[conda] blas 1.0 mkl
[conda] cudatoolkit 11.1.1 heb2d755_9 conda-forge
[conda] mkl 2020.2 256
[conda] mkl-service 2.3.0 py38hb782905_0
[conda] mkl_fft 1.2.0 py38h45dec08_0
[conda] mkl_random 1.1.1 py38h47e9c7a_0
[conda] mypy-extensions 0.4.3 pypi_0 pypi
[conda] numpy 1.22.1 pypi_0 pypi
[conda] numpydoc 1.1.0 pyhd3eb1b0_1
[conda] torch 1.10.2+cu113 pypi_0 pypi
[conda] torch-tb-profiler 0.2.0 pypi_0 pypi
[conda] torchaudio 0.10.2+cu113 pypi_0 pypi
[conda] torchfile 0.1.0 pypi_0 pypi
[conda] torchsummary 1.5.1 pypi_0 pypi
[conda] torchvision 0.11.3+cu113 pypi_0 pypi
[conda] torchviz 0.0.2 pypi_0 pypi

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions