In [None]:
#CNN version
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms

class MBConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, expand_ratio=6):
        super(MBConvBlock, self).__init__()
        self.expand_ratio = expand_ratio
        hidden_dim = in_channels * expand_ratio

        self.conv1 = nn.Conv2d(in_channels, hidden_dim, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(hidden_dim)
        self.act1 = nn.ReLU6(inplace=True)

        self.conv2 = nn.Conv2d(hidden_dim, hidden_dim, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, groups=hidden_dim, bias=False)
        self.bn2 = nn.BatchNorm2d(hidden_dim)
        self.act2 = nn.ReLU6(inplace=True)

        self.conv3 = nn.Conv2d(hidden_dim, out_channels, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride == 1 and in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = self.act1(self.bn1(self.conv1(x)))
        out = self.act2(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        return out

class EfficientNet(nn.Module):
    def __init__(self, num_classes=10, width_coefficient=1.0, depth_coefficient=1.0, dropout_rate=0.2):
        super(EfficientNet, self).__init__()

        # 定义网络结构参数
        channels = [32, 16, 24, 40, 80, 112, 192, 320, 1280]
        repeats = [1, 2, 2, 3, 3, 4, 1]

        # 根据缩放系数调整通道数
        channels = [int(c * width_coefficient) for c in channels]

        # 定义 EfficientNet 模型主体
        self.features = nn.Sequential(
            nn.Conv2d(3, channels[0], kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(channels[0]),
            nn.ReLU6(inplace=True)
        )

        # 添加 MBConvBlocks
        for i in range(1, len(channels) - 1):
            num_repeats = int(repeats[i] * depth_coefficient)
            for _ in range(num_repeats):
                stride = 2 if i > 0 and _ == 0 else 1
                self.features.add_module(f'MBConvBlock_{i}_{_}', MBConvBlock(channels[i - 1], channels[i], kernel_size=3, stride=stride))

        # 添加最后的卷积和全局平均池化
        self.features.add_module('Conv', nn.Conv2d(channels[-2], channels[-1], kernel_size=1, bias=False))
        self.features.add_module('BatchNorm', nn.BatchNorm2d(channels[-1]))
        self.features.add_module('ReLU', nn.ReLU6(inplace=True))
        self.features.add_module('GlobalAvgPool', nn.AdaptiveAvgPool2d(1))

        # 定义分类器
        self.classifier = nn.Sequential(
            nn.Dropout(dropout_rate),
            nn.Linear(channels[-1], num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

# 设置一些超参数
num_classes = 10
learning_rate = 0.001
batch_size = 64

# 加载数据集，这里使用CIFAR-10作为示例
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# 初始化模型、损失函数和优化器
model = EfficientNet(num_classes=num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 训练模型
num_epochs = 5
for epoch in range(num_epochs):
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

# 保存模型
torch.save(model.state_dict(), 'efficientnet_model.pth')
