In [3]:
import torch
import torch.nn as nn
from torchsummary import summary

## Alexnet input image size - 3 x 227 x 227

In [4]:
class Alexnet(nn.Module):
    def __init__(self , num_classes) :
        super().__init__()
        self.num_classes = num_classes

        self.conv_layer1 = nn.Sequential(
            nn.Conv2d(in_channels=3 , out_channels=96 , kernel_size=(11,11) , padding=0, stride=4),
            nn.LeakyReLU(0.1),
            nn.MaxPool2d(kernel_size=(3,3) , stride = 2),
            nn.BatchNorm2d(96)
        )

        self.conv_layer2 = nn.Sequential(
            nn.Conv2d(in_channels=96 , out_channels=256 , kernel_size=(5,5) , padding=2 , stride=1),
            nn.LeakyReLU(0.1),
            nn.MaxPool2d(kernel_size=(3,3) , stride=2),
            nn.BatchNorm2d(256)
        )

        self.conv_layer3 = nn.Sequential(
            nn.Conv2d(in_channels=256 , out_channels=384 , kernel_size=(3,3) , stride=1, padding=1),
            nn.LeakyReLU(0.1)
        )

        self.conv_layer4 = nn.Sequential(
            nn.Conv2d(in_channels=384, out_channels=384 , kernel_size=(3,3) , padding=1 , stride = 1),
            nn.LeakyReLU(0.1)
        )

        self.conv_layer5 = nn.Sequential(
            nn.Conv2d(in_channels=384 , out_channels=256 , kernel_size=(3,3) , padding=1 ,stride=1),
            nn.LeakyReLU(0.1),
            nn.MaxPool2d(kernel_size=(3,3) ,stride=2)
        )

        self.FC = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(0.5),
            nn.Linear(in_features=256*6*6 , out_features=4096),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(in_features=4096 , out_features=4096),
            nn.ReLU(),
            nn.Linear(in_features=4096 , out_features=self.num_classes)
        )

    def forward(self,x):
        x = self.conv_layer1(x)
        x = self.conv_layer2(x)
        x = self.conv_layer3(x)
        x = self.conv_layer4(x)
        x = self.conv_layer5(x)
        x = self.FC(x)
        return x
    
model = Alexnet(num_classes=10)
model

Alexnet(
  (conv_layer1): Sequential(
    (0): Conv2d(3, 96, kernel_size=(11, 11), stride=(4, 4))
    (1): LeakyReLU(negative_slope=0.1)
    (2): MaxPool2d(kernel_size=(3, 3), stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv_layer2): Sequential(
    (0): Conv2d(96, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): LeakyReLU(negative_slope=0.1)
    (2): MaxPool2d(kernel_size=(3, 3), stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv_layer3): Sequential(
    (0): Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.1)
  )
  (conv_layer4): Sequential(
    (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.1)
  )
  (conv_layer5): Sequential(
    (0): Conv2d(384, 256, k

In [5]:
x = torch.randn([1,3,227,227])
print(model(x).shape)

torch.Size([1, 10])


In [6]:
from torchinfo import summary
summary(model, ( 1, 3, 227, 227))

Layer (type:depth-idx)                   Output Shape              Param #
Alexnet                                  [1, 10]                   --
├─Sequential: 1-1                        [1, 96, 27, 27]           --
│    └─Conv2d: 2-1                       [1, 96, 55, 55]           34,944
│    └─LeakyReLU: 2-2                    [1, 96, 55, 55]           --
│    └─MaxPool2d: 2-3                    [1, 96, 27, 27]           --
│    └─BatchNorm2d: 2-4                  [1, 96, 27, 27]           192
├─Sequential: 1-2                        [1, 256, 13, 13]          --
│    └─Conv2d: 2-5                       [1, 256, 27, 27]          614,656
│    └─LeakyReLU: 2-6                    [1, 256, 27, 27]          --
│    └─MaxPool2d: 2-7                    [1, 256, 13, 13]          --
│    └─BatchNorm2d: 2-8                  [1, 256, 13, 13]          512
├─Sequential: 1-3                        [1, 384, 13, 13]          --
│    └─Conv2d: 2-9                       [1, 384, 13, 13]          885,120