<a href="https://colab.research.google.com/github/shuyaguan/0826/blob/main/Shuya_Guan_HW6.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torch.nn as nn

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, use_residual=True):
        super(ConvBlock, self).__init__()
        self.use_residual = use_residual
        self.same_channels = (in_channels == out_channels)

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu1 = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, 1, padding)
        self.bn2 = nn.BatchNorm2d(out_channels)

        if use_residual and not self.same_channels:
            self.skip = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride, 0, bias=False),
                nn.BatchNorm2d(out_channels)
            )

        self.relu2 = nn.ReLU(inplace=True)

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.use_residual:
            if not self.same_channels:
                identity = self.skip(identity)
            out += identity

        out = self.relu2(out)
        return out

class SpatialAttention(nn.Module):
    def __init__(self, in_channels):
        super(SpatialAttention, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, in_channels//8, kernel_size=1),
            nn.BatchNorm2d(in_channels//8),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels//8, 1, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        attention = self.conv(x)
        return x * attention

class AdvancedCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(AdvancedCNN, self).__init__()

        self.initial = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        self.block1 = nn.Sequential(
            ConvBlock(64, 64),
            ConvBlock(64, 64),
            SpatialAttention(64)
        )

        self.block2 = nn.Sequential(
            ConvBlock(64, 128, stride=2),
            ConvBlock(128, 128),
            SpatialAttention(128)
        )

        self.block3 = nn.Sequential(
            ConvBlock(128, 256, stride=2),
            ConvBlock(256, 256),
            ConvBlock(256, 256),
            SpatialAttention(256)
        )

        self.block4 = nn.Sequential(
            ConvBlock(256, 512, stride=2),
            ConvBlock(512, 512),
            ConvBlock(512, 512),
            SpatialAttention(512)
        )

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(0.5)
        self.fc = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.initial(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = self.fc(x)
        return x

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
from tqdm.notebook import tqdm

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    running_corrects = 0

    for inputs, labels in tqdm(dataloader):
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        _, preds = torch.max(outputs, 1)
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)

    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_acc = running_corrects.double() / len(dataloader.dataset)

    return epoch_loss, epoch_acc

def validate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    running_corrects = 0

    with torch.no_grad():
        for inputs, labels in tqdm(dataloader):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            _, preds = torch.max(outputs, 1)
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_acc = running_corrects.double() / len(dataloader.dataset)

    return epoch_loss, epoch_acc

In [4]:
# 下载数据集
!wget http://madm.dfki.de/files/sentinel/EuroSAT.zip -O EuroSAT.zip

# 解压数据集
!unzip -q EuroSAT.zip -d 'EuroSAT/'

# 删除压缩文件
!rm EuroSAT.zip

# 确认数据集文件夹存在
!ls EuroSAT/

--2025-03-24 17:36:56--  http://madm.dfki.de/files/sentinel/EuroSAT.zip
Resolving madm.dfki.de (madm.dfki.de)... 131.246.195.183
Connecting to madm.dfki.de (madm.dfki.de)|131.246.195.183|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 94280567 (90M) [application/zip]
Saving to: ‘EuroSAT.zip’


2025-03-24 17:37:01 (20.0 MB/s) - ‘EuroSAT.zip’ saved [94280567/94280567]

2750


In [5]:
import torch
from torch.utils import data
from torchvision import datasets, transforms
import os

# 检查数据集路径
data_dir = './EuroSAT/2750/'
if not os.path.exists(data_dir):
    print(f"警告：路径 {data_dir} 不存在，尝试查找正确路径...")
    # 查找实际路径
    !find ./EuroSAT -type d -name "[0-9]*" | head -1
    # 检查EuroSAT目录结构
    !ls -la ./EuroSAT

# 定义数据转换
input_size = 224
imagenet_mean, imagenet_std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(input_size),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(imagenet_mean, imagenet_std)
])

val_transform = transforms.Compose([
    transforms.Resize(input_size),
    transforms.CenterCrop(input_size),
    transforms.ToTensor(),
    transforms.Normalize(imagenet_mean, imagenet_std)
])

test_transform = transforms.Compose([
    transforms.Resize(input_size),
    transforms.CenterCrop(input_size),
    transforms.ToTensor(),
    transforms.Normalize(imagenet_mean, imagenet_std)
])

# 修改这里的路径，可能需要根据实际情况调整
data_dir = './EuroSAT/'
dataset = datasets.ImageFolder(data_dir)

# 获取类别名称
class_names = dataset.classes
print("类别名称:", class_names)
print("类别总数:", len(class_names))

# 定义数据集类
class EuroSAT(data.Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __getitem__(self, index):
        if self.transform:
            x = self.transform(self.dataset[index][0])
        else:
            x = self.dataset[index][0]
        y = self.dataset[index][1]
        return x, y

    def __len__(self):
        return len(self.dataset)

# 应用数据转换
train_data = EuroSAT(dataset, train_transform)
val_data = EuroSAT(dataset, val_transform)
test_data = EuroSAT(dataset, test_transform)

# 随机分割数据集
import numpy as np
train_size = 0.70
val_size = 0.15
indices = list(range(int(len(dataset))))
train_split = int(train_size * len(dataset))
val_split = int(val_size * len(dataset))
np.random.seed(42)
np.random.shuffle(indices)

train_data = data.Subset(train_data, indices=indices[:train_split])
val_data = data.Subset(val_data, indices=indices[train_split: train_split+val_split])
test_data = data.Subset(test_data, indices=indices[train_split+val_split:])
print("训练/验证/测试集大小: {}/{}/{}".format(len(train_data), len(val_data), len(test_data)))

# 创建数据加载器
batch_size = 32
num_workers = 2

train_loader = data.DataLoader(
    train_data, batch_size=batch_size, num_workers=num_workers, shuffle=True
)
val_loader = data.DataLoader(
    val_data, batch_size=batch_size, num_workers=num_workers, shuffle=False
)
test_loader = data.DataLoader(
    test_data, batch_size=batch_size, num_workers=num_workers, shuffle=False
)

类别名称: ['2750']
类别总数: 1
训练/验证/测试集大小: 18900/4050/4050


In [6]:
import torch
from torch.utils import data
from torchvision import datasets, transforms

# 定义数据转换
input_size = 224
imagenet_mean, imagenet_std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(input_size),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(imagenet_mean, imagenet_std)
])

val_transform = transforms.Compose([
    transforms.Resize(input_size),
    transforms.CenterCrop(input_size),
    transforms.ToTensor(),
    transforms.Normalize(imagenet_mean, imagenet_std)
])

test_transform = transforms.Compose([
    transforms.Resize(input_size),
    transforms.CenterCrop(input_size),
    transforms.ToTensor(),
    transforms.Normalize(imagenet_mean, imagenet_std)
])

# 定义数据集类
class EuroSAT(data.Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __getitem__(self, index):
        if self.transform:
            x = self.transform(self.dataset[index][0])
        else:
            x = self.dataset[index][0]
        y = self.dataset[index][1]
        return x, y

    def __len__(self):
        return len(self.dataset)

# 加载数据集
data_dir = './EuroSAT/2750/'
dataset = datasets.ImageFolder(data_dir)

# 获取类别名称
class_names = dataset.classes
print("类别名称:", class_names)
print("类别总数:", len(class_names))

# 应用数据转换
train_data = EuroSAT(dataset, train_transform)
val_data = EuroSAT(dataset, val_transform)
test_data = EuroSAT(dataset, test_transform)

# 随机分割数据集
import numpy as np
train_size = 0.70
val_size = 0.15
indices = list(range(int(len(dataset))))
train_split = int(train_size * len(dataset))
val_split = int(val_size * len(dataset))
np.random.seed(42)
np.random.shuffle(indices)

train_data = data.Subset(train_data, indices=indices[:train_split])
val_data = data.Subset(val_data, indices=indices[train_split: train_split+val_split])
test_data = data.Subset(test_data, indices=indices[train_split+val_split:])
print("训练/验证/测试集大小: {}/{}/{}".format(len(train_data), len(val_data), len(test_data)))

# 创建数据加载器
batch_size = 32
num_workers = 2

train_loader = data.DataLoader(
    train_data, batch_size=batch_size, num_workers=num_workers, shuffle=True
)
val_loader = data.DataLoader(
    val_data, batch_size=batch_size, num_workers=num_workers, shuffle=False
)
test_loader = data.DataLoader(
    test_data, batch_size=batch_size, num_workers=num_workers, shuffle=False
)

类别名称: ['AnnualCrop', 'Forest', 'HerbaceousVegetation', 'Highway', 'Industrial', 'Pasture', 'PermanentCrop', 'Residential', 'River', 'SeaLake']
类别总数: 10
训练/验证/测试集大小: 18900/4050/4050


In [7]:
# 初始化模型
model = AdvancedCNN(num_classes=10)
model = model.to(device)

# 设置损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)

# 训练参数
num_epochs = 30
best_val_acc = 0.0
best_model_wts = None

# 训练循环
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")

    # 训练阶段
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")

    # 验证阶段
    val_loss, val_acc = validate(model, val_loader, criterion, device)
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

    # 更新学习率
    scheduler.step(val_loss)

    # 保存最佳模型
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_model_wts = model.state_dict().copy()

    print("-" * 40)

# 加载最佳模型
model.load_state_dict(best_model_wts)
print(f"最佳验证集准确率: {best_val_acc:.4f}")

# 在测试集上评估
test_loss, test_acc = validate(model, test_loader, criterion, device)
print(f"测试集准确率: {test_acc:.4f}")

Epoch 1/30




  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 1.4113, Train Acc: 0.4921


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 1.1287, Val Acc: 0.5911
----------------------------------------
Epoch 2/30


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 1.0484, Train Acc: 0.6276


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.8724, Val Acc: 0.7259
----------------------------------------
Epoch 3/30


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.8637, Train Acc: 0.6988


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.6234, Val Acc: 0.7795
----------------------------------------
Epoch 4/30


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.7444, Train Acc: 0.7432


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.7764, Val Acc: 0.7462
----------------------------------------
Epoch 5/30


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.6795, Train Acc: 0.7662


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.4663, Val Acc: 0.8398
----------------------------------------
Epoch 6/30


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.6327, Train Acc: 0.7856


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.4819, Val Acc: 0.8306
----------------------------------------
Epoch 7/30


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.5836, Train Acc: 0.8021


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.4182, Val Acc: 0.8570
----------------------------------------
Epoch 8/30


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.5430, Train Acc: 0.8183


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.5029, Val Acc: 0.8281
----------------------------------------
Epoch 9/30


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.5107, Train Acc: 0.8303


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.3651, Val Acc: 0.8746
----------------------------------------
Epoch 10/30


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.4909, Train Acc: 0.8380


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.3663, Val Acc: 0.8756
----------------------------------------
Epoch 11/30


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.4775, Train Acc: 0.8409


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.3129, Val Acc: 0.8995
----------------------------------------
Epoch 12/30


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.4469, Train Acc: 0.8514


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.2874, Val Acc: 0.9059
----------------------------------------
Epoch 13/30


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.4263, Train Acc: 0.8584


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.2300, Val Acc: 0.9193
----------------------------------------
Epoch 14/30


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.4151, Train Acc: 0.8606


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.2589, Val Acc: 0.9099
----------------------------------------
Epoch 15/30


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.4025, Train Acc: 0.8680


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.2396, Val Acc: 0.9198
----------------------------------------
Epoch 16/30


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.3908, Train Acc: 0.8722


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.2086, Val Acc: 0.9269
----------------------------------------
Epoch 17/30


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.3903, Train Acc: 0.8704


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.2287, Val Acc: 0.9232
----------------------------------------
Epoch 18/30


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.3695, Train Acc: 0.8792


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.1972, Val Acc: 0.9296
----------------------------------------
Epoch 19/30


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.3555, Train Acc: 0.8813


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.2044, Val Acc: 0.9331
----------------------------------------
Epoch 20/30


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.3479, Train Acc: 0.8825


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.2590, Val Acc: 0.9116
----------------------------------------
Epoch 21/30


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.3547, Train Acc: 0.8825


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.1984, Val Acc: 0.9336
----------------------------------------
Epoch 22/30


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.3416, Train Acc: 0.8889


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.1935, Val Acc: 0.9321
----------------------------------------
Epoch 23/30


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.3335, Train Acc: 0.8903


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.1840, Val Acc: 0.9400
----------------------------------------
Epoch 24/30


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.3295, Train Acc: 0.8904


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.1572, Val Acc: 0.9462
----------------------------------------
Epoch 25/30


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.3239, Train Acc: 0.8947


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.1510, Val Acc: 0.9494
----------------------------------------
Epoch 26/30


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.3290, Train Acc: 0.8930


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.1509, Val Acc: 0.9491
----------------------------------------
Epoch 27/30


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.3117, Train Acc: 0.8975


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.1763, Val Acc: 0.9422
----------------------------------------
Epoch 28/30


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.3178, Train Acc: 0.8957


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.2090, Val Acc: 0.9284
----------------------------------------
Epoch 29/30


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.3085, Train Acc: 0.8997


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.1614, Val Acc: 0.9435
----------------------------------------
Epoch 30/30


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.3085, Train Acc: 0.8983


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.1692, Val Acc: 0.9398
----------------------------------------
最佳验证集准确率: 0.9494


  0%|          | 0/127 [00:00<?, ?it/s]

测试集准确率: 0.9363


In [8]:
from torchvision import transforms
import torch

# 创建更高级的数据增强
advanced_train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(20),  # 添加随机旋转
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),  # 添加颜色抖动
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),  # 添加平移变换
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# 更新训练数据集
train_data = EuroSAT(dataset, advanced_train_transform)
train_data = data.Subset(train_data, indices=indices[:train_split])
train_loader = data.DataLoader(train_data, batch_size=batch_size, num_workers=num_workers, shuffle=True)

In [10]:
from torchvision import transforms

# 创建更强大的数据增强
advanced_train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),  # 随机裁剪，保留70-100%的原图
    transforms.RandomHorizontalFlip(p=0.5),  # 50%几率水平翻转
    transforms.RandomVerticalFlip(p=0.5),    # 50%几率垂直翻转
    transforms.RandomRotation(30),           # 随机旋转30度以内
    transforms.ColorJitter(
        brightness=0.3,   # 亮度变化范围
        contrast=0.3,     # 对比度变化范围
        saturation=0.3,   # 饱和度变化范围
        hue=0.1           # 色调变化范围
    ),
    transforms.RandomAffine(
        degrees=0,
        translate=(0.1, 0.1),  # 在各方向平移10%以内
        scale=(0.9, 1.1)       # 缩放因子在0.9-1.1之间
    ),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [11]:
# 使用新的增强重新定义训练数据
train_data = EuroSAT(dataset, advanced_train_transform)
train_data = data.Subset(train_data, indices=indices[:train_split])

# 更新数据加载器
train_loader = data.DataLoader(
    train_data,
    batch_size=batch_size,
    num_workers=num_workers,
    shuffle=True
)

In [12]:
# 使用带标签平滑的损失函数
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)  # 0.1是标签平滑参数

In [13]:
# 初始化模型
model = AdvancedCNN(num_classes=10)
model = model.to(device)

# 设置优化器和学习率
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=5,
    verbose=True
)

# 训练参数
num_epochs = 50  # 增加训练轮数
best_val_acc = 0.0
best_model_wts = None

# 训练循环
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")

    # 训练阶段
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")

    # 验证阶段
    val_loss, val_acc = validate(model, val_loader, criterion, device)
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

    # 更新学习率
    scheduler.step(val_loss)

    # 保存最佳模型
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_model_wts = model.state_dict().copy()
        # 连续10个epoch没有提升就提前停止
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= 10:
            print(f"早停：{epoch+1}轮后没有改进")
            break

    print("-" * 40)

# 加载最佳模型
model.load_state_dict(best_model_wts)
print(f"最佳验证集准确率: {best_val_acc:.4f}")

# 在测试集上评估
test_loss, test_acc = validate(model, test_loader, criterion, device)
print(f"测试集准确率: {test_acc:.4f}")

Epoch 1/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 1.7151, Train Acc: 0.4137


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 1.4045, Val Acc: 0.5825
----------------------------------------
Epoch 2/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 1.4170, Train Acc: 0.5861


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 1.3801, Val Acc: 0.6289
----------------------------------------
Epoch 3/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 1.3174, Train Acc: 0.6381


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 1.0634, Val Acc: 0.7526
----------------------------------------
Epoch 4/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 1.2249, Train Acc: 0.6835


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 1.2058, Val Acc: 0.6956
----------------------------------------
Epoch 5/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 1.1613, Train Acc: 0.7127


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 1.0381, Val Acc: 0.7731
----------------------------------------
Epoch 6/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 1.0842, Train Acc: 0.7495


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 1.0378, Val Acc: 0.7825
----------------------------------------
Epoch 7/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 1.0150, Train Acc: 0.7815


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.8911, Val Acc: 0.8348
----------------------------------------
Epoch 8/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.9697, Train Acc: 0.8014


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.8071, Val Acc: 0.8664
----------------------------------------
Epoch 9/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.9321, Train Acc: 0.8183


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.7827, Val Acc: 0.8835
----------------------------------------
Epoch 10/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.8957, Train Acc: 0.8342


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.7833, Val Acc: 0.8793
----------------------------------------
Epoch 11/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.8763, Train Acc: 0.8454


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.7961, Val Acc: 0.8765
----------------------------------------
Epoch 12/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.8465, Train Acc: 0.8547


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.7496, Val Acc: 0.8857
----------------------------------------
Epoch 13/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.8237, Train Acc: 0.8666


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.7194, Val Acc: 0.9081
----------------------------------------
Epoch 14/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.8088, Train Acc: 0.8724


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.7329, Val Acc: 0.9007
----------------------------------------
Epoch 15/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.7932, Train Acc: 0.8807


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.6978, Val Acc: 0.9128
----------------------------------------
Epoch 16/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.7806, Train Acc: 0.8838


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.7416, Val Acc: 0.8958
----------------------------------------
Epoch 17/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.7793, Train Acc: 0.8870


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.7126, Val Acc: 0.9091
----------------------------------------
Epoch 18/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.7572, Train Acc: 0.8956


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.7222, Val Acc: 0.9040
----------------------------------------
Epoch 19/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.7570, Train Acc: 0.8936


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.6639, Val Acc: 0.9351
----------------------------------------
Epoch 20/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.7440, Train Acc: 0.9016


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.7022, Val Acc: 0.9133
----------------------------------------
Epoch 21/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.7381, Train Acc: 0.9010


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.6617, Val Acc: 0.9338
----------------------------------------
Epoch 22/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.7259, Train Acc: 0.9087


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.6330, Val Acc: 0.9385
----------------------------------------
Epoch 23/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.7296, Train Acc: 0.9077


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.6609, Val Acc: 0.9286
----------------------------------------
Epoch 24/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.7149, Train Acc: 0.9130


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.6736, Val Acc: 0.9240
----------------------------------------
Epoch 25/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.7162, Train Acc: 0.9115


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.6846, Val Acc: 0.9165
----------------------------------------
Epoch 26/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.7103, Train Acc: 0.9128


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.6916, Val Acc: 0.9215
----------------------------------------
Epoch 27/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.7084, Train Acc: 0.9132


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.6667, Val Acc: 0.9279
----------------------------------------
Epoch 28/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.7002, Train Acc: 0.9184


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.6398, Val Acc: 0.9390
----------------------------------------
Epoch 29/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.6703, Train Acc: 0.9302


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.6085, Val Acc: 0.9568
----------------------------------------
Epoch 30/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.6675, Train Acc: 0.9334


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.6533, Val Acc: 0.9331
----------------------------------------
Epoch 31/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.6634, Train Acc: 0.9330


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.6090, Val Acc: 0.9565
----------------------------------------
Epoch 32/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.6556, Train Acc: 0.9379


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.6156, Val Acc: 0.9462
----------------------------------------
Epoch 33/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.6595, Train Acc: 0.9340


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.6479, Val Acc: 0.9353
----------------------------------------
Epoch 34/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.6569, Train Acc: 0.9372


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.5955, Val Acc: 0.9595
----------------------------------------
Epoch 35/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.6575, Train Acc: 0.9368


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.6422, Val Acc: 0.9422
----------------------------------------
Epoch 36/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.6576, Train Acc: 0.9376


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.5971, Val Acc: 0.9615
----------------------------------------
Epoch 37/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.6495, Train Acc: 0.9410


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.6160, Val Acc: 0.9489
----------------------------------------
Epoch 38/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.6481, Train Acc: 0.9398


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.6442, Val Acc: 0.9385
----------------------------------------
Epoch 39/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.6455, Train Acc: 0.9409


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.6174, Val Acc: 0.9484
----------------------------------------
Epoch 40/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.6454, Train Acc: 0.9435


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.5882, Val Acc: 0.9630
----------------------------------------
Epoch 41/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.6428, Train Acc: 0.9422


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.6292, Val Acc: 0.9440
----------------------------------------
Epoch 42/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.6415, Train Acc: 0.9437


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.5909, Val Acc: 0.9605
----------------------------------------
Epoch 43/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.6408, Train Acc: 0.9435


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.6056, Val Acc: 0.9575
----------------------------------------
Epoch 44/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.6378, Train Acc: 0.9435


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.6046, Val Acc: 0.9558
----------------------------------------
Epoch 45/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.6373, Train Acc: 0.9458


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.6398, Val Acc: 0.9412
----------------------------------------
Epoch 46/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.6358, Train Acc: 0.9457


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.5895, Val Acc: 0.9647
----------------------------------------
Epoch 47/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.6229, Train Acc: 0.9505


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.5983, Val Acc: 0.9570
----------------------------------------
Epoch 48/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.6191, Train Acc: 0.9535


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.5847, Val Acc: 0.9640
----------------------------------------
Epoch 49/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.6200, Train Acc: 0.9538


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.5870, Val Acc: 0.9644
----------------------------------------
Epoch 50/50


  0%|          | 0/591 [00:00<?, ?it/s]

Train Loss: 0.6170, Train Acc: 0.9532


  0%|          | 0/127 [00:00<?, ?it/s]

Val Loss: 0.5996, Val Acc: 0.9573
----------------------------------------
最佳验证集准确率: 0.9647


  0%|          | 0/127 [00:00<?, ?it/s]

测试集准确率: 0.9504
