In [58]:
import torch
import torch.nn as nn
from torchsummary import summary

In [26]:
class AlexA(nn.Module):
    def __init__(self, ni, nf, kernel, stride=1, padding=2,num_classes=1000):
        super(AlexA, self).__init__()
        self.conv = nn.Conv2d(ni, nf, kernel, stride, padding)
        self.pool = nn.MaxPool2d(3, 2)
    def forward(self, x):
        x = self.conv(x)
        x = nn.ReLU()(x)
        x = self.pool(x)
        return x

In [27]:
class AlexB(nn.Module):
    def __init__(self):
        super(AlexB, self).__init__()
        
        self.conv1 = nn.Conv2d(256, 384, (3, 3), 1, 1)
        self.conv2 = nn.Conv2d(384, 384, (3, 3), 1, 1)
        self.conv3 = nn.Conv2d(384, 256, (3, 3), 2, 1)
        self.pool = nn.MaxPool2d(3, 2)
        
    def forward(self, x):
        x = self.conv1(x)
        x = nn.ReLU()(x)
        x = self.conv2(x)
        x = nn.ReLU()(x)
        x = self.conv3(x)
        x = nn.ReLU()(x)
        x = self.pool(x)

In [45]:
class AlexC(nn.Module):
    def __init__(self, num_classes=1000):
        super(AlexC, self).__init__()
        self.fc1 = nn.Linear(256*6*6, 4096)
        self.fc2 = nn.Linear(4096, 4096)
        self.out = nn.Linear(4096, num_classes)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.out(x)
        return x

In [52]:
class AlexNet(nn.Module):
    def __init__(self, num_classes):
        super(AlexNet, self).__init__()
        self.block1 = AlexA(3, 96, (11, 11), 4, 2)
        self.block2 = AlexA(96, 256, (5, 5), 2, 2)
        self.block3 = AlexB()
        self.dense = AlexC(num_classes=num_classes)
    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = torch.flatten(x, 1)
        x = self.dense(x)
        return x

In [71]:
alex = AlexNet(10)
alex.cuda()

AlexNet(
  (block1): AlexA(
    (conv): Conv2d(3, 96, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (pool): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (block2): AlexA(
    (conv): Conv2d(96, 256, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (pool): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (block3): AlexB(
    (conv1): Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv2): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv3): Conv2d(384, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (pool): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (dense): AlexC(
    (fc1): Linear(in_features=9216, out_features=4096, bias=True)
    (fc2): Linear(in_features=4096, out_features=4096, bias=True)
    (out): Linear(in_features=4096, out_features=10, bias=True)
  )
)

In [74]:
# summary(alex, (3, 224, 224), device='cuda')