In [None]:
import torch
from torch import nn

<img src="./vgg.png" height = 720>

In [None]:
         # (N, img_channel, 224, 224)
vgg16 = [64, 64, 'M',
         # (N. 64, 112, 112)
        128, 128, 'M',
         # (N, 128, 56, 56)
        256, 256, 256, 'M',
         # (N, 256, 28, 28)
        512, 512, 512, 'M',
         # (N, 512, 14, 14)
        512, 512, 512, 'M']
         # (N, 512, 7, 7)
         # after the Linear layer, the tensor will be map into
         # (N, num_class)

In [None]:
class Vgg(nn.Module):
    def __init__(self, layer_arch, img_channel, num_class) -> None:
        super(Vgg, self).__init__()
        self.arch = layer_arch
        self.img_channel = img_channel
        self.num_class = num_class

        # create conv layers: feature
        self.feature = self._creat_conv()
        # create linear layers: classifier
        self.classifier = nn.Sequential(
            nn.Linear(512*7*7, 4096),
            nn.ReLU(),
            nn.Linear(4096, 4096),
            nn.ReLU(),
            nn.Linear(4096, num_class)
        )

        
    def _conv_block(self, inchannel, outchannel, kernel_size, stride, padding):
        # classic conv, bn, relu
        return nn.Sequential(
            nn.Conv2d(inchannel, outchannel, kernel_size, stride, padding),
            nn.BatchNorm2d(outchannel),
            nn.ReLU()
        )

    def _creat_conv(self):
        conv_net = []
        inchannel = self.img_channel
        for each in self.arch:
            if type(each) == int:
                outchannel = each
                conv_net.append(self._conv_block(inchannel, outchannel, (3,3), 1, 1))
                inchannel = outchannel
            if type(each) == str:
                conv_net.append(nn.MaxPool2d(kernel_size=(2,2), stride=2))
        return nn.Sequential(*conv_net)

    def forward(self, x):
        x = self.feature(x)
        x = x.reshape((x.shape[0],-1))
        return self.classifier(x)