In [14]:
import torch
import torch.nn as nn


__all__ = ['AlexNet', 'alexnet']

model_urls = {
    'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
}


"""
class AlexNet(nn.Module):
    def __init__(self, image_channel, num_classes=10000):
        super(AlexNet, self).__init__()

        self.features = nn.Sequential(
            nn.Conv2d(in_channels=image_channel, out_channels=96, kernel_size=11, stride=4, padding=0),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),

            nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, stride=1, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),

            nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),

            nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),

            nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),

        )

        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))

        self.classifer = nn.Sequential(
            nn.Dropout(p=0.5, inplace=True),

            nn.Linear(in_features=256 * 6 * 6, out_features=4096),
            nn.ReLU(),

            nn.Dropout(p=0.5, inplace=True),

            nn.Linear(in_features=4096, out_features=4096),
            nn.ReLU(),

            nn.Linear(in_features=4096, out_features=num_classes),

        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)

        # x = self.flatten(x)           # Why error ???
        x = x.reshape(x.size(0), -1)    # OR   x = x.view(-1, 256 * 6 * 6)
        
        x = self.classifer(x)

        return x
"""


class AlexNet(nn.Module):

    def __init__(self, image_channel, num_classes=1000):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(image_channel, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


def alexnet(pretrained=False, progress=True, **kwargs):
    r"""AlexNet model architecture from the
    `"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    model = AlexNet(**kwargs)
    if pretrained:
        state_dict = torch.hub.load_state_dict_from_url(model_urls['alexnet'],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model


def test():
  # alexNet = AlexNet(image_channel=3, num_classes=10)
  # input = torch.randn(5,3,227,227)
  # output = alexNet(input)

  alexNet = alexnet(pretrained=True, progress=True, image_channel=3, num_classes=1000)
  input = torch.randn(5,3,227,227)
  output = alexNet(input)

  
  print(output.shape)
  print(output)

test()


torch.Size([5, 1000])
tensor([[ 0.2028, -0.7026, -1.1720,  ..., -1.3214, -0.9990,  1.2281],
        [-0.3586, -0.2798, -1.6351,  ..., -0.2906, -1.1012,  2.0939],
        [ 0.5765, -1.6307, -1.5088,  ..., -0.3902, -2.0989,  2.3969],
        [-0.6943, -2.2502, -0.7982,  ..., -2.4075, -1.6611,  1.5419],
        [-0.0226, -2.0184, -1.3322,  ..., -0.9417, -0.4622,  1.1703]],
       grad_fn=<AddmmBackward>)
