In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary

In [6]:
config = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


class VGG(nn.Module):
    def __init__(self, vgg_name, num_class):
        super().__init__()
        self.features = self._make_layers(config[vgg_name])
        self.classifier = nn.Linear(512, num_class)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        
        return out

    def _make_layers(self, config):
        layers = []
        in_channels = 3
        for x in config:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [
                    nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                    nn.BatchNorm2d(x),
                    nn.ReLU(inplace=True)
                ]
                in_channels = x
                
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        
        return nn.Sequential(*layers)

In [7]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = VGG('VGG11', 10).to(device)
summary(net, input_data=(3, 32, 32), verbose=0)

------------------------------------------------------------------------------------------
Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 512, 1, 1]           --
|    └─Conv2d: 2-1                       [-1, 64, 32, 32]          1,792
|    └─BatchNorm2d: 2-2                  [-1, 64, 32, 32]          128
|    └─ReLU: 2-3                         [-1, 64, 32, 32]          --
|    └─MaxPool2d: 2-4                    [-1, 64, 16, 16]          --
|    └─Conv2d: 2-5                       [-1, 128, 16, 16]         73,856
|    └─BatchNorm2d: 2-6                  [-1, 128, 16, 16]         256
|    └─ReLU: 2-7                         [-1, 128, 16, 16]         --
|    └─MaxPool2d: 2-8                    [-1, 128, 8, 8]           --
|    └─Conv2d: 2-9                       [-1, 256, 8, 8]           295,168
|    └─BatchNorm2d: 2-10                 [-1, 256, 8, 8]           512
|    └─ReLU: 2-11                        [-1, 256