In [None]:
import torch
import torch.nn as nn

class Inception(nn.Module):
    # c1: 1x1 conv channel, 
    # c2: 1x1 and 3x3 conv channel, 
    # c3: 1x1 and 5x5 conv channel, 
    # c4: 3x3 max-pooling and 1x1 conv channel
    def __init__(self, in_channels, c1, c2, c3, c4):
        super(Inception, self).__init__()
        self.RELU = nn.ReLU()
        # path1: 1x1 conv
        self.p1_1 = nn.Conv2d(in_channels, c1, kernel_size=1)
        # path2: 1x1 and 3x3 conv
        self.p2_1 = nn.Conv2d(in_channels, c2[0], kernel_size=1)

        # path3: 1x1 and 5x5 conv
        self.p2_2 = nn.Conv2d(c2[0], c2[1], kernel_size=3, padding=1)
        # path4: 3x3 max-pooling and 1x1 conv
        self.p3_1 = nn.Conv2d(in_channels, c3[0], kernel_size=1)
        # path5: 1x1 and 5x5 conv
        self.p3_2 = nn.Conv2d(c3[0], c3[1], kernel_size=5, padding=2)

        # path6: 3x3 max-pooling and 1x1 conv
        self.p4_1 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
        self.p4_2 = nn.Conv2d(in_channels, c4, kernel_size=1)

    def forward(self, x):
        p1 = self.RELU(self.p1_1(x))
        p2 = self.RELU(self.p2_2(self.RELU(self.p2_1(x))))
        p3 = self.RELU(self.p3_2(self.RELU(self.p3_1(x))))
        p4 = self.RELU(self.p4_2(self.p4_1(x)))
        # batch, channel, height, width
        return torch.cat((p1, p2, p3, p4), dim=1)
        
class GoogLeNet(nn.Module):
    def __init__(self, Inception, in_channels, num_classes=1000):
        super(GoogLeNet, self).__init__()
        self.b1 = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=7, stride=2, padding=3),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        self.b2 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=1),
            nn.Conv2d(in_channels=64, out_channels=192, kernel_size=3, padding=1),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        self.b3 = nn.Sequential(
            Inception(in_channels=192, c1=64, c2=(96, 128), c3=(16, 32), c4=32),
            Inception(in_channels=256, c1=128, c2=(128, 192), c3=(32, 96), c4=64),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        self.b4 = nn.Sequential(
            Inception(in_channels=480, c1=192, c2=(96, 208), c3=(16, 48), c4=64),
            Inception(in_channels=512, c1=160, c2=(112, 224), c3=(24, 64), c4=64),
            Inception(in_channels=512, c1=128, c2=(128, 256), c3=(24, 64), c4=64),
            Inception(in_channels=512, c1=112, c2=(144, 288), c3=(32, 64), c4=64),
            Inception(in_channels=528, c1=256, c2=(160, 320), c3=(32, 128), c4=128),    
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        self.b5 = nn.Sequential(
            Inception(in_channels=832, c1=256, c2=(160, 320), c3=(32, 128), c4=128),
            Inception(in_channels=832, c1=384, c2=(192, 384), c3=(48, 128), c4=128),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten()
        )
        self.fc = nn.Sequential(
            nn.Dropout(p=0.4),
            nn.Linear(1024, num_classes)
        )

        # 权重初始化
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
                
    def forward(self, x):
        x = self.b1(x)
        x = self.b2(x)
        x = self.b3(x)
        x = self.b4(x)
        x = self.b5(x)
        x = self.fc(x)
        return x


In [None]:
# 画图
import matplotlib.pyplot as plt

def matplot_train_process_data(train_process_data):
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(train_process_data['epoch'], train_process_data['train_loss'], 'ro-',label='Train Loss')
    plt.plot(train_process_data['epoch'], train_process_data['val_loss'], 'bs-', label='Val Loss')
    plt.legend()
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Train and Val Loss')

    plt.subplot(1, 2, 2)
    plt.plot(train_process_data['epoch'], train_process_data['train_acc'], 'ro-', label='Train Acc')
    plt.plot(train_process_data['epoch'], train_process_data['val_acc'], 'bs-', label='Val Acc')
    plt.legend()
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Train and Val Accuracy')
    plt.show()

In [None]:
# 训练
import copy
import time
import torchvision.transforms as Transforms
import torchvision.datasets as Datasets
import torch.utils.data as Data
import pandas as pd


def train_val_data_process():
    
    # 加载数据集
    train_data = Datasets.FashionMNIST(root='./data', train=True, 
        transform=Transforms.Compose([
            Transforms.Resize((224, 224)),
            Transforms.ToTensor()
        ]),
        download=True)
    
    # 划分训练集和验证集
    train_data, val_data = Data.random_split(train_data, 
    [round(len(train_data) * 0.8), round(len(train_data) * 0.2)])
    
    train_loader = Data.DataLoader(dataset=train_data, 
        batch_size=32,
        shuffle=True,
        num_workers=8)

    val_loader = Data.DataLoader(dataset=val_data, 
        batch_size=32,
        shuffle=True,
        num_workers=8)
    
    return train_loader, val_loader


def train_model_process(model, train_loader, val_loader, num_epochs=5):
    # 加载模型
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("当前设备：", device)
    model = model.to(device)

    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    # 训练损失
    train_loss_all = []
    # 训练准确率
    train_acc_all = []
    # 验证损失
    val_loss_all = []
    # 验证准确率
    val_acc_all = []

    since = time.time()
    # 训练
    for epoch in range(num_epochs):
        
        print(f"Epoch {epoch+1}/{num_epochs} - LR: {optimizer.param_groups[0]['lr']}")
        print('-' * 10)
        
        train_loss = 0.0
        train_corrects = 0.0

        val_loss = 0.0
        val_corrects = 0.0

        train_num = 0
        val_num = 0

        # 训练
        for step, (b_x, b_y) in enumerate(train_loader):
            # 数据移动到设备
            b_x = b_x.to(device)
            b_y = b_y.to(device)

            train_num += b_x.size(0)

            # 模型训练
            model.train()
            outputs = model(b_x)
            # 计算损失
            loss = criterion(outputs, b_y)
            
            # 梯度清零，防止梯度叠加
            optimizer.zero_grad()
            # 反向传播
            loss.backward()
            # 更新参数
            optimizer.step()

            # 计算准确率
            pre_lab = torch.argmax(outputs, dim=1)
            train_corrects += torch.sum(pre_lab == b_y)
            # 计算损失
            train_loss += loss.item() * b_x.size(0)
            
            # 计算验证损失和准确率
        for step, (b_x, b_y) in enumerate(val_loader):
            b_x = b_x.to(device)
            b_y = b_y.to(device)

            val_num += b_x.size(0)

            # 模型验证
            model.eval()
            outputs = model(b_x)
            loss = criterion(outputs, b_y)

            # 计算验证损失
            val_loss += loss.item() * b_x.size(0)
            pre_lab = torch.argmax(outputs, dim=1)
            val_corrects += torch.sum(pre_lab == b_y)

        # 计算训练损失和准确率
        train_loss_all.append(train_loss / train_num)
        train_acc_all.append(train_corrects.double().item() / train_num)
        val_loss_all.append(val_loss / val_num)
        val_acc_all.append(val_corrects.double().item() / val_num)

        print('Train loss: {:.4f} Train acc: {:.4f} | Val loss: {:.4f} Val acc: {:.4f}'
        .format(train_loss_all[-1], train_acc_all[-1], val_loss_all[-1], val_acc_all[-1]))

        # 保存最佳模型参数
    if val_acc_all[-1] > best_acc:
        best_acc = val_acc_all[-1]
        best_model_wts = copy.deepcopy(model.state_dict())
    

    time_elapsed = time.time() - since
    print(f"训练完成，用时 {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s")
    
    # 保存最佳模型
    torch.save(best_model_wts, 'model/vgg16_best_model.pth')

    train_process_data = pd.DataFrame({
        'epoch': range(num_epochs),
        'train_loss': train_loss_all,
        'train_acc': train_acc_all,
        'val_loss': val_loss_all,
        'val_acc': val_acc_all
    })
    return train_process_data
    

model = GoogLeNet(Inception, 1, 10)
train_loader, val_loader = train_val_data_process()
train_process_data = train_model_process(model, train_loader, val_loader, num_epochs=20)

matplot_train_process_data(train_process_data)
    