In [1]:
import torch
import torchvision
from torchsummary import summary

In [2]:
from torchvision.models import alexnet

In [4]:
summary(alexnet(), (3, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 55, 55]          23,296
              ReLU-2           [-1, 64, 55, 55]               0
         MaxPool2d-3           [-1, 64, 27, 27]               0
            Conv2d-4          [-1, 192, 27, 27]         307,392
              ReLU-5          [-1, 192, 27, 27]               0
         MaxPool2d-6          [-1, 192, 13, 13]               0
            Conv2d-7          [-1, 384, 13, 13]         663,936
              ReLU-8          [-1, 384, 13, 13]               0
            Conv2d-9          [-1, 256, 13, 13]         884,992
             ReLU-10          [-1, 256, 13, 13]               0
           Conv2d-11          [-1, 256, 13, 13]         590,080
             ReLU-12          [-1, 256, 13, 13]               0
        MaxPool2d-13            [-1, 256, 6, 6]               0
AdaptiveAvgPool2d-14            [-1, 25

In [5]:
import torch.nn as nn

In [17]:
class AlexNet(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(AlexNet, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 96, kernel_size=11, stride=4), # (None, 96, 55, 55)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2), #(None, 96, 27, 27)
            nn.Conv2d(96, 256, 5, padding=2), # (None, 256, 27, 27)
            nn.ReLU(),
            nn.MaxPool2d(3, 2), # (None, 256, 13, 13)
            nn.Conv2d(256, 384, 3, padding=1), # (None, 384, 13, 13)
            nn.ReLU(),
            nn.Conv2d(384, 384, 3, padding=1), #( None, 384, 13, 13)
            nn.ReLU(),
            nn.Conv2d(384, 256, 3, padding=1), #(None, 256, 13, 13)
            nn.ReLU(),
            nn.MaxPool2d(3, 2), #(None, 256, 6, 6)
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(256*6*6, 4096),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(),
            nn.Linear(4096, num_classes)
        )
    def forward(self, x):
        x = self.net(x)
        x = x.reshape(-1, 256*6*6)
        return self.classifier(x)

In [19]:
summary(AlexNet(3, 1000), (3, 227, 227))

torch.Size([2, 256, 6, 6])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 96, 55, 55]          34,944
              ReLU-2           [-1, 96, 55, 55]               0
         MaxPool2d-3           [-1, 96, 27, 27]               0
            Conv2d-4          [-1, 256, 27, 27]         614,656
              ReLU-5          [-1, 256, 27, 27]               0
         MaxPool2d-6          [-1, 256, 13, 13]               0
            Conv2d-7          [-1, 384, 13, 13]         885,120
              ReLU-8          [-1, 384, 13, 13]               0
            Conv2d-9          [-1, 384, 13, 13]       1,327,488
             ReLU-10          [-1, 384, 13, 13]               0
           Conv2d-11          [-1, 256, 13, 13]         884,992
             ReLU-12          [-1, 256, 13, 13]               0
        MaxPool2d-13            [-1, 256, 6, 6]               0
          Dr