In [1]:
import torch
import torch.nn as nn
from torchinfo import summary

In [2]:
class AlexNet(nn.Module):
    def __init__(self, num_classes=1000, dropout=0.5):
        super().__init__()
        # 定义卷积层
        self.features = nn.Sequential(
            # 卷积+ReLU+最大池化
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            # 卷积+ReLU+最大池化
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            # 卷积+ReLU
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            # 卷积+ReLU
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            # 卷积+ReLU
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            # 最大池化
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        # 定义全连接层
        self.classifier = nn.Sequential(
            # Dropout+全连接层+ReLU
            nn.Dropout(p=dropout),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            # Dropout+全连接层+ReLU
            nn.Dropout(p=dropout),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            # 全连接层
            nn.Linear(4096, num_classes),
        )

    # 定义前向传播函数
    def forward(self, x):
        # 先经过feature提取特征
        print("Input shape:\t", x.size())
        for layer in self.features:
            x = layer(x)
            print(layer.__class__.__name__, 'Output shape:\t', x.size())
        # 对特征图进行展平操作
        x = torch.flatten(x, 1)
        print("Flatten Output shape:\t", x.size())
        # 送入全连接层
        for layer in self.classifier:
            x = layer(x)
            print(layer.__class__.__name__, 'Output shape:\t', x.size())
        return x


# 实例化 AlexNet 类
alexnet = AlexNet()

# 生成输入张量
x = torch.randn(1, 3, 224, 224)

# 进行前向传播
output = alexnet(x)

Input shape:	 torch.Size([1, 3, 224, 224])
Conv2d Output shape:	 torch.Size([1, 64, 55, 55])
ReLU Output shape:	 torch.Size([1, 64, 55, 55])
MaxPool2d Output shape:	 torch.Size([1, 64, 27, 27])
Conv2d Output shape:	 torch.Size([1, 192, 27, 27])
ReLU Output shape:	 torch.Size([1, 192, 27, 27])
MaxPool2d Output shape:	 torch.Size([1, 192, 13, 13])
Conv2d Output shape:	 torch.Size([1, 384, 13, 13])
ReLU Output shape:	 torch.Size([1, 384, 13, 13])
Conv2d Output shape:	 torch.Size([1, 256, 13, 13])
ReLU Output shape:	 torch.Size([1, 256, 13, 13])
Conv2d Output shape:	 torch.Size([1, 256, 13, 13])
ReLU Output shape:	 torch.Size([1, 256, 13, 13])
MaxPool2d Output shape:	 torch.Size([1, 256, 6, 6])
Flatten Output shape:	 torch.Size([1, 9216])
Dropout Output shape:	 torch.Size([1, 9216])
Linear Output shape:	 torch.Size([1, 4096])
ReLU Output shape:	 torch.Size([1, 4096])
Dropout Output shape:	 torch.Size([1, 4096])
Linear Output shape:	 torch.Size([1, 4096])
ReLU Output shape:	 torch.Size([1, 