# Sample code for Separate VGG16 Model

In [1]:
from typing import Any, cast, Dict, List, Optional, Union

import torch
import torch.nn as nn
import torchvision.models as models

In [2]:
# Load VGG16 Model
vgg_model = models.vgg16(pretrained=True)

# Get state_dict of top 5 layers
top_layers_state_dict = {}
for i, (layer_name, param) in enumerate(vgg_model.named_parameters()):
    print(f'param #{i}: {layer_name}')
    if 'features' in layer_name and i < 5*2:
        top_layers_state_dict[layer_name] = param

# save file path
save_path = 'vgg_top5_layers.pth'

# save state_dict
torch.save(top_layers_state_dict, save_path)



param #0: features.0.weight
param #1: features.0.bias
param #2: features.2.weight
param #3: features.2.bias
param #4: features.5.weight
param #5: features.5.bias
param #6: features.7.weight
param #7: features.7.bias
param #8: features.10.weight
param #9: features.10.bias
param #10: features.12.weight
param #11: features.12.bias
param #12: features.14.weight
param #13: features.14.bias
param #14: features.17.weight
param #15: features.17.bias
param #16: features.19.weight
param #17: features.19.bias
param #18: features.21.weight
param #19: features.21.bias
param #20: features.24.weight
param #21: features.24.bias
param #22: features.26.weight
param #23: features.26.bias
param #24: features.28.weight
param #25: features.28.bias
param #26: classifier.0.weight
param #27: classifier.0.bias
param #28: classifier.3.weight
param #29: classifier.3.bias
param #30: classifier.6.weight
param #31: classifier.6.bias


In [3]:
# Copy and modify from https://pytorch.org/vision/main/_modules/torchvision/models/vgg.html#vgg16

__all__ = [
    "VGG",
    "VGG11_Weights",
    "VGG11_BN_Weights",
    "VGG13_Weights",
    "VGG13_BN_Weights",
    "VGG16_Weights",
    "VGG16_BN_Weights",
    "VGG19_Weights",
    "VGG19_BN_Weights",
    "vgg11",
    "vgg11_bn",
    "vgg13",
    "vgg13_bn",
    "vgg16",
    "vgg16_bn",
    "vgg19",
    "vgg19_bn",
]


class VGG(nn.Module):
    def __init__(
        self, features: nn.Module, num_classes: int = 1000, init_weights: bool = True, dropout: float = 0.5
    ) -> None:
        super().__init__()
        self.features = features
        if init_weights:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                    if m.bias is not None:
                        nn.init.constant_(m.bias, 0)
                elif isinstance(m, nn.BatchNorm2d):
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)
                elif isinstance(m, nn.Linear):
                    nn.init.normal_(m.weight, 0, 0.01)
                    nn.init.constant_(m.bias, 0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        return x


def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False, n_layers: int=100) -> nn.Sequential:
    layers: List[nn.Module] = []
    in_channels = 3
    feature_count = 0
    for v in cfg:
        if v == "M":
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            v = cast(int, v)
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v

            feature_count += 1

        if (feature_count >= n_layers):
            break
    return nn.Sequential(*layers)


cfgs: Dict[str, List[Union[str, int]]] = {
    "A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
    "B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
    "D": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"],
    "E": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"],
}


def vgg_features(cfg: str, batch_norm: bool, progress: bool, n_layers: int, **kwargs: Any) -> VGG:
    model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm, n_layers=n_layers), **kwargs)
    return model

In [4]:
vgg_model_features = vgg_features("D", False, True, 5)
vgg_model_features

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
  )
)

In [5]:
vgg_model_features.load_state_dict(torch.load(save_path))

<All keys matched successfully>

In [6]:
vgg_model_features.eval()

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
  )
)

In [7]:
input_tensor = torch.randn(1, 3, 224, 224)

In [8]:
feature_tensor = vgg_model_features(input_tensor)
print(feature_tensor.shape)

torch.Size([1, 256, 56, 56])


In [9]:
feature_tensor_hook = None
def hook_fn(module, input, output):
    global feature_tensor_hook
    feature_tensor_hook = output

hook_layer = vgg_model.features[11]
hook_layer.register_forward_hook(hook_fn)

<torch.utils.hooks.RemovableHandle at 0x7f97929d23b0>

In [10]:
vgg_model.eval()
output_tensor = vgg_model(input_tensor)

In [11]:
feature_tensor_hook

tensor([[[[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 7.2432,  0.0000,  2.7085,  ...,  0.0000,  0.0000,  0.0000],
          [ 8.9935,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          ...,
          [ 9.2600,  0.0000,  1.5271,  ...,  9.7830,  0.8683,  0.0000],
          [16.1120,  5.3304,  0.0000,  ...,  0.0000,  2.3761,  0.0000],
          [14.4671, 14.3048, 11.8593,  ...,  3.2519,  5.1902,  1.0948]],

         [[ 5.7436,  9.6808,  3.4098,  ...,  4.9111, 12.2151, 13.2478],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  4.6748],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  8.2324],
          ...,
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  6.7248],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  9.6312,  8.9141,  ..., 12.8768, 16.8608, 17.6182]],

         [[ 4.7295,  0.0000,  9.6665,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  4.0611, 1

In [12]:
feature_tensor

tensor([[[[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 7.2432,  0.0000,  2.7085,  ...,  0.0000,  0.0000,  0.0000],
          [ 8.9935,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          ...,
          [ 9.2600,  0.0000,  1.5271,  ...,  9.7830,  0.8683,  0.0000],
          [16.1120,  5.3304,  0.0000,  ...,  0.0000,  2.3761,  0.0000],
          [14.4671, 14.3048, 11.8593,  ...,  3.2519,  5.1902,  1.0948]],

         [[ 5.7436,  9.6808,  3.4098,  ...,  4.9111, 12.2151, 13.2478],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  4.6748],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  8.2324],
          ...,
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  6.7248],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  9.6312,  8.9141,  ..., 12.8768, 16.8608, 17.6182]],

         [[ 4.7295,  0.0000,  9.6665,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  4.0611, 1

In [13]:
(feature_tensor_hook == feature_tensor).all()

tensor(True)