In [None]:
#CNN version
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms

# 定义DenseNet模型的基本组件
class DenseBlock(nn.Module):
    def __init__(self, in_channels, growth_rate, num_layers):
        super(DenseBlock, self).__init__()
        layers = []
        for i in range(num_layers):
            layers.append(nn.BatchNorm2d(in_channels + i * growth_rate))
            layers.append(nn.ReLU(inplace=True))
            layers.append(nn.Conv2d(in_channels + i * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False))
        self.dense_block = nn.Sequential(*layers)

    def forward(self, x):
        return torch.cat([x, self.dense_block(x)], 1)

class TransitionBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(TransitionBlock, self).__init__()
        self.transition_block = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
            nn.AvgPool2d(kernel_size=2, stride=2)
        )

    def forward(self, x):
        return self.transition_block(x)

# 定义DenseNet模型
class DenseNet(nn.Module):
    def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_classes=10):
        super(DenseNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, growth_rate * 2, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(growth_rate * 2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        in_channels = growth_rate * 2
        for i, num_layers in enumerate(block_config):
            block = DenseBlock(in_channels, growth_rate, num_layers)
            self.features.add_module(f"DenseBlock_{i + 1}", block)
            in_channels = in_channels + num_layers * growth_rate
            if i != len(block_config) - 1:
                trans = TransitionBlock(in_channels, in_channels // 2)
                self.features.add_module(f"TransitionBlock_{i + 1}", trans)
                in_channels = in_channels // 2

        self.features.add_module('Norm', nn.BatchNorm2d(in_channels))
        self.features.add_module('ReLU', nn.ReLU(inplace=True))
        self.features.add_module('AvgPool', nn.AdaptiveAvgPool2d(1))

        self.classifier = nn.Linear(in_channels, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

# 设置一些超参数
num_classes = 10
learning_rate = 0.001
batch_size = 64

# 加载数据集，这里使用CIFAR-10作为示例
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# 初始化模型、损失函数和优化器
model = DenseNet(num_classes=num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 训练模型
num_epochs = 5
for epoch in range(num_epochs):
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

# 保存模型
torch.save(model.state_dict(), 'densenet_model.pth')


In [None]:
#Spherical version