In [1]:
import torch
import torch.nn as nn

In [2]:
class VGG(nn.Module):
    def __init__(self, architecture):
        super().__init__()
        self.conv_layers = self.get_conv_layers(architecture)
        self.fc_layers = self.get_fc_layers()

    def get_conv_layers(self, architecture):
        layers = []
        in_channels = 3

        for a in architecture:
            if type(a) == int:
                layers += [nn.Conv2d(in_channels=in_channels, out_channels=a, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
                           nn.BatchNorm2d(a),
                           nn.ReLU()]
                
                in_channels = a
            
            elif a == "M":
                layers += [nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))]
        
        return nn.Sequential(*layers)

    def get_fc_layers(self):
        layers = [nn.Linear(512 * 7 * 7, 4096),
                  nn.ReLU(),
                  nn.Dropout(p=0.5),
                  nn.Linear(4096, 4096),
                  nn.ReLU(),
                  nn.Dropout(p=0.5),
                  nn.Linear(4096, 1000)]

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv_layers(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc_layers(x)

        return x

In [3]:
VGG_architectures = {"VGG11": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],  
                     "VGG13": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],  
                     "VGG16": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M" ],  
                     "VGG19": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"]}

In [4]:
net = VGG(VGG_architectures["VGG19"])
net

VGG(
  (conv_layers): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU()
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU()
    (13): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (14): Conv2d(128, 256, kernel_size=(3, 3), str

In [5]:
data = torch.randn(3, 3, 224, 224)
output = net(data)
output.shape

torch.Size([3, 1000])