In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms
from sklearn.metrics import classification_report, confusion_matrix
import time
import copy

# 设置随机种子确保结果可复现
torch.manual_seed(42)
np.random.seed(42)

# 配置参数
DATA_DIR = r"D:\homework\论文\论文\project\code\训练\dataset_split"
IMAGE_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 30
NUM_WORKERS = 4

# 模型配置（替换VGG16为ShuffleNet v2）
MODELS = {
    "efficientnet-b0": {
        "model": models.efficientnet_b0,
        "pretrained": True,
        "feature_size": 1280,
        "classifier": lambda feat_size, num_classes: nn.Linear(feat_size, num_classes)
    },
    "resnet50": {
        "model": models.resnet50,
        "pretrained": True,
        "feature_size": 2048,
        "classifier": lambda feat_size, num_classes: nn.Linear(feat_size, num_classes)
    },
    "shufflenet_v2": {
        "model": models.shufflenet_v2_x1_0,
        "pretrained": True,
        "feature_size": 1024,
        "classifier": lambda feat_size, num_classes: nn.Linear(feat_size, num_classes)
    },
    "mobilenet_v3": {
        "model": models.mobilenet_v3_large,
        "pretrained": True,
        "feature_size": 960,  # 修正特征维度
        "classifier": lambda feat_size, num_classes: nn.Sequential(
            nn.Linear(feat_size, 1280),
            nn.Hardswish(),
            nn.Dropout(p=0.2, inplace=True),
            nn.Linear(1280, num_classes)
        )
    }
}

# 检查GPU是否可用
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def create_data_transforms():
    data_transforms = {
        'train': transforms.Compose([
            transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.8, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(p=0.1),
            transforms.RandomRotation(30),
            transforms.ColorJitter(
                brightness=0.3,
                contrast=0.3,
                saturation=0.3,
                hue=0.1
            ),
            transforms.RandomGrayscale(p=0.1),
            # transforms.RandomErasing(p=0.2, scale=(0.02, 0.2)),
            transforms.RandomAffine(
                degrees=0,
                translate=(0.1, 0.1),
                scale=(0.9, 1.1),
                shear=10
            ),
            transforms.RandomPerspective(distortion_scale=0.2, p=0.2),
            transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.Resize(IMAGE_SIZE),
            transforms.CenterCrop(IMAGE_SIZE),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'test': transforms.Compose([
            transforms.Resize(IMAGE_SIZE),
            transforms.CenterCrop(IMAGE_SIZE),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }
    return data_transforms

def create_data_loaders(data_dir):
    """创建数据加载器"""
    data_transforms = create_data_transforms()
    
    image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
                      for x in ['train', 'val', 'test']}
    
    dataloaders = {x: DataLoader(image_datasets[x], batch_size=BATCH_SIZE,
                                 shuffle=True if x == 'train' else False,
                                 num_workers=NUM_WORKERS)
                   for x in ['train', 'val', 'test']}
    
    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val', 'test']}
    class_names = image_datasets['train'].classes
    
    return dataloaders, dataset_sizes, class_names

def initialize_model(model_name, num_classes):
    """初始化预训练模型"""
    model_info = MODELS[model_name]
    model = model_info["model"](pretrained=model_info["pretrained"])
    
    # 修改分类器
    if model_name == "resnet50":
        model.fc = model_info["classifier"](model_info["feature_size"], num_classes)
    elif model_name == "shufflenet_v2":
        model.fc = model_info["classifier"](model_info["feature_size"], num_classes)
    elif model_name == "efficientnet-b0":
        model.classifier[1] = model_info["classifier"](model_info["feature_size"], num_classes)
    elif model_name == "mobilenet_v3":
        model.classifier = model_info["classifier"](model_info["feature_size"], num_classes)
    
    # 打印模型结构（可选）
    print(f"\n{model_name} Classifier:")
    print(model.classifier if hasattr(model, 'classifier') else model.fc)
    
    return model.to(device)

def count_parameters(model):
    """计算模型参数量（MB）"""
    params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return params * 4 / (1024 * 1024)  # 转换为MB

def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, model_name="model"):
    """训练模型并记录指标"""
    since = time.time()
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    # 记录训练历史
    history = {
        'train_loss': [], 'train_acc': [], 'train_precision': [], 'train_recall': [], 'train_f1': [],
        'val_loss': [], 'val_acc': [], 'val_precision': [], 'val_recall': [], 'val_f1': []
    }
    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs-1}')
        print('-' * 10)
        
        # 每个epoch都有一个训练和验证阶段
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # 训练模式
            else:
                model.eval()   # 评估模式
            
            running_loss = 0.0
            running_corrects = 0
            all_preds = []
            all_labels = []
            
            # 迭代数据
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                # 零梯度
                optimizer.zero_grad()
                
                # 前向传播
                # 只有在训练时才跟踪历史
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)
                    
                    # 只有在训练阶段才进行反向传播和优化
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                
                # 统计
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
            
            # 计算epoch指标
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
            
            # 计算分类报告
            report = classification_report(all_labels, all_preds, output_dict=True, zero_division=0)
            macro_precision = report['macro avg']['precision']
            macro_recall = report['macro avg']['recall']
            macro_f1 = report['macro avg']['f1-score']
            
            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f} '
                  f'Precision: {macro_precision:.4f} Recall: {macro_recall:.4f} F1: {macro_f1:.4f}')
            
            # 记录历史
            history[f'{phase}_loss'].append(epoch_loss)
            history[f'{phase}_acc'].append(epoch_acc.item())
            history[f'{phase}_precision'].append(macro_precision)
            history[f'{phase}_recall'].append(macro_recall)
            history[f'{phase}_f1'].append(macro_f1)
            
            # 深拷贝模型
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                torch.save(model.state_dict(), f"{model_name}_best_model.pth")
        
        print()
    
    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')
    
    # 加载最佳模型权重
    model.load_state_dict(best_model_wts)
    return model, history

def evaluate_model(model, dataloader, class_names, model_name):
    """评估模型"""
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # 打印分类报告
    report = classification_report(all_labels, all_preds, target_names=class_names, output_dict=True)
    print(f"\n{model_name} Test Metrics:")
    print(f"Accuracy: {report['accuracy']:.4f}")
    print(f"Macro Precision: {report['macro avg']['precision']:.4f}")
    print(f"Macro Recall: {report['macro avg']['recall']:.4f}")
    print(f"Macro F1-Score: {report['macro avg']['f1-score']:.4f}")
    
    # 计算混淆矩阵
    cm = confusion_matrix(all_labels, all_preds)
    
    return {
        'accuracy': report['accuracy'],
        'precision': report['macro avg']['precision'],
        'recall': report['macro avg']['recall'],
        'f1': report['macro avg']['f1-score'],
        'confusion_matrix': cm
    }

def plot_training_curves(histories, metrics, save_path):
    """绘制训练曲线"""
    plt.figure(figsize=(15, 10))
    
    for i, metric in enumerate(metrics):
        plt.subplot(2, 3, i+1)
        
        for model_name, history in histories.items():
            plt.plot(history[f'train_{metric}'], label=f'{model_name} Train')
            plt.plot(history[f'val_{metric}'], label=f'{model_name} Val')
        
        plt.title(f'Model {metric.capitalize()}')
        plt.xlabel('Epoch')
        plt.ylabel(metric.capitalize())
        plt.legend()
        plt.grid(True)
    
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

def compare_models_table(model_metrics, model_sizes):
    """生成模型对比表格"""
    print("\n===== Model Comparison =====")
    print(f"{'Model':<15} {'Accuracy(%)':<12} {'Precision(%)':<12} {'Recall(%)':<12} {'F1-Score(%)':<12} {'Params(MB)':<12}")
    print("-" * 75)
    
    for model_name, metrics in model_metrics.items():
        print(f"{model_name:<15} {metrics['accuracy']*100:<12.2f} {metrics['precision']*100:<12.2f} "
              f"{metrics['recall']*100:<12.2f} {metrics['f1']*100:<12.2f} {model_sizes[model_name]:<12.2f}")

def main():
    """主函数"""
    # 创建数据加载器
    dataloaders, dataset_sizes, class_names = create_data_loaders(DATA_DIR)
    print(f"Classes: {class_names}")
    num_classes = len(class_names)
    
    all_histories = {}
    all_metrics = {}
    model_sizes = {}
    
    # 训练所有模型
    for model_name in MODELS.keys():
        print(f"\n===== Training {model_name} model =====")
        
        # 初始化模型
        model = initialize_model(model_name, num_classes)
        model = model.to(device)
        
        # 计算模型大小
        model_size = count_parameters(model)
        model_sizes[model_name] = model_size
        print(f"{model_name} Parameters: {model_size:.2f} MB")
        
        # 设置损失函数和优化器
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=0.001)
        
        # 训练模型
        model, history = train_model(
            model, dataloaders, criterion, optimizer,
            num_epochs=EPOCHS, model_name=model_name
        )
        
        # 评估模型
        print(f"\nEvaluating {model_name} model:")
        metrics = evaluate_model(model, dataloaders['test'], class_names, model_name)
        all_metrics[model_name] = metrics
        
        # 保存训练历史
        all_histories[model_name] = history
    
    # 绘制训练曲线
    plot_training_curves(
        all_histories, 
        metrics=['loss', 'acc', 'precision', 'recall', 'f1'],
        save_path='training_curves_comparison.png'
    )
    
    # 生成模型对比表格
    compare_models_table(all_metrics, model_sizes)
    
    print("\n===== Training and evaluation completed for all models =====")
    print("Training curves saved to 'training_curves_comparison.png'")

if __name__ == "__main__":
    main()

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda:0
Classes: ['Bacterial Leaf Blight', 'Brown Spot', 'Healthy Rice Leaf', 'Leaf Blast', 'Leaf scald', 'Narrow Brown Leaf Spot', 'Neck Blast', 'Rice Hispa', 'Sheath Blight']

===== Training efficientnet-b0 model =====

efficientnet-b0 Classifier:
Sequential(
  (0): Dropout(p=0.2, inplace=True)
  (1): Linear(in_features=1280, out_features=9, bias=True)
)




efficientnet-b0 Parameters: 15.33 MB
Epoch 0/29
----------
train Loss: 0.9182 Acc: 0.6795 Precision: 0.6755 Recall: 0.6778 F1: 0.6745
val Loss: 0.4494 Acc: 0.8413 Precision: 0.8494 Recall: 0.8414 F1: 0.8417

Epoch 1/29
----------
train Loss: 0.5360 Acc: 0.8162 Precision: 0.8152 Recall: 0.8147 F1: 0.8145
val Loss: 0.3300 Acc: 0.8871 Precision: 0.8956 Recall: 0.8939 F1: 0.8886

Epoch 2/29
----------
train Loss: 0.4119 Acc: 0.8601 Precision: 0.8597 Recall: 0.8586 F1: 0.8590
val Loss: 0.2110 Acc: 0.9198 Precision: 0.9191 Recall: 0.9188 F1: 0.9175

Epoch 3/29
----------
train Loss: 0.3296 Acc: 0.8897 Precision: 0.8899 Recall: 0.8897 F1: 0.8896
val Loss: 0.2328 Acc: 0.9193 Precision: 0.9214 Recall: 0.9225 F1: 0.9184

Epoch 4/29
----------
train Loss: 0.3005 Acc: 0.8981 Precision: 0.8994 Recall: 0.8987 F1: 0.8989
val Loss: 0.1753 Acc: 0.9435 Precision: 0.9437 Recall: 0.9437 F1: 0.9428

Epoch 5/29
----------
train Loss: 0.2586 Acc: 0.9131 Precision: 0.9133 Recall: 0.9126 F1: 0.9129
val Loss: 0




resnet50 Classifier:
Linear(in_features=2048, out_features=9, bias=True)
resnet50 Parameters: 89.75 MB
Epoch 0/29
----------
train Loss: 1.4992 Acc: 0.4642 Precision: 0.4462 Recall: 0.4607 F1: 0.4459
val Loss: 1.0697 Acc: 0.6273 Precision: 0.6626 Recall: 0.6198 F1: 0.5963

Epoch 1/29
----------
train Loss: 1.0947 Acc: 0.6173 Precision: 0.6111 Recall: 0.6136 F1: 0.6059
val Loss: 0.7891 Acc: 0.7357 Precision: 0.7433 Recall: 0.7397 F1: 0.7366

Epoch 2/29
----------
train Loss: 0.9745 Acc: 0.6561 Precision: 0.6496 Recall: 0.6518 F1: 0.6477
val Loss: 0.9443 Acc: 0.6595 Precision: 0.6835 Recall: 0.6636 F1: 0.6579

Epoch 3/29
----------
train Loss: 0.8988 Acc: 0.6784 Precision: 0.6724 Recall: 0.6732 F1: 0.6701
val Loss: 0.6926 Acc: 0.7623 Precision: 0.7683 Recall: 0.7618 F1: 0.7596

Epoch 4/29
----------
train Loss: 0.7921 Acc: 0.7214 Precision: 0.7187 Recall: 0.7169 F1: 0.7157
val Loss: 0.9927 Acc: 0.6827 Precision: 0.7748 Recall: 0.6747 F1: 0.6584

Epoch 5/29
----------
train Loss: 0.7372 



train Loss: 1.1437 Acc: 0.6117 Precision: 0.6114 Recall: 0.5975 F1: 0.5831
val Loss: 0.6938 Acc: 0.7595 Precision: 0.7871 Recall: 0.7529 F1: 0.7494

Epoch 1/29
----------
train Loss: 0.6870 Acc: 0.7574 Precision: 0.7503 Recall: 0.7519 F1: 0.7499
val Loss: 0.3995 Acc: 0.8498 Precision: 0.8497 Recall: 0.8429 F1: 0.8427

Epoch 2/29
----------
train Loss: 0.5446 Acc: 0.8087 Precision: 0.8081 Recall: 0.8056 F1: 0.8061
val Loss: 0.3277 Acc: 0.8826 Precision: 0.8879 Recall: 0.8757 F1: 0.8783

Epoch 3/29
----------
train Loss: 0.4734 Acc: 0.8364 Precision: 0.8359 Recall: 0.8349 F1: 0.8350
val Loss: 0.2961 Acc: 0.9063 Precision: 0.9059 Recall: 0.9069 F1: 0.9054

Epoch 4/29
----------
train Loss: 0.4002 Acc: 0.8644 Precision: 0.8638 Recall: 0.8627 F1: 0.8630
val Loss: 0.2863 Acc: 0.8950 Precision: 0.8958 Recall: 0.8948 F1: 0.8941

Epoch 5/29
----------
train Loss: 0.3569 Acc: 0.8780 Precision: 0.8766 Recall: 0.8760 F1: 0.8761
val Loss: 0.2881 Acc: 0.9051 Precision: 0.9082 Recall: 0.9049 F1: 0.90



train Loss: 0.9253 Acc: 0.6716 Precision: 0.6663 Recall: 0.6681 F1: 0.6653
val Loss: 0.7242 Acc: 0.7708 Precision: 0.7867 Recall: 0.7763 F1: 0.7767

Epoch 1/29
----------
train Loss: 0.6078 Acc: 0.7925 Precision: 0.7909 Recall: 0.7902 F1: 0.7901
val Loss: 0.4118 Acc: 0.8532 Precision: 0.8581 Recall: 0.8585 F1: 0.8530

Epoch 2/29
----------
train Loss: 0.4731 Acc: 0.8409 Precision: 0.8407 Recall: 0.8400 F1: 0.8402
val Loss: 0.2722 Acc: 0.9068 Precision: 0.9098 Recall: 0.9026 F1: 0.9040

Epoch 3/29
----------
train Loss: 0.3682 Acc: 0.8734 Precision: 0.8740 Recall: 0.8725 F1: 0.8731
val Loss: 0.3018 Acc: 0.9119 Precision: 0.9237 Recall: 0.9079 F1: 0.9115

Epoch 4/29
----------
train Loss: 0.3176 Acc: 0.8939 Precision: 0.8934 Recall: 0.8933 F1: 0.8933
val Loss: 0.2964 Acc: 0.9001 Precision: 0.9063 Recall: 0.8953 F1: 0.8973

Epoch 5/29
----------
train Loss: 0.2858 Acc: 0.9025 Precision: 0.9029 Recall: 0.9023 F1: 0.9025
val Loss: 0.2770 Acc: 0.9006 Precision: 0.9083 Recall: 0.9042 F1: 0.90

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, models, transforms
from torchvision.models.efficientnet import MBConv
import time
import copy
import shutil
from tqdm import tqdm

# 设置随机种子确保结果可复现
torch.manual_seed(42)
np.random.seed(42)

# 配置参数
DATA_DIR = r"D:\homework\论文\论文\project\dataset\archive\Rice_Leaf_AUG\Rice_Leaf_AUG"  # 原始数据集路径
OUTPUT_DIR = "dataset_split"  # 划分后的数据集保存路径
IMAGE_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 30
NUM_WORKERS = 4  # 数据加载的线程数
SPLIT_RATIO = [0.7, 0.15, 0.15]  # 训练集、验证集、测试集比例

# 模型配置
MODELS = {
    "efficientnet-b0": {
        "model": models.efficientnet_b0,
        "pretrained": True,
        "feature_size": 1280  # EfficientNet-B0的特征维度
    },
}

# 检查GPU是否可用
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")

# 注意力模块定义
class ChannelAttention(nn.Module):
    def __init__(self, channels, reduction_ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(channels, channels // reduction_ratio, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(channels // reduction_ratio, channels, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return x * self.sigmoid(out)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        out = torch.cat([avg_out, max_out], dim=1)
        out = self.conv(out)
        return x * self.sigmoid(out)

class CBAM(nn.Module):
    def __init__(self, channels, reduction_ratio=16):
        super(CBAM, self).__init__()
        self.channel_attention = ChannelAttention(channels, reduction_ratio)
        self.spatial_attention = SpatialAttention()
    
    def forward(self, x):
        x = self.channel_attention(x)
        x = self.spatial_attention(x)
        return x

def add_cbam_to_efficientnet(model):
    """为EfficientNet添加CBAM模块"""
    features = []
    for layer in model.features.children():
        features.append(layer)
        if isinstance(layer, MBConv):
            # 获取当前MBConv层的输出通道数
            out_channels = layer.out_channels
            # 添加CBAM模块
            features.append(CBAM(out_channels))
    # 重构features模块
    model.features = nn.Sequential(*features)
    return model

def split_dataset(data_dir, output_dir, split_ratio):
    """将数据集划分为训练集、验证集和测试集"""
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    subdirs = [f.name for f in os.scandir(data_dir) if f.is_dir()]
    
    for subdir in subdirs:
        print(f"Processing class: {subdir}")
        class_dir = os.path.join(data_dir, subdir)
        files = [f for f in os.listdir(class_dir) if os.path.isfile(os.path.join(class_dir, f))]
        
        # 划分数据集
        train_files, test_files = train_test_split(files, test_size=split_ratio[2], random_state=42)
        train_files, val_files = train_test_split(train_files, test_size=split_ratio[1]/(split_ratio[0]+split_ratio[1]), random_state=42)
        
        # 创建输出目录
        for split in ["train", "val", "test"]:
            split_dir = os.path.join(output_dir, split, subdir)
            os.makedirs(split_dir, exist_ok=True)
        
        # 复制文件
        def copy_files(files_list, split_name):
            for file_name in tqdm(files_list, desc=f"Copying {split_name} files"):
                src = os.path.join(class_dir, file_name)
                dst = os.path.join(output_dir, split_name, subdir, file_name)
                shutil.copy(src, dst)
        
        copy_files(train_files, "train")
        copy_files(val_files, "val")
        copy_files(test_files, "test")
    
    print(f"Dataset split completed. Saved to {output_dir}")
    return output_dir

def create_data_transforms():
    data_transforms = {
        'train': transforms.Compose([
            transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.8, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(p=0.1),
            transforms.RandomRotation(30),
            transforms.ColorJitter(
                brightness=0.3,
                contrast=0.3,
                saturation=0.3,
                hue=0.1
            ),
            transforms.RandomGrayscale(p=0.1),
            transforms.RandomAffine(
                degrees=0,
                translate=(0.1, 0.1),
                scale=(0.9, 1.1),
                shear=10
            ),
            transforms.RandomPerspective(distortion_scale=0.2, p=0.2),
            transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.Resize(IMAGE_SIZE),
            transforms.CenterCrop(IMAGE_SIZE),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'test': transforms.Compose([
            transforms.Resize(IMAGE_SIZE),
            transforms.CenterCrop(IMAGE_SIZE),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }
    return data_transforms

def create_data_loaders(data_dir):
    """创建数据加载器"""
    data_transforms = create_data_transforms()
    
    # 创建数据集
    image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                              data_transforms[x])
                      for x in ['train', 'val', 'test']}
    
    # 创建数据加载器
    dataloaders = {x: DataLoader(image_datasets[x], batch_size=BATCH_SIZE,
                                 shuffle=True if x == 'train' else False,
                                 num_workers=NUM_WORKERS)
                   for x in ['train', 'val', 'test']}
    
    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val', 'test']}
    class_names = image_datasets['train'].classes
    
    return dataloaders, dataset_sizes, class_names

def initialize_model(model_name, num_classes, feature_extract=True):
    """初始化预训练模型"""
    model_info = MODELS[model_name]
    model_ft = model_info["model"](pretrained=model_info["pretrained"])
    
    # 添加注意力模块（仅对efficientnet）
    if model_name == "efficientnet-b0":
        model_ft = add_cbam_to_efficientnet(model_ft)
    
    # 冻结预训练模型的参数
    if feature_extract:
        for param in model_ft.parameters():
            param.requires_grad = False
    
    # 修改分类器
    if model_name == "efficientnet-b0":
        in_features = model_ft.classifier[1].in_features
        model_ft.classifier[1] = nn.Linear(in_features, num_classes)
    
    return model_ft

def train_model(model, dataloaders, criterion, optimizer, scheduler, num_epochs=25, model_name="model"):
    """训练模型"""
    since = time.time()
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    # 记录训练历史
    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': []
    }
    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs-1}')
        print('-' * 10)
        
        # 每个epoch都有一个训练和验证阶段
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()
            
            running_loss = 0.0
            running_corrects = 0
            
            # 迭代数据
            progress_bar = tqdm(enumerate(dataloaders[phase]), total=len(dataloaders[phase]))
            for i, (inputs, labels) in progress_bar:
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                # 零梯度
                optimizer.zero_grad()
                
                # 前向传播
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)
                    
                    # 反向传播+优化
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                
                # 统计
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                
                # 更新进度条
                progress_bar.set_description(f"{phase} Loss: {running_loss/(i*BATCH_SIZE+inputs.size(0)):.4f} Acc: {running_corrects/(i*BATCH_SIZE+inputs.size(0)):.4f}")
            
            if phase == 'train' and scheduler:
                scheduler.step()
            
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
            
            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
            
            # 记录历史
            history[f'{phase}_loss'].append(epoch_loss)
            history[f'{phase}_acc'].append(epoch_acc.item())
            
            # 保存最佳模型
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                torch.save(model.state_dict(), f"{model_name}_best_model.pth")
        
        print()
    
    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')
    
    # 加载最佳模型权重
    model.load_state_dict(best_model_wts)
    return model, history

def evaluate_model(model, dataloader, class_names, model_name):
    """评估模型"""
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Evaluating"):
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # 打印分类报告+
    print(f"\n{classification_report(all_labels, all_preds, target_names=class_names)}")
    
    # 计算混淆矩阵
    cm = confusion_matrix(all_labels, all_preds)
    plot_confusion_matrix(cm, class_names, f"{model_name}_confusion_matrix.png")
    
    return all_preds, all_labels

def plot_confusion_matrix(cm, classes, save_path, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
    """绘制混淆矩阵"""
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')
    
    plt.figure(figsize=(10, 10))
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)
    
    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, format(cm[i, j], fmt),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")
    
    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.savefig(save_path)
    plt.close()

def plot_training_curves(histories, save_path='training_curves.png'):
    """绘制训练曲线"""
    plt.figure(figsize=(12, 10))
    
    # 绘制准确率曲线
    plt.subplot(2, 1, 1)
    for model_name, history in histories.items():
        plt.plot(history['train_acc'], label=f'{model_name} Train')
        plt.plot(history['val_acc'], label=f'{model_name} Val')
    
    plt.title('Model Accuracy')
    plt.ylabel('Accuracy')
    plt.xlabel('Epoch')
    plt.legend(loc='lower right')
    plt.grid(True)
    
    # 绘制损失曲线
    plt.subplot(2, 1, 2)
    for model_name, history in histories.items():
        plt.plot(history['train_loss'], label=f'{model_name} Train')
        plt.plot(history['val_loss'], label=f'{model_name} Val')
    
    plt.title('Model Loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(loc='upper right')
    plt.grid(True)
    
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

def main():
    """主函数"""
    # 划分数据集
    split_dir = split_dataset(DATA_DIR, OUTPUT_DIR, SPLIT_RATIO)
    
    # 创建数据加载器
    dataloaders, dataset_sizes, class_names = create_data_loaders(split_dir)
    print(f"Classes: {class_names}")
    num_classes = len(class_names)
    
    all_histories = {}
    
    # 训练所有模型
    for model_name in MODELS.keys():
        print(f"\n===== Training {model_name} model =====")
        
        # 初始化模型
        model_ft = initialize_model(model_name, num_classes, feature_extract=False)
        model_ft = model_ft.to(device)
        
        # 设置损失函数和优化器
        criterion = nn.CrossEntropyLoss()
        optimizer_ft = optim.Adam(model_ft.parameters(), lr=0.001)
        exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
        
        # 训练模型
        model_ft, history = train_model(
            model_ft, dataloaders, criterion, optimizer_ft, exp_lr_scheduler,
            num_epochs=EPOCHS, model_name=model_name
        )
        
        # 保存最终模型
        torch.save(model_ft.state_dict(), f"{model_name}_final_model.pth")
        
        # 评估模型
        print(f"\nEvaluating {model_name} model:")
        evaluate_model(model_ft, dataloaders['test'], class_names, model_name)
        
        # 保存训练历史
        all_histories[model_name] = history
    
    # 绘制训练曲线
    plot_training_curves(all_histories, save_path='all_models_training_curves.png')
    
    print("\n===== Training completed for all models =====")

if __name__ == "__main__":
    main()

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda:0
GPU: NVIDIA GeForce RTX 3060 Ti
Processing class: Bacterial Leaf Blight


Copying train files: 100%|██████████| 837/837 [00:05<00:00, 149.62it/s]
Copying val files: 100%|██████████| 180/180 [00:01<00:00, 149.75it/s]
Copying test files: 100%|██████████| 180/180 [00:01<00:00, 154.15it/s]


Processing class: Brown Spot


Copying train files: 100%|██████████| 1082/1082 [00:06<00:00, 154.82it/s]
Copying val files: 100%|██████████| 232/232 [00:01<00:00, 158.13it/s]
Copying test files: 100%|██████████| 232/232 [00:01<00:00, 161.76it/s]


Processing class: Healthy Rice Leaf


Copying train files: 100%|██████████| 759/759 [00:05<00:00, 151.58it/s]
Copying val files: 100%|██████████| 163/163 [00:01<00:00, 155.35it/s]
Copying test files: 100%|██████████| 163/163 [00:01<00:00, 149.88it/s]


Processing class: Leaf Blast


Copying train files: 100%|██████████| 1222/1222 [00:07<00:00, 154.88it/s]
Copying val files: 100%|██████████| 263/263 [00:01<00:00, 152.45it/s]
Copying test files: 100%|██████████| 263/263 [00:01<00:00, 152.53it/s]


Processing class: Leaf scald


Copying train files: 100%|██████████| 932/932 [00:05<00:00, 156.28it/s]
Copying val files: 100%|██████████| 200/200 [00:01<00:00, 156.51it/s]
Copying test files: 100%|██████████| 200/200 [00:01<00:00, 155.77it/s]


Processing class: Narrow Brown Leaf Spot


Copying train files: 100%|██████████| 667/667 [00:04<00:00, 155.44it/s]
Copying val files: 100%|██████████| 143/143 [00:00<00:00, 160.08it/s]
Copying test files: 100%|██████████| 144/144 [00:00<00:00, 151.24it/s]


Processing class: Neck Blast


Copying train files: 100%|██████████| 700/700 [00:05<00:00, 122.61it/s]
Copying val files: 100%|██████████| 150/150 [00:01<00:00, 120.61it/s]
Copying test files: 100%|██████████| 150/150 [00:01<00:00, 120.36it/s]


Processing class: Rice Hispa


Copying train files: 100%|██████████| 909/909 [00:05<00:00, 157.99it/s]
Copying val files: 100%|██████████| 195/195 [00:01<00:00, 161.39it/s]
Copying test files: 100%|██████████| 195/195 [00:01<00:00, 159.17it/s]


Processing class: Sheath Blight


Copying train files: 100%|██████████| 1139/1139 [00:07<00:00, 154.80it/s]
Copying val files: 100%|██████████| 245/245 [00:01<00:00, 153.83it/s]
Copying test files: 100%|██████████| 245/245 [00:01<00:00, 158.84it/s]


Dataset split completed. Saved to dataset_split
Classes: ['Bacterial Leaf Blight', 'Brown Spot', 'Healthy Rice Leaf', 'Leaf Blast', 'Leaf scald', 'Narrow Brown Leaf Spot', 'Neck Blast', 'Rice Hispa', 'Sheath Blight']

===== Training efficientnet-b0 model =====
Epoch 0/29
----------


train Loss: 0.9142 Acc: 0.6846: 100%|██████████| 258/258 [01:05<00:00,  3.97it/s]


train Loss: 0.9142 Acc: 0.6846


val Loss: 0.5424 Acc: 0.8204: 100%|██████████| 56/56 [00:09<00:00,  5.78it/s]


val Loss: 0.5424 Acc: 0.8204

Epoch 1/29
----------


train Loss: 0.5424 Acc: 0.8152: 100%|██████████| 258/258 [00:55<00:00,  4.66it/s]


train Loss: 0.5424 Acc: 0.8152


val Loss: 0.3305 Acc: 0.8871: 100%|██████████| 56/56 [00:08<00:00,  6.76it/s]


val Loss: 0.3305 Acc: 0.8871

Epoch 2/29
----------


train Loss: 0.4165 Acc: 0.8598: 100%|██████████| 258/258 [00:55<00:00,  4.62it/s]


train Loss: 0.4165 Acc: 0.8598


val Loss: 0.3152 Acc: 0.9130: 100%|██████████| 56/56 [00:08<00:00,  6.53it/s]


val Loss: 0.3152 Acc: 0.9130

Epoch 3/29
----------


train Loss: 0.3429 Acc: 0.8836: 100%|██████████| 258/258 [00:54<00:00,  4.77it/s]


train Loss: 0.3429 Acc: 0.8836


val Loss: 0.1845 Acc: 0.9356: 100%|██████████| 56/56 [00:08<00:00,  7.00it/s]


val Loss: 0.1845 Acc: 0.9356

Epoch 4/29
----------


train Loss: 0.3020 Acc: 0.8952: 100%|██████████| 258/258 [00:53<00:00,  4.79it/s]


train Loss: 0.3020 Acc: 0.8952


val Loss: 0.1369 Acc: 0.9571: 100%|██████████| 56/56 [00:08<00:00,  6.56it/s]


val Loss: 0.1369 Acc: 0.9571

Epoch 5/29
----------


train Loss: 0.2434 Acc: 0.9205: 100%|██████████| 258/258 [00:54<00:00,  4.69it/s]


train Loss: 0.2434 Acc: 0.9205


val Loss: 0.1997 Acc: 0.9413: 100%|██████████| 56/56 [00:08<00:00,  6.92it/s]


val Loss: 0.1997 Acc: 0.9413

Epoch 6/29
----------


train Loss: 0.2220 Acc: 0.9228: 100%|██████████| 258/258 [00:55<00:00,  4.65it/s]


train Loss: 0.2220 Acc: 0.9228


val Loss: 0.1019 Acc: 0.9667: 100%|██████████| 56/56 [00:08<00:00,  6.86it/s]


val Loss: 0.1019 Acc: 0.9667

Epoch 7/29
----------


train Loss: 0.1326 Acc: 0.9537: 100%|██████████| 258/258 [00:55<00:00,  4.67it/s]


train Loss: 0.1326 Acc: 0.9537


val Loss: 0.0449 Acc: 0.9848: 100%|██████████| 56/56 [00:08<00:00,  6.64it/s]


val Loss: 0.0449 Acc: 0.9848

Epoch 8/29
----------


train Loss: 0.0893 Acc: 0.9720: 100%|██████████| 258/258 [00:54<00:00,  4.73it/s]


train Loss: 0.0893 Acc: 0.9720


val Loss: 0.0273 Acc: 0.9910: 100%|██████████| 56/56 [00:08<00:00,  6.89it/s]


val Loss: 0.0273 Acc: 0.9910

Epoch 9/29
----------


train Loss: 0.0810 Acc: 0.9742: 100%|██████████| 258/258 [00:54<00:00,  4.72it/s]


train Loss: 0.0810 Acc: 0.9742


val Loss: 0.0252 Acc: 0.9932: 100%|██████████| 56/56 [00:08<00:00,  6.85it/s]


val Loss: 0.0252 Acc: 0.9932

Epoch 10/29
----------


train Loss: 0.0683 Acc: 0.9777: 100%|██████████| 258/258 [00:54<00:00,  4.75it/s]


train Loss: 0.0683 Acc: 0.9777


val Loss: 0.0206 Acc: 0.9910: 100%|██████████| 56/56 [00:08<00:00,  6.94it/s]


val Loss: 0.0206 Acc: 0.9910

Epoch 11/29
----------


train Loss: 0.0698 Acc: 0.9762: 100%|██████████| 258/258 [00:53<00:00,  4.80it/s]


train Loss: 0.0698 Acc: 0.9762


val Loss: 0.0147 Acc: 0.9955: 100%|██████████| 56/56 [00:08<00:00,  6.68it/s]


val Loss: 0.0147 Acc: 0.9955

Epoch 12/29
----------


train Loss: 0.0533 Acc: 0.9823: 100%|██████████| 258/258 [00:54<00:00,  4.75it/s]


train Loss: 0.0533 Acc: 0.9823


val Loss: 0.0138 Acc: 0.9966: 100%|██████████| 56/56 [00:08<00:00,  6.88it/s]


val Loss: 0.0138 Acc: 0.9966

Epoch 13/29
----------


train Loss: 0.0536 Acc: 0.9845: 100%|██████████| 258/258 [00:54<00:00,  4.73it/s]


train Loss: 0.0536 Acc: 0.9845


val Loss: 0.0122 Acc: 0.9955: 100%|██████████| 56/56 [00:08<00:00,  6.73it/s]


val Loss: 0.0122 Acc: 0.9955

Epoch 14/29
----------


train Loss: 0.0479 Acc: 0.9859: 100%|██████████| 258/258 [00:53<00:00,  4.85it/s]


train Loss: 0.0479 Acc: 0.9859


val Loss: 0.0122 Acc: 0.9972: 100%|██████████| 56/56 [00:08<00:00,  6.85it/s]


val Loss: 0.0122 Acc: 0.9972

Epoch 15/29
----------


train Loss: 0.0447 Acc: 0.9853: 100%|██████████| 258/258 [00:54<00:00,  4.76it/s]


train Loss: 0.0447 Acc: 0.9853


val Loss: 0.0126 Acc: 0.9955: 100%|██████████| 56/56 [00:08<00:00,  6.72it/s]


val Loss: 0.0126 Acc: 0.9955

Epoch 16/29
----------


train Loss: 0.0438 Acc: 0.9851: 100%|██████████| 258/258 [00:53<00:00,  4.80it/s]


train Loss: 0.0438 Acc: 0.9851


val Loss: 0.0102 Acc: 0.9966: 100%|██████████| 56/56 [00:08<00:00,  6.89it/s]


val Loss: 0.0102 Acc: 0.9966

Epoch 17/29
----------


train Loss: 0.0416 Acc: 0.9858: 100%|██████████| 258/258 [00:53<00:00,  4.80it/s]


train Loss: 0.0416 Acc: 0.9858


val Loss: 0.0103 Acc: 0.9960: 100%|██████████| 56/56 [00:08<00:00,  6.75it/s]


val Loss: 0.0103 Acc: 0.9960

Epoch 18/29
----------


train Loss: 0.0411 Acc: 0.9871: 100%|██████████| 258/258 [00:53<00:00,  4.86it/s]


train Loss: 0.0411 Acc: 0.9871


val Loss: 0.0101 Acc: 0.9972: 100%|██████████| 56/56 [00:08<00:00,  6.81it/s]


val Loss: 0.0101 Acc: 0.9972

Epoch 19/29
----------


train Loss: 0.0477 Acc: 0.9851: 100%|██████████| 258/258 [00:53<00:00,  4.80it/s]


train Loss: 0.0477 Acc: 0.9851


val Loss: 0.0106 Acc: 0.9960: 100%|██████████| 56/56 [00:08<00:00,  6.98it/s]


val Loss: 0.0106 Acc: 0.9960

Epoch 20/29
----------


train Loss: 0.0414 Acc: 0.9867: 100%|██████████| 258/258 [00:54<00:00,  4.77it/s]


train Loss: 0.0414 Acc: 0.9867


val Loss: 0.0099 Acc: 0.9972: 100%|██████████| 56/56 [00:08<00:00,  6.90it/s]


val Loss: 0.0099 Acc: 0.9972

Epoch 21/29
----------


train Loss: 0.0429 Acc: 0.9853: 100%|██████████| 258/258 [00:53<00:00,  4.81it/s]


train Loss: 0.0429 Acc: 0.9853


val Loss: 0.0088 Acc: 0.9977: 100%|██████████| 56/56 [00:08<00:00,  6.90it/s]


val Loss: 0.0088 Acc: 0.9977

Epoch 22/29
----------


train Loss: 0.0423 Acc: 0.9870: 100%|██████████| 258/258 [00:54<00:00,  4.74it/s]


train Loss: 0.0423 Acc: 0.9870


val Loss: 0.0096 Acc: 0.9977: 100%|██████████| 56/56 [00:08<00:00,  6.77it/s]


val Loss: 0.0096 Acc: 0.9977

Epoch 23/29
----------


train Loss: 0.0399 Acc: 0.9880: 100%|██████████| 258/258 [00:53<00:00,  4.79it/s]


train Loss: 0.0399 Acc: 0.9880


val Loss: 0.0090 Acc: 0.9977: 100%|██████████| 56/56 [00:08<00:00,  6.88it/s]


val Loss: 0.0090 Acc: 0.9977

Epoch 24/29
----------


train Loss: 0.0378 Acc: 0.9869: 100%|██████████| 258/258 [00:53<00:00,  4.85it/s]


train Loss: 0.0378 Acc: 0.9869


val Loss: 0.0092 Acc: 0.9972: 100%|██████████| 56/56 [00:07<00:00,  7.03it/s]


val Loss: 0.0092 Acc: 0.9972

Epoch 25/29
----------


train Loss: 0.0378 Acc: 0.9880: 100%|██████████| 258/258 [00:53<00:00,  4.84it/s]


train Loss: 0.0378 Acc: 0.9880


val Loss: 0.0090 Acc: 0.9977: 100%|██████████| 56/56 [00:07<00:00,  7.10it/s]


val Loss: 0.0090 Acc: 0.9977

Epoch 26/29
----------


train Loss: 0.0360 Acc: 0.9887: 100%|██████████| 258/258 [00:53<00:00,  4.78it/s]


train Loss: 0.0360 Acc: 0.9887


val Loss: 0.0088 Acc: 0.9983: 100%|██████████| 56/56 [00:08<00:00,  6.89it/s]


val Loss: 0.0088 Acc: 0.9983

Epoch 27/29
----------


train Loss: 0.0411 Acc: 0.9878: 100%|██████████| 258/258 [00:53<00:00,  4.79it/s]


train Loss: 0.0411 Acc: 0.9878


val Loss: 0.0092 Acc: 0.9972: 100%|██████████| 56/56 [00:07<00:00,  7.03it/s]


val Loss: 0.0092 Acc: 0.9972

Epoch 28/29
----------


train Loss: 0.0388 Acc: 0.9873: 100%|██████████| 258/258 [00:53<00:00,  4.83it/s]


train Loss: 0.0388 Acc: 0.9873


val Loss: 0.0090 Acc: 0.9972: 100%|██████████| 56/56 [00:07<00:00,  7.04it/s]


val Loss: 0.0090 Acc: 0.9972

Epoch 29/29
----------


train Loss: 0.0400 Acc: 0.9863: 100%|██████████| 258/258 [00:53<00:00,  4.86it/s]


train Loss: 0.0400 Acc: 0.9863


val Loss: 0.0092 Acc: 0.9977: 100%|██████████| 56/56 [00:08<00:00,  6.86it/s]


val Loss: 0.0092 Acc: 0.9977

Training complete in 42m 13s
Best val Acc: 0.998306

Evaluating efficientnet-b0 model:


Evaluating: 100%|██████████| 56/56 [00:20<00:00,  2.79it/s]



                        precision    recall  f1-score   support

 Bacterial Leaf Blight       0.99      0.98      0.99       180
            Brown Spot       1.00      0.99      1.00       232
     Healthy Rice Leaf       1.00      1.00      1.00       163
            Leaf Blast       0.98      1.00      0.99       263
            Leaf scald       1.00      0.98      0.99       200
Narrow Brown Leaf Spot       0.99      1.00      0.99       144
            Neck Blast       1.00      1.00      1.00       150
            Rice Hispa       1.00      1.00      1.00       195
         Sheath Blight       1.00      1.00      1.00       245

              accuracy                           0.99      1772
             macro avg       0.99      0.99      0.99      1772
          weighted avg       0.99      0.99      0.99      1772

Confusion matrix, without normalization

===== Training completed for all models =====


: 

In [3]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms
from sklearn.metrics import classification_report, confusion_matrix
import time
import copy

# 设置随机种子确保结果可复现
torch.manual_seed(42)
np.random.seed(42)

# 配置参数
DATA_DIR = r"D:\homework\论文\论文\project\code\训练\dataset_split"
IMAGE_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 1
NUM_WORKERS = 4
FREEZE_LAYERS = True  # 是否冻结部分层

# 模型配置
MODELS = {
        "efficientnet-b0": {
        "model": models.efficientnet_b0,
        "pretrained": True,
        "feature_size": 1280,
        "classifier": lambda feat_size, num_classes: nn.Linear(feat_size, num_classes),
        "freeze_func": lambda model: [
            # 冻结特征提取器的前半部分
            param.requires_grad_(False) for param in list(model.features.parameters())[:len(list(model.features.parameters()))//2]
        ]
    },
    "resnet50": {
        "model": models.resnet50,
        "pretrained": True,
        "feature_size": 2048,
        "classifier": lambda feat_size, num_classes: nn.Linear(feat_size, num_classes),
        "freeze_func": lambda model: [
            # 冻结前3个残差块（共4个）
            param.requires_grad_(False) for param in list(model.parameters())[:len(list(model.parameters()))//2]
        ]
    },
    "shufflenet_v2": {
        "model": models.shufflenet_v2_x1_0,
        "pretrained": True,
        "feature_size": 1024,
        "classifier": lambda feat_size, num_classes: nn.Linear(feat_size, num_classes),
        "freeze_func": lambda model: [
            # 冻结前半部分层
            param.requires_grad_(False) for param in list(model.parameters())[:len(list(model.parameters()))//2]
        ]
    },

    "mobilenet_v3": {
        "model": models.mobilenet_v3_large,
        "pretrained": True,
        "feature_size": 960,
        "classifier": lambda feat_size, num_classes: nn.Sequential(
            nn.Linear(feat_size, 1280),
            nn.Hardswish(),
            nn.Dropout(p=0.2, inplace=True),
            nn.Linear(1280, num_classes)
        ),
        "freeze_func": lambda model: [
            # 冻结特征提取器的前半部分
            param.requires_grad_(False) for param in list(model.features.parameters())[:len(list(model.features.parameters()))//2]
        ]
    }
}

# 检查GPU是否可用
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def create_data_transforms():
    """创建数据转换"""
    data_transforms = {
        'train': transforms.Compose([
            transforms.RandomResizedCrop(IMAGE_SIZE),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.Resize(IMAGE_SIZE),
            transforms.CenterCrop(IMAGE_SIZE),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'test': transforms.Compose([
            transforms.Resize(IMAGE_SIZE),
            transforms.CenterCrop(IMAGE_SIZE),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }
    return data_transforms

def create_data_loaders(data_dir):
    """创建数据加载器"""
    data_transforms = create_data_transforms()
    
    image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
                      for x in ['train', 'val', 'test']}
    
    dataloaders = {x: DataLoader(image_datasets[x], batch_size=BATCH_SIZE,
                                 shuffle=True if x == 'train' else False,
                                 num_workers=NUM_WORKERS)
                   for x in ['train', 'val', 'test']}
    
    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val', 'test']}
    class_names = image_datasets['train'].classes
    
    return dataloaders, dataset_sizes, class_names

def initialize_model(model_name, num_classes):
    """初始化预训练模型并冻结部分层"""
    model_info = MODELS[model_name]
    model = model_info["model"](pretrained=model_info["pretrained"])
    
    # 修改分类器
    if model_name == "resnet50":
        model.fc = model_info["classifier"](model_info["feature_size"], num_classes)
    elif model_name == "shufflenet_v2":
        model.fc = model_info["classifier"](model_info["feature_size"], num_classes)
    elif model_name == "efficientnet-b0":
        model.classifier[1] = model_info["classifier"](model_info["feature_size"], num_classes)
    elif model_name == "mobilenet_v3":
        model.classifier = model_info["classifier"](model_info["feature_size"], num_classes)
    
    # 冻结部分层
    if FREEZE_LAYERS:
        print(f"Freezing layers for {model_name}...")
        model_info["freeze_func"](model)
        
        # 计算可训练参数
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        total_params = sum(p.numel() for p in model.parameters())
        print(f"Trainable params: {trainable_params:,} ({trainable_params/total_params*100:.2f}%)")
    
    return model.to(device)

def count_parameters(model):
    """计算模型参数量（MB）"""
    params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return params * 4 / (1024 * 1024)  # 转换为MB

def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, model_name="model"):
    """训练模型并记录指标"""
    since = time.time()
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    # 记录训练历史
    history = {
        'train_loss': [], 'train_acc': [], 'train_precision': [], 'train_recall': [], 'train_f1': [],
        'val_loss': [], 'val_acc': [], 'val_precision': [], 'val_recall': [], 'val_f1': []
    }
    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs-1}')
        print('-' * 10)
        
        # 每个epoch都有一个训练和验证阶段
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # 训练模式
            else:
                model.eval()   # 评估模式
            
            running_loss = 0.0
            running_corrects = 0
            all_preds = []
            all_labels = []
            
            # 迭代数据
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                # 零梯度
                optimizer.zero_grad()
                
                # 前向传播
                # 只有在训练时才跟踪历史
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)
                    
                    # 只有在训练阶段才进行反向传播和优化
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                
                # 统计
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
            
            # 计算epoch指标
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
            
            # 计算分类报告
            report = classification_report(all_labels, all_preds, output_dict=True, zero_division=0)
            macro_precision = report['macro avg']['precision']
            macro_recall = report['macro avg']['recall']
            macro_f1 = report['macro avg']['f1-score']
            
            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f} '
                  f'Precision: {macro_precision:.4f} Recall: {macro_recall:.4f} F1: {macro_f1:.4f}')
            
            # 记录历史
            history[f'{phase}_loss'].append(epoch_loss)
            history[f'{phase}_acc'].append(epoch_acc.item())
            history[f'{phase}_precision'].append(macro_precision)
            history[f'{phase}_recall'].append(macro_recall)
            history[f'{phase}_f1'].append(macro_f1)
            
            # 深拷贝模型
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                torch.save(model.state_dict(), f"{model_name}_best_model.pth")
        
        print()
    
    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')
    
    # 加载最佳模型权重
    model.load_state_dict(best_model_wts)
    return model, history

def evaluate_model(model, dataloader, class_names, model_name):
    """评估模型"""
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # 打印分类报告
    report = classification_report(all_labels, all_preds, target_names=class_names, output_dict=True)
    print(f"\n{model_name} Test Metrics:")
    print(f"Accuracy: {report['accuracy']:.4f}")
    print(f"Macro Precision: {report['macro avg']['precision']:.4f}")
    print(f"Macro Recall: {report['macro avg']['recall']:.4f}")
    print(f"Macro F1-Score: {report['macro avg']['f1-score']:.4f}")
    
    # 计算混淆矩阵
    cm = confusion_matrix(all_labels, all_preds)
    
    return {
        'accuracy': report['accuracy'],
        'precision': report['macro avg']['precision'],
        'recall': report['macro avg']['recall'],
        'f1': report['macro avg']['f1-score'],
        'confusion_matrix': cm
    }

def plot_training_curves(histories, metrics, save_path):
    """绘制训练曲线"""
    plt.figure(figsize=(15, 10))
    
    for i, metric in enumerate(metrics):
        plt.subplot(2, 3, i+1)
        
        for model_name, history in histories.items():
            plt.plot(history[f'train_{metric}'], label=f'{model_name} Train')
            plt.plot(history[f'val_{metric}'], label=f'{model_name} Val')
        
        plt.title(f'Model {metric.capitalize()}')
        plt.xlabel('Epoch')
        plt.ylabel(metric.capitalize())
        plt.legend()
        plt.grid(True)
    
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

def compare_models_table(model_metrics, model_sizes):
    """生成模型对比表格"""
    print("\n===== Model Comparison =====")
    print(f"{'Model':<15} {'Accuracy(%)':<12} {'Precision(%)':<12} {'Recall(%)':<12} {'F1-Score(%)':<12} {'Params(MB)':<12}")
    print("-" * 75)
    
    for model_name, metrics in model_metrics.items():
        print(f"{model_name:<15} {metrics['accuracy']*100:<12.2f} {metrics['precision']*100:<12.2f} "
              f"{metrics['recall']*100:<12.2f} {metrics['f1']*100:<12.2f} {model_sizes[model_name]:<12.2f}")

def main():
    """主函数"""
    # 创建数据加载器
    dataloaders, dataset_sizes, class_names = create_data_loaders(DATA_DIR)
    print(f"Classes: {class_names}")
    num_classes = len(class_names)
    
    all_histories = {}
    all_metrics = {}
    model_sizes = {}
    
    # 训练所有模型
    for model_name in MODELS.keys():
        print(f"\n===== Training {model_name} model =====")
        
        # 初始化模型
        model = initialize_model(model_name, num_classes)
        model = model.to(device)
        
        # 计算模型大小
        model_size = count_parameters(model)
        model_sizes[model_name] = model_size
        print(f"{model_name} Trainable Parameters: {model_size:.2f} MB")
        
        # 设置损失函数和优化器
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.1)
        
        # 训练模型
        model, history = train_model(
            model, dataloaders, criterion, optimizer,
            num_epochs=EPOCHS, model_name=model_name
        )
        
        # 评估模型
        print(f"\nEvaluating {model_name} model:")
        metrics = evaluate_model(model, dataloaders['test'], class_names, model_name)
        all_metrics[model_name] = metrics
        
        # 保存训练历史
        all_histories[model_name] = history
    
    # 绘制训练曲线
    plot_training_curves(
        all_histories, 
        metrics=['loss', 'acc', 'precision', 'recall', 'f1'],
        save_path='training_curves_comparison.png'
    )
    
    # 生成模型对比表格
    compare_models_table(all_metrics, model_sizes)
    
    print("\n===== Training and evaluation completed for all models =====")
    print("Training curves saved to 'training_curves_comparison.png'")

if __name__ == "__main__":
    main()

Using device: cuda:0
Classes: ['Bacterial Leaf Blight', 'Brown Spot', 'Healthy Rice Leaf', 'Leaf Blast', 'Leaf scald', 'Narrow Brown Leaf Spot', 'Neck Blast', 'Rice Hispa', 'Sheath Blight']

===== Training efficientnet-b0 model =====
Freezing layers for efficientnet-b0...
Trainable params: 3,672,017 (91.36%)
efficientnet-b0 Trainable Parameters: 14.01 MB
Epoch 0/0
----------


KeyboardInterrupt: 

最终优化 冻结层 CA注意力机制

In [None]:
# import os
# import numpy as np
# import matplotlib.pyplot as plt
# from sklearn.model_selection import train_test_split
# from sklearn.metrics import classification_report, confusion_matrix
# import torch
# import torch.nn as nn
# import torch.optim as optim
# from torch.optim import lr_scheduler
# from torch.utils.data import DataLoader, Subset
# from torchvision import datasets, models, transforms
# from torchvision.models.efficientnet import MBConv
# import time
# import copy
# import shutil
# from tqdm import tqdm

# # 设置随机种子确保结果可复现
# torch.manual_seed(42)
# np.random.seed(42)

# # 配置参数
# DATA_DIR = r"D:\homework\论文\论文\project\dataset\archive\Rice_Leaf_AUG\Rice_Leaf_AUG"  # 原始数据集路径
# OUTPUT_DIR = "dataset_split"  # 划分后的数据集保存路径
# IMAGE_SIZE = 224
# BATCH_SIZE = 32
# EPOCHS = 50
# NUM_WORKERS = 4  # 数据加载的线程数
# SPLIT_RATIO = [0.7, 0.15, 0.15]  # 训练集、验证集、测试集比例

# # 模型配置
# MODELS = {
#     "efficientnet-b0": {
#         "model": models.efficientnet_b0,
#         "pretrained": True,
#         "feature_size": 1280  # EfficientNet-B0的特征维度
#     },
# }

# # 检查GPU是否可用
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# print(f"Using device: {device}")
# if device.type == 'cuda':
#     print(f"GPU: {torch.cuda.get_device_name(0)}")

# # 注意力模块定义
# class ChannelAttention(nn.Module):
#     def __init__(self, channels, reduction_ratio=16):
#         super(ChannelAttention, self).__init__()
#         self.avg_pool = nn.AdaptiveAvgPool2d(1)
#         self.max_pool = nn.AdaptiveMaxPool2d(1)
#         self.fc = nn.Sequential(
#             nn.Conv2d(channels, channels // reduction_ratio, 1, bias=False),
#             nn.ReLU(),
#             nn.Conv2d(channels // reduction_ratio, channels, 1, bias=False)
#         )
#         self.sigmoid = nn.Sigmoid()
    
#     def forward(self, x):
#         avg_out = self.fc(self.avg_pool(x))
#         max_out = self.fc(self.max_pool(x))
#         out = avg_out + max_out
#         return x * self.sigmoid(out)

# class SpatialAttention(nn.Module):
#     def __init__(self, kernel_size=7):
#         super(SpatialAttention, self).__init__()
#         self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
#         self.sigmoid = nn.Sigmoid()
    
#     def forward(self, x):
#         avg_out = torch.mean(x, dim=1, keepdim=True)
#         max_out, _ = torch.max(x, dim=1, keepdim=True)
#         out = torch.cat([avg_out, max_out], dim=1)
#         out = self.conv(out)
#         return x * self.sigmoid(out)

# class CBAM(nn.Module):
#     def __init__(self, channels, reduction_ratio=16):
#         super(CBAM, self).__init__()
#         self.channel_attention = ChannelAttention(channels, reduction_ratio)
#         self.spatial_attention = SpatialAttention()
    
#     def forward(self, x):
#         x = self.channel_attention(x)
#         x = self.spatial_attention(x)
#         return x

# def add_cbam_to_efficientnet(model):
#     """为EfficientNet添加CBAM模块"""
#     features = []
#     for layer in model.features.children():
#         features.append(layer)
#         if isinstance(layer, MBConv):
#             # 获取当前MBConv层的输出通道数
#             out_channels = layer.out_channels
#             # 添加CBAM模块
#             features.append(CBAM(out_channels))
#     # 重构features模块
#     model.features = nn.Sequential(*features)
#     return model

# def split_dataset(data_dir, output_dir, split_ratio):
#     """将数据集划分为训练集、验证集和测试集"""
#     if not os.path.exists(output_dir):
#         os.makedirs(output_dir)
    
#     subdirs = [f.name for f in os.scandir(data_dir) if f.is_dir()]
    
#     for subdir in subdirs:
#         print(f"Processing class: {subdir}")
#         class_dir = os.path.join(data_dir, subdir)
#         files = [f for f in os.listdir(class_dir) if os.path.isfile(os.path.join(class_dir, f))]
        
#         # 划分数据集
#         train_files, test_files = train_test_split(files, test_size=split_ratio[2], random_state=42)
#         train_files, val_files = train_test_split(train_files, test_size=split_ratio[1]/(split_ratio[0]+split_ratio[1]), random_state=42)
        
#         # 创建输出目录
#         for split in ["train", "val", "test"]:
#             split_dir = os.path.join(output_dir, split, subdir)
#             os.makedirs(split_dir, exist_ok=True)
        
#         # 复制文件
#         def copy_files(files_list, split_name):
#             for file_name in tqdm(files_list, desc=f"Copying {split_name} files"):
#                 src = os.path.join(class_dir, file_name)
#                 dst = os.path.join(output_dir, split_name, subdir, file_name)
#                 shutil.copy(src, dst)
        
#         copy_files(train_files, "train")
#         copy_files(val_files, "val")
#         copy_files(test_files, "test")
    
#     print(f"Dataset split completed. Saved to {output_dir}")
#     return output_dir

# def create_data_transforms():
#     data_transforms = {
#         'train': transforms.Compose([
#             transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.8, 1.0)),
#             transforms.RandomHorizontalFlip(),
#             transforms.RandomVerticalFlip(p=0.1),
#             transforms.RandomRotation(30),
#             transforms.ColorJitter(
#                 brightness=0.3,
#                 contrast=0.3,
#                 saturation=0.3,
#                 hue=0.1
#             ),
#             transforms.RandomGrayscale(p=0.1),
#             transforms.RandomAffine(
#                 degrees=0,
#                 translate=(0.1, 0.1),
#                 scale=(0.9, 1.1),
#                 shear=10
#             ),
#             transforms.RandomPerspective(distortion_scale=0.2, p=0.2),
#             transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
#             transforms.ToTensor(),
#             transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
#         ]),
#         'val': transforms.Compose([
#             transforms.Resize(IMAGE_SIZE),
#             transforms.CenterCrop(IMAGE_SIZE),
#             transforms.ToTensor(),
#             transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
#         ]),
#         'test': transforms.Compose([
#             transforms.Resize(IMAGE_SIZE),
#             transforms.CenterCrop(IMAGE_SIZE),
#             transforms.ToTensor(),
#             transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
#         ]),
#     }
#     return data_transforms

# def create_data_loaders(data_dir):
#     """创建数据加载器"""
#     data_transforms = create_data_transforms()
    
#     # 创建数据集
#     image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
#                                               data_transforms[x])
#                       for x in ['train', 'val', 'test']}
    
#     # 创建数据加载器
#     dataloaders = {x: DataLoader(image_datasets[x], batch_size=BATCH_SIZE,
#                                  shuffle=True if x == 'train' else False,
#                                  num_workers=NUM_WORKERS)
#                    for x in ['train', 'val', 'test']}
    
#     dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val', 'test']}
#     class_names = image_datasets['train'].classes
    
#     return dataloaders, dataset_sizes, class_names

# def initialize_model(model_name, num_classes, feature_extract=True):
#     """初始化预训练模型"""
#     model_info = MODELS[model_name]
#     model_ft = model_info["model"](pretrained=model_info["pretrained"])
    
#     # 添加注意力模块（仅对efficientnet）
#     if model_name == "efficientnet-b0":
#         model_ft = add_cbam_to_efficientnet(model_ft)
    
#     # 冻结预训练模型的参数
#     if feature_extract:
#         for param in model_ft.parameters():
#             param.requires_grad = False
    
#     # 修改分类器
#     if model_name == "efficientnet-b0":
#         in_features = model_ft.classifier[1].in_features
#         model_ft.classifier[1] = nn.Linear(in_features, num_classes)
    
#     return model_ft

# def train_model(model, dataloaders, criterion, optimizer, scheduler, num_epochs=25, model_name="model"):
#     """训练模型"""
#     since = time.time()
    
#     best_model_wts = copy.deepcopy(model.state_dict())
#     best_acc = 0.0
    
#     # 记录训练历史
#     history = {
#         'train_loss': [],
#         'train_acc': [],
#         'val_loss': [],
#         'val_acc': []
#     }
    
#     for epoch in range(num_epochs):
#         print(f'Epoch {epoch}/{num_epochs-1}')
#         print('-' * 10)
        
#         # 每个epoch都有一个训练和验证阶段
#         for phase in ['train', 'val']:
#             if phase == 'train':
#                 model.train()
#             else:
#                 model.eval()
            
#             running_loss = 0.0
#             running_corrects = 0
            
#             # 迭代数据
#             progress_bar = tqdm(enumerate(dataloaders[phase]), total=len(dataloaders[phase]))
#             for i, (inputs, labels) in progress_bar:
#                 inputs = inputs.to(device)
#                 labels = labels.to(device)
                
#                 # 零梯度
#                 optimizer.zero_grad()
                
#                 # 前向传播
#                 with torch.set_grad_enabled(phase == 'train'):
#                     outputs = model(inputs)
#                     _, preds = torch.max(outputs, 1)
#                     loss = criterion(outputs, labels)
                    
#                     # 反向传播+优化
#                     if phase == 'train':
#                         loss.backward()
#                         optimizer.step()
                
#                 # 统计
#                 running_loss += loss.item() * inputs.size(0)
#                 running_corrects += torch.sum(preds == labels.data)
                
#                 # 更新进度条
#                 progress_bar.set_description(f"{phase} Loss: {running_loss/(i*BATCH_SIZE+inputs.size(0)):.4f} Acc: {running_corrects/(i*BATCH_SIZE+inputs.size(0)):.4f}")
            
#             if phase == 'train' and scheduler:
#                 scheduler.step()
            
#             epoch_loss = running_loss / len(dataloaders[phase].dataset)
#             epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
            
#             print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
            
#             # 记录历史
#             history[f'{phase}_loss'].append(epoch_loss)
#             history[f'{phase}_acc'].append(epoch_acc.item())
            
#             # 保存最佳模型
#             if phase == 'val' and epoch_acc > best_acc:
#                 best_acc = epoch_acc
#                 best_model_wts = copy.deepcopy(model.state_dict())
#                 torch.save(model.state_dict(), f"{model_name}_best_model.pth")
        
#         print()
    
#     time_elapsed = time.time() - since
#     print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
#     print(f'Best val Acc: {best_acc:4f}')
    
#     # 加载最佳模型权重
#     model.load_state_dict(best_model_wts)
#     return model, history

# def evaluate_model(model, dataloader, class_names, model_name):
#     """评估模型"""
#     model.eval()
#     all_preds = []
#     all_labels = []
    
#     with torch.no_grad():
#         for inputs, labels in tqdm(dataloader, desc="Evaluating"):
#             inputs = inputs.to(device)
#             labels = labels.to(device)
            
#             outputs = model(inputs)
#             _, preds = torch.max(outputs, 1)
            
#             all_preds.extend(preds.cpu().numpy())
#             all_labels.extend(labels.cpu().numpy())
    
#     # 打印分类报告+
#     print(f"\n{classification_report(all_labels, all_preds, target_names=class_names)}")
    
#     # 计算混淆矩阵
#     cm = confusion_matrix(all_labels, all_preds)
#     plot_confusion_matrix(cm, class_names, f"{model_name}_confusion_matrix.png")
    
#     return all_preds, all_labels

# def plot_confusion_matrix(cm, classes, save_path, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
#     """绘制混淆矩阵"""
#     if normalize:
#         cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
#         print("Normalized confusion matrix")
#     else:
#         print('Confusion matrix, without normalization')
    
#     plt.figure(figsize=(10, 10))
#     plt.imshow(cm, interpolation='nearest', cmap=cmap)
#     plt.title(title)
#     plt.colorbar()
#     tick_marks = np.arange(len(classes))
#     plt.xticks(tick_marks, classes, rotation=45)
#     plt.yticks(tick_marks, classes)
    
#     fmt = '.2f' if normalize else 'd'
#     thresh = cm.max() / 2.
#     for i in range(cm.shape[0]):
#         for j in range(cm.shape[1]):
#             plt.text(j, i, format(cm[i, j], fmt),
#                      horizontalalignment="center",
#                      color="white" if cm[i, j] > thresh else "black")
    
#     plt.tight_layout()
#     plt.ylabel('True label')
#     plt.xlabel('Predicted label')
#     plt.savefig(save_path)
#     plt.close()

# def plot_training_curves(histories, save_path='training_curves.png'):
#     """绘制训练曲线"""
#     plt.figure(figsize=(12, 10))
    
#     # 绘制准确率曲线
#     plt.subplot(2, 1, 1)
#     for model_name, history in histories.items():
#         plt.plot(history['train_acc'], label=f'{model_name} Train')
#         plt.plot(history['val_acc'], label=f'{model_name} Val')
    
#     plt.title('Model Accuracy')
#     plt.ylabel('Accuracy')
#     plt.xlabel('Epoch')
#     plt.legend(loc='lower right')
#     plt.grid(True)
    
#     # 绘制损失曲线
#     plt.subplot(2, 1, 2)
#     for model_name, history in histories.items():
#         plt.plot(history['train_loss'], label=f'{model_name} Train')
#         plt.plot(history['val_loss'], label=f'{model_name} Val')
    
#     plt.title('Model Loss')
#     plt.ylabel('Loss')
#     plt.xlabel('Epoch')
#     plt.legend(loc='upper right')
#     plt.grid(True)
    
#     plt.tight_layout()
#     plt.savefig(save_path)
#     plt.close()

# def main():
#     """主函数"""
#     # 划分数据集
#     split_dir = split_dataset(DATA_DIR, OUTPUT_DIR, SPLIT_RATIO)
    
#     # 创建数据加载器
#     dataloaders, dataset_sizes, class_names = create_data_loaders(split_dir)
#     print(f"Classes: {class_names}")
#     num_classes = len(class_names)
    
#     all_histories = {}
    
#     # 训练所有模型
#     for model_name in MODELS.keys():
#         print(f"\n===== Training {model_name} model =====")
        
#         # 初始化模型
#         model_ft = initialize_model(model_name, num_classes, feature_extract=False)
#         model_ft = model_ft.to(device)
        
#         # 设置损失函数和优化器
#         criterion = nn.CrossEntropyLoss()
#         optimizer_ft = optim.Adam(model_ft.parameters(), lr=0.001)
#         exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
        
#         # 训练模型
#         model_ft, history = train_model(
#             model_ft, dataloaders, criterion, optimizer_ft, exp_lr_scheduler,
#             num_epochs=EPOCHS, model_name=model_name
#         )
        
#         # 保存最终模型
#         torch.save(model_ft.state_dict(), f"{model_name}_final_model.pth")
        
#         # 评估模型
#         print(f"\nEvaluating {model_name} model:")
#         evaluate_model(model_ft, dataloaders['test'], class_names, model_name)
        
#         # 保存训练历史
#         all_histories[model_name] = history
    
#     # 绘制训练曲线
#     plot_training_curves(all_histories, save_path='all_models_training_curves.png')
    
#     print("\n===== Training completed for all models =====")

# if __name__ == "__main__":
#     main()

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda:0
GPU: NVIDIA GeForce RTX 3060 Ti
Processing class: Bacterial Leaf Blight


Copying train files: 100%|██████████| 837/837 [00:00<00:00, 1087.08it/s]
Copying val files: 100%|██████████| 180/180 [00:00<00:00, 1098.26it/s]
Copying test files: 100%|██████████| 180/180 [00:00<00:00, 1173.72it/s]


Processing class: Brown Spot


Copying train files: 100%|██████████| 1082/1082 [00:00<00:00, 1082.01it/s]
Copying val files: 100%|██████████| 232/232 [00:00<00:00, 1212.35it/s]
Copying test files: 100%|██████████| 232/232 [00:00<00:00, 1216.81it/s]


Processing class: Healthy Rice Leaf


Copying train files: 100%|██████████| 759/759 [00:00<00:00, 994.88it/s] 
Copying val files: 100%|██████████| 163/163 [00:00<00:00, 1116.52it/s]
Copying test files: 100%|██████████| 163/163 [00:00<00:00, 941.21it/s]


Processing class: Leaf Blast


Copying train files: 100%|██████████| 1222/1222 [00:01<00:00, 1088.27it/s]
Copying val files: 100%|██████████| 263/263 [00:00<00:00, 1187.17it/s]
Copying test files: 100%|██████████| 263/263 [00:00<00:00, 1126.98it/s]


Processing class: Leaf scald


Copying train files: 100%|██████████| 932/932 [00:00<00:00, 1067.68it/s]
Copying val files: 100%|██████████| 200/200 [00:00<00:00, 987.32it/s]
Copying test files: 100%|██████████| 200/200 [00:00<00:00, 1240.27it/s]


Processing class: Narrow Brown Leaf Spot


Copying train files: 100%|██████████| 667/667 [00:00<00:00, 1017.94it/s]
Copying val files: 100%|██████████| 143/143 [00:00<00:00, 1076.84it/s]
Copying test files: 100%|██████████| 144/144 [00:00<00:00, 1081.32it/s]


Processing class: Neck Blast


Copying train files: 100%|██████████| 700/700 [00:01<00:00, 509.07it/s]
Copying val files: 100%|██████████| 150/150 [00:00<00:00, 484.44it/s]
Copying test files: 100%|██████████| 150/150 [00:00<00:00, 500.31it/s]


Processing class: Rice Hispa


Copying train files: 100%|██████████| 909/909 [00:00<00:00, 1133.50it/s]
Copying val files: 100%|██████████| 195/195 [00:00<00:00, 1237.68it/s]
Copying test files: 100%|██████████| 195/195 [00:00<00:00, 1106.85it/s]


Processing class: Sheath Blight


Copying train files: 100%|██████████| 1139/1139 [00:01<00:00, 992.93it/s] 
Copying val files: 100%|██████████| 245/245 [00:00<00:00, 1015.88it/s]
Copying test files: 100%|██████████| 245/245 [00:00<00:00, 1199.09it/s]


Dataset split completed. Saved to dataset_split
Classes: ['Bacterial Leaf Blight', 'Brown Spot', 'Healthy Rice Leaf', 'Leaf Blast', 'Leaf scald', 'Narrow Brown Leaf Spot', 'Neck Blast', 'Rice Hispa', 'Sheath Blight']

===== Training efficientnet-b0 model =====
Epoch 0/49
----------


train Loss: 0.9132 Acc: 0.6849: 100%|██████████| 258/258 [01:09<00:00,  3.69it/s]


train Loss: 0.9132 Acc: 0.6849


val Loss: 0.5736 Acc: 0.8063: 100%|██████████| 56/56 [00:10<00:00,  5.41it/s]


val Loss: 0.5736 Acc: 0.8063

Epoch 1/49
----------


train Loss: 0.5376 Acc: 0.8139: 100%|██████████| 258/258 [00:54<00:00,  4.73it/s]


train Loss: 0.5376 Acc: 0.8139


val Loss: 0.3479 Acc: 0.8842: 100%|██████████| 56/56 [00:08<00:00,  6.69it/s]


val Loss: 0.3479 Acc: 0.8842

Epoch 2/49
----------


train Loss: 0.4160 Acc: 0.8598: 100%|██████████| 258/258 [00:55<00:00,  4.64it/s]


train Loss: 0.4160 Acc: 0.8598


val Loss: 0.2462 Acc: 0.9142: 100%|██████████| 56/56 [00:08<00:00,  6.82it/s]


val Loss: 0.2462 Acc: 0.9142

Epoch 3/49
----------


train Loss: 0.3439 Acc: 0.8829: 100%|██████████| 258/258 [00:54<00:00,  4.78it/s]


train Loss: 0.3439 Acc: 0.8829


val Loss: 0.2100 Acc: 0.9294: 100%|██████████| 56/56 [00:08<00:00,  6.81it/s]


val Loss: 0.2100 Acc: 0.9294

Epoch 4/49
----------


train Loss: 0.3037 Acc: 0.8975: 100%|██████████| 258/258 [00:53<00:00,  4.79it/s]


train Loss: 0.3037 Acc: 0.8975


val Loss: 0.1786 Acc: 0.9424: 100%|██████████| 56/56 [00:08<00:00,  6.93it/s]


val Loss: 0.1786 Acc: 0.9424

Epoch 5/49
----------


train Loss: 0.2497 Acc: 0.9156: 100%|██████████| 258/258 [00:54<00:00,  4.72it/s]


train Loss: 0.2497 Acc: 0.9156


val Loss: 0.1235 Acc: 0.9588: 100%|██████████| 56/56 [00:08<00:00,  6.73it/s]


val Loss: 0.1235 Acc: 0.9588

Epoch 6/49
----------


train Loss: 0.2426 Acc: 0.9162: 100%|██████████| 258/258 [00:55<00:00,  4.61it/s]


train Loss: 0.2426 Acc: 0.9162


val Loss: 0.1111 Acc: 0.9616: 100%|██████████| 56/56 [00:08<00:00,  6.78it/s]


val Loss: 0.1111 Acc: 0.9616

Epoch 7/49
----------


train Loss: 0.1382 Acc: 0.9548: 100%|██████████| 258/258 [00:55<00:00,  4.66it/s]


train Loss: 0.1382 Acc: 0.9548


val Loss: 0.0494 Acc: 0.9819: 100%|██████████| 56/56 [00:08<00:00,  6.68it/s]


val Loss: 0.0494 Acc: 0.9819

Epoch 8/49
----------


train Loss: 0.0909 Acc: 0.9694: 100%|██████████| 258/258 [00:55<00:00,  4.68it/s]


train Loss: 0.0909 Acc: 0.9694


val Loss: 0.0425 Acc: 0.9853: 100%|██████████| 56/56 [00:08<00:00,  6.72it/s]


val Loss: 0.0425 Acc: 0.9853

Epoch 9/49
----------


train Loss: 0.0810 Acc: 0.9741: 100%|██████████| 258/258 [00:55<00:00,  4.68it/s]


train Loss: 0.0810 Acc: 0.9741


val Loss: 0.0376 Acc: 0.9887: 100%|██████████| 56/56 [00:08<00:00,  6.75it/s]


val Loss: 0.0376 Acc: 0.9887

Epoch 10/49
----------


train Loss: 0.0698 Acc: 0.9774: 100%|██████████| 258/258 [00:55<00:00,  4.65it/s]


train Loss: 0.0698 Acc: 0.9774


val Loss: 0.0264 Acc: 0.9904: 100%|██████████| 56/56 [00:08<00:00,  6.71it/s]


val Loss: 0.0264 Acc: 0.9904

Epoch 11/49
----------


train Loss: 0.0625 Acc: 0.9796: 100%|██████████| 258/258 [00:54<00:00,  4.70it/s]


train Loss: 0.0625 Acc: 0.9796


val Loss: 0.0270 Acc: 0.9910: 100%|██████████| 56/56 [00:08<00:00,  6.81it/s]


val Loss: 0.0270 Acc: 0.9910

Epoch 12/49
----------


train Loss: 0.0564 Acc: 0.9808: 100%|██████████| 258/258 [00:55<00:00,  4.64it/s]


train Loss: 0.0564 Acc: 0.9808


val Loss: 0.0243 Acc: 0.9927: 100%|██████████| 56/56 [00:08<00:00,  6.82it/s]


val Loss: 0.0243 Acc: 0.9927

Epoch 13/49
----------


train Loss: 0.0520 Acc: 0.9829: 100%|██████████| 258/258 [00:54<00:00,  4.72it/s]


train Loss: 0.0520 Acc: 0.9829


val Loss: 0.0219 Acc: 0.9910: 100%|██████████| 56/56 [00:08<00:00,  6.74it/s]


val Loss: 0.0219 Acc: 0.9910

Epoch 14/49
----------


train Loss: 0.0488 Acc: 0.9838: 100%|██████████| 258/258 [00:54<00:00,  4.77it/s]


train Loss: 0.0488 Acc: 0.9838


val Loss: 0.0235 Acc: 0.9904: 100%|██████████| 56/56 [00:08<00:00,  6.73it/s]


val Loss: 0.0235 Acc: 0.9904

Epoch 15/49
----------


train Loss: 0.0436 Acc: 0.9850: 100%|██████████| 258/258 [00:55<00:00,  4.67it/s]


train Loss: 0.0436 Acc: 0.9850


val Loss: 0.0207 Acc: 0.9921: 100%|██████████| 56/56 [00:08<00:00,  6.91it/s]


val Loss: 0.0207 Acc: 0.9921

Epoch 16/49
----------


train Loss: 0.0444 Acc: 0.9871: 100%|██████████| 258/258 [00:54<00:00,  4.73it/s]


train Loss: 0.0444 Acc: 0.9871


val Loss: 0.0182 Acc: 0.9927: 100%|██████████| 56/56 [00:08<00:00,  6.72it/s]


val Loss: 0.0182 Acc: 0.9927

Epoch 17/49
----------


train Loss: 0.0466 Acc: 0.9842: 100%|██████████| 258/258 [00:54<00:00,  4.77it/s]


train Loss: 0.0466 Acc: 0.9842


val Loss: 0.0173 Acc: 0.9921: 100%|██████████| 56/56 [00:08<00:00,  6.76it/s]


val Loss: 0.0173 Acc: 0.9921

Epoch 18/49
----------


train Loss: 0.0390 Acc: 0.9886: 100%|██████████| 258/258 [00:53<00:00,  4.80it/s]


train Loss: 0.0390 Acc: 0.9886


val Loss: 0.0169 Acc: 0.9932: 100%|██████████| 56/56 [00:08<00:00,  6.67it/s]


val Loss: 0.0169 Acc: 0.9932

Epoch 19/49
----------


train Loss: 0.0463 Acc: 0.9846: 100%|██████████| 258/258 [00:54<00:00,  4.74it/s]


train Loss: 0.0463 Acc: 0.9846


val Loss: 0.0186 Acc: 0.9932: 100%|██████████| 56/56 [00:08<00:00,  6.81it/s]


val Loss: 0.0186 Acc: 0.9932

Epoch 20/49
----------


train Loss: 0.0411 Acc: 0.9858: 100%|██████████| 258/258 [00:54<00:00,  4.72it/s]


train Loss: 0.0411 Acc: 0.9858


val Loss: 0.0172 Acc: 0.9921: 100%|██████████| 56/56 [00:08<00:00,  6.72it/s]


val Loss: 0.0172 Acc: 0.9921

Epoch 21/49
----------


train Loss: 0.0432 Acc: 0.9853: 100%|██████████| 258/258 [00:54<00:00,  4.76it/s]


train Loss: 0.0432 Acc: 0.9853


val Loss: 0.0159 Acc: 0.9927: 100%|██████████| 56/56 [00:08<00:00,  6.71it/s]


val Loss: 0.0159 Acc: 0.9927

Epoch 22/49
----------


train Loss: 0.0394 Acc: 0.9871: 100%|██████████| 258/258 [00:55<00:00,  4.68it/s]


train Loss: 0.0394 Acc: 0.9871


val Loss: 0.0166 Acc: 0.9938: 100%|██████████| 56/56 [00:08<00:00,  6.75it/s]


val Loss: 0.0166 Acc: 0.9938

Epoch 23/49
----------


train Loss: 0.0389 Acc: 0.9876: 100%|██████████| 258/258 [00:54<00:00,  4.75it/s]


train Loss: 0.0389 Acc: 0.9876


val Loss: 0.0165 Acc: 0.9932: 100%|██████████| 56/56 [00:08<00:00,  6.77it/s]


val Loss: 0.0165 Acc: 0.9932

Epoch 24/49
----------


train Loss: 0.0405 Acc: 0.9871: 100%|██████████| 258/258 [00:54<00:00,  4.74it/s]


train Loss: 0.0405 Acc: 0.9871


val Loss: 0.0171 Acc: 0.9927: 100%|██████████| 56/56 [00:08<00:00,  6.75it/s]


val Loss: 0.0171 Acc: 0.9927

Epoch 25/49
----------


train Loss: 0.0389 Acc: 0.9875: 100%|██████████| 258/258 [00:53<00:00,  4.81it/s]

train Loss: 0.0389 Acc: 0.9875



val Loss: 0.0170 Acc: 0.9927: 100%|██████████| 56/56 [00:08<00:00,  6.93it/s]


val Loss: 0.0170 Acc: 0.9927

Epoch 26/49
----------


train Loss: 0.0402 Acc: 0.9859: 100%|██████████| 258/258 [00:54<00:00,  4.71it/s]


train Loss: 0.0402 Acc: 0.9859


val Loss: 0.0177 Acc: 0.9921: 100%|██████████| 56/56 [00:08<00:00,  6.89it/s]


val Loss: 0.0177 Acc: 0.9921

Epoch 27/49
----------


train Loss: 0.0396 Acc: 0.9874: 100%|██████████| 258/258 [00:53<00:00,  4.80it/s]


train Loss: 0.0396 Acc: 0.9874


val Loss: 0.0168 Acc: 0.9921: 100%|██████████| 56/56 [00:08<00:00,  6.70it/s]


val Loss: 0.0168 Acc: 0.9921

Epoch 28/49
----------


train Loss: 0.0405 Acc: 0.9879: 100%|██████████| 258/258 [00:54<00:00,  4.77it/s]


train Loss: 0.0405 Acc: 0.9879


val Loss: 0.0157 Acc: 0.9932: 100%|██████████| 56/56 [00:08<00:00,  6.78it/s]


val Loss: 0.0157 Acc: 0.9932

Epoch 29/49
----------


train Loss: 0.0364 Acc: 0.9874: 100%|██████████| 258/258 [00:53<00:00,  4.81it/s]


train Loss: 0.0364 Acc: 0.9874


val Loss: 0.0160 Acc: 0.9932: 100%|██████████| 56/56 [00:08<00:00,  6.68it/s]


val Loss: 0.0160 Acc: 0.9932

Epoch 30/49
----------


train Loss: 0.0443 Acc: 0.9861: 100%|██████████| 258/258 [00:54<00:00,  4.75it/s]


train Loss: 0.0443 Acc: 0.9861


val Loss: 0.0157 Acc: 0.9932: 100%|██████████| 56/56 [00:08<00:00,  6.92it/s]


val Loss: 0.0157 Acc: 0.9932

Epoch 31/49
----------


train Loss: 0.0430 Acc: 0.9861: 100%|██████████| 258/258 [00:53<00:00,  4.80it/s]


train Loss: 0.0430 Acc: 0.9861


val Loss: 0.0150 Acc: 0.9944: 100%|██████████| 56/56 [00:08<00:00,  6.72it/s]


val Loss: 0.0150 Acc: 0.9944

Epoch 32/49
----------


train Loss: 0.0406 Acc: 0.9865: 100%|██████████| 258/258 [00:53<00:00,  4.80it/s]


train Loss: 0.0406 Acc: 0.9865


val Loss: 0.0161 Acc: 0.9932: 100%|██████████| 56/56 [00:08<00:00,  6.89it/s]


val Loss: 0.0161 Acc: 0.9932

Epoch 33/49
----------


train Loss: 0.0394 Acc: 0.9867: 100%|██████████| 258/258 [00:53<00:00,  4.79it/s]


train Loss: 0.0394 Acc: 0.9867


val Loss: 0.0167 Acc: 0.9927: 100%|██████████| 56/56 [00:08<00:00,  6.70it/s]


val Loss: 0.0167 Acc: 0.9927

Epoch 34/49
----------


train Loss: 0.0364 Acc: 0.9893: 100%|██████████| 258/258 [00:53<00:00,  4.82it/s]


train Loss: 0.0364 Acc: 0.9893


val Loss: 0.0150 Acc: 0.9938: 100%|██████████| 56/56 [00:08<00:00,  6.70it/s]


val Loss: 0.0150 Acc: 0.9938

Epoch 35/49
----------


train Loss: 0.0461 Acc: 0.9863: 100%|██████████| 258/258 [00:53<00:00,  4.79it/s]


train Loss: 0.0461 Acc: 0.9863


val Loss: 0.0158 Acc: 0.9932: 100%|██████████| 56/56 [00:08<00:00,  6.81it/s]


val Loss: 0.0158 Acc: 0.9932

Epoch 36/49
----------


train Loss: 0.0430 Acc: 0.9884: 100%|██████████| 258/258 [00:53<00:00,  4.84it/s]


train Loss: 0.0430 Acc: 0.9884


val Loss: 0.0160 Acc: 0.9932: 100%|██████████| 56/56 [00:08<00:00,  6.74it/s]


val Loss: 0.0160 Acc: 0.9932

Epoch 37/49
----------


train Loss: 0.0411 Acc: 0.9857: 100%|██████████| 258/258 [00:54<00:00,  4.77it/s]


train Loss: 0.0411 Acc: 0.9857


val Loss: 0.0167 Acc: 0.9932: 100%|██████████| 56/56 [00:08<00:00,  6.81it/s]


val Loss: 0.0167 Acc: 0.9932

Epoch 38/49
----------


train Loss: 0.0391 Acc: 0.9876: 100%|██████████| 258/258 [00:54<00:00,  4.75it/s]


train Loss: 0.0391 Acc: 0.9876


val Loss: 0.0147 Acc: 0.9938: 100%|██████████| 56/56 [00:08<00:00,  6.75it/s]


val Loss: 0.0147 Acc: 0.9938

Epoch 39/49
----------


train Loss: 0.0403 Acc: 0.9865: 100%|██████████| 258/258 [00:53<00:00,  4.79it/s]


train Loss: 0.0403 Acc: 0.9865


val Loss: 0.0176 Acc: 0.9921: 100%|██████████| 56/56 [00:08<00:00,  6.71it/s]


val Loss: 0.0176 Acc: 0.9921

Epoch 40/49
----------


train Loss: 0.0390 Acc: 0.9870: 100%|██████████| 258/258 [00:53<00:00,  4.80it/s]


train Loss: 0.0390 Acc: 0.9870


val Loss: 0.0171 Acc: 0.9938: 100%|██████████| 56/56 [00:08<00:00,  6.73it/s]


val Loss: 0.0171 Acc: 0.9938

Epoch 41/49
----------


train Loss: 0.0381 Acc: 0.9873: 100%|██████████| 258/258 [00:53<00:00,  4.78it/s]


train Loss: 0.0381 Acc: 0.9873


val Loss: 0.0150 Acc: 0.9944: 100%|██████████| 56/56 [00:08<00:00,  6.83it/s]


val Loss: 0.0150 Acc: 0.9944

Epoch 42/49
----------


train Loss: 0.0416 Acc: 0.9869: 100%|██████████| 258/258 [00:54<00:00,  4.77it/s]


train Loss: 0.0416 Acc: 0.9869


val Loss: 0.0186 Acc: 0.9915: 100%|██████████| 56/56 [00:08<00:00,  6.75it/s]


val Loss: 0.0186 Acc: 0.9915

Epoch 43/49
----------


train Loss: 0.0411 Acc: 0.9871: 100%|██████████| 258/258 [00:54<00:00,  4.72it/s]


train Loss: 0.0411 Acc: 0.9871


val Loss: 0.0170 Acc: 0.9927: 100%|██████████| 56/56 [00:08<00:00,  6.91it/s]


val Loss: 0.0170 Acc: 0.9927

Epoch 44/49
----------


train Loss: 0.0417 Acc: 0.9861: 100%|██████████| 258/258 [00:53<00:00,  4.79it/s]


train Loss: 0.0417 Acc: 0.9861


val Loss: 0.0175 Acc: 0.9927: 100%|██████████| 56/56 [00:08<00:00,  6.90it/s]


val Loss: 0.0175 Acc: 0.9927

Epoch 45/49
----------


train Loss: 0.0386 Acc: 0.9865: 100%|██████████| 258/258 [00:54<00:00,  4.74it/s]


train Loss: 0.0386 Acc: 0.9865


val Loss: 0.0159 Acc: 0.9938: 100%|██████████| 56/56 [00:08<00:00,  6.69it/s]


val Loss: 0.0159 Acc: 0.9938

Epoch 46/49
----------


train Loss: 0.0454 Acc: 0.9842: 100%|██████████| 258/258 [00:53<00:00,  4.81it/s]


train Loss: 0.0454 Acc: 0.9842


val Loss: 0.0166 Acc: 0.9932: 100%|██████████| 56/56 [00:08<00:00,  6.70it/s]


val Loss: 0.0166 Acc: 0.9932

Epoch 47/49
----------


train Loss: 0.0374 Acc: 0.9853: 100%|██████████| 258/258 [00:53<00:00,  4.78it/s]


train Loss: 0.0374 Acc: 0.9853


val Loss: 0.0165 Acc: 0.9932: 100%|██████████| 56/56 [00:08<00:00,  6.94it/s]


val Loss: 0.0165 Acc: 0.9932

Epoch 48/49
----------


train Loss: 0.0384 Acc: 0.9870: 100%|██████████| 258/258 [00:54<00:00,  4.74it/s]


train Loss: 0.0384 Acc: 0.9870


val Loss: 0.0164 Acc: 0.9938: 100%|██████████| 56/56 [00:08<00:00,  6.89it/s]


val Loss: 0.0164 Acc: 0.9938

Epoch 49/49
----------


train Loss: 0.0406 Acc: 0.9880: 100%|██████████| 258/258 [00:54<00:00,  4.75it/s]


train Loss: 0.0406 Acc: 0.9880


val Loss: 0.0164 Acc: 0.9938: 100%|██████████| 56/56 [00:08<00:00,  6.83it/s]


val Loss: 0.0164 Acc: 0.9938

Training complete in 70m 48s
Best val Acc: 0.994353

Evaluating efficientnet-b0 model:


Evaluating: 100%|██████████| 56/56 [00:20<00:00,  2.70it/s]



                        precision    recall  f1-score   support

 Bacterial Leaf Blight       0.99      0.98      0.99       180
            Brown Spot       1.00      1.00      1.00       232
     Healthy Rice Leaf       0.99      1.00      0.99       163
            Leaf Blast       0.98      1.00      0.99       263
            Leaf scald       1.00      0.98      0.99       200
Narrow Brown Leaf Spot       1.00      0.99      0.99       144
            Neck Blast       1.00      1.00      1.00       150
            Rice Hispa       1.00      1.00      1.00       195
         Sheath Blight       0.99      1.00      0.99       245

              accuracy                           0.99      1772
             macro avg       0.99      0.99      0.99      1772
          weighted avg       0.99      0.99      0.99      1772

Confusion matrix, without normalization

===== Training completed for all models =====


In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, models, transforms
import time
import copy
import shutil
from tqdm import tqdm

# 设置随机种子确保结果可复现
torch.manual_seed(42)
np.random.seed(42)

# 配置参数
DATA_DIR = r"D:\homework\论文\论文\project\dataset\archive\Rice_Leaf_AUG\Rice_Leaf_AUG"  # 原始数据集路径
OUTPUT_DIR = "dataset_split"  # 划分后的数据集保存路径
IMAGE_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 1
NUM_WORKERS = 4  # 数据加载的线程数
SPLIT_RATIO = [0.7, 0.15, 0.15]  # 训练集、验证集、测试集比例

# 模型配置
MODELS = {
    #     "resnet50": {
    #     "model": models.resnet50,
    #     "pretrained": True,
    #     "feature_size": 2048  # ResNet50的特征维度
    # },
    # "vgg16": {
    #     "model": models.vgg16,
    #     "pretrained": True,
    #     "feature_size": 512 * 7 * 7  # VGG16的特征维度
    # },
    # "efficientnet-b0": {
    #     "model": models.efficientnet_b0,
    #     "pretrained": True,
    #     "feature_size": 1280  # EfficientNet-B0的特征维度
    # },
    "mobilenet_v3": {
        "model": models.mobilenet_v3_large,
        "pretrained": True,
        "feature_size": 960  # MobileNetV3的特征维度
    },

}

# 检查GPU是否可用
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")

def split_dataset(data_dir, output_dir, split_ratio):
    """将数据集划分为训练集、验证集和测试集"""
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    subdirs = [f.name for f in os.scandir(data_dir) if f.is_dir()]
    
    for subdir in subdirs:
        print(f"Processing class: {subdir}")
        class_dir = os.path.join(data_dir, subdir)
        files = [f for f in os.listdir(class_dir) if os.path.isfile(os.path.join(class_dir, f))]
        
        # 划分数据集
        train_files, test_files = train_test_split(files, test_size=split_ratio[2], random_state=42)
        train_files, val_files = train_test_split(train_files, test_size=split_ratio[1]/(split_ratio[0]+split_ratio[1]), random_state=42)
        
        # 创建输出目录
        for split in ["train", "val", "test"]:
            split_dir = os.path.join(output_dir, split, subdir)
            os.makedirs(split_dir, exist_ok=True)
        
        # 复制文件
        def copy_files(files_list, split_name):
            for file_name in tqdm(files_list, desc=f"Copying {split_name} files"):
                src = os.path.join(class_dir, file_name)
                dst = os.path.join(output_dir, split_name, subdir, file_name)
                shutil.copy(src, dst)
        
        copy_files(train_files, "train")
        copy_files(val_files, "val")
        copy_files(test_files, "test")
    
    print(f"Dataset split completed. Saved to {output_dir}")
    return output_dir

def create_data_transforms():
    data_transforms = {
        'train': transforms.Compose([
            transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.8, 1.0)),  # 调整裁剪范围
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(p=0.1),
            transforms.RandomRotation(30),  # 增加旋转角度
            transforms.ColorJitter(
                brightness=0.3,  # 增加亮度调整范围
                contrast=0.3,    # 增加对比度调整范围
                saturation=0.3,  # 增加饱和度调整范围
                hue=0.1          # 添加色调调整
            ),
            transforms.RandomGrayscale(p=0.1),
            transforms.RandomAffine(  # 仿射变换
                degrees=0,
                translate=(0.1, 0.1),  # 平移
                scale=(0.9, 1.1),      # 缩放
                shear=10               # 剪切
            ),
            transforms.RandomPerspective(distortion_scale=0.2, p=0.2),  # 透视变换
            transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),  # 高斯模糊
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.Resize(IMAGE_SIZE),
            transforms.CenterCrop(IMAGE_SIZE),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'test': transforms.Compose([
            transforms.Resize(IMAGE_SIZE),
            transforms.CenterCrop(IMAGE_SIZE),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }
    return data_transforms

def create_data_loaders(data_dir):
    """创建数据加载器"""
    data_transforms = create_data_transforms()
    
    # 创建数据集
    image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                              data_transforms[x])
                      for x in ['train', 'val', 'test']}
    
    # 创建数据加载器
    dataloaders = {x: DataLoader(image_datasets[x], batch_size=BATCH_SIZE,
                                 shuffle=True if x == 'train' else False,
                                 num_workers=NUM_WORKERS)
                   for x in ['train', 'val', 'test']}
    
    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val', 'test']}
    class_names = image_datasets['train'].classes
    
    return dataloaders, dataset_sizes, class_names

def initialize_model(model_name, num_classes, feature_extract=True):
    """初始化预训练模型"""
    model_info = MODELS[model_name]
    model_ft = model_info["model"](pretrained=model_info["pretrained"])
    
    # 冻结预训练模型的参数
    if feature_extract:
        for param in model_ft.parameters():
            param.requires_grad = False
    
    # 修改分类器
    if model_name == "efficientnet-b0":
        model_ft.classifier[1] = nn.Linear(model_info["feature_size"], num_classes)
    
    elif model_name == "mobilenet_v3":
        model_ft.classifier[3] = nn.Linear(model_info["feature_size"], num_classes)
    
    elif model_name == "resnet50":
        model_ft.fc = nn.Linear(model_info["feature_size"], num_classes)
    
    elif model_name == "vgg16":
        model_ft.classifier[6] = nn.Linear(model_info["feature_size"], num_classes)
    
    return model_ft

def train_model(model, dataloaders, criterion, optimizer, scheduler, num_epochs=25, model_name="model"):
    """训练模型"""
    since = time.time()
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    # 记录训练历史
    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': []
    }
    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs-1}')
        print('-' * 10)
        
        # 每个epoch都有一个训练和验证阶段
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # 训练模式
            else:
                model.eval()   # 评估模式
            
            running_loss = 0.0
            running_corrects = 0
            
            # 迭代数据
            progress_bar = tqdm(enumerate(dataloaders[phase]), total=len(dataloaders[phase]))
            for i, (inputs, labels) in progress_bar:
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                # 零梯度
                optimizer.zero_grad()
                
                # 前向传播
                # 只有在训练时才跟踪历史
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)
                    
                    # 只有在训练阶段才进行反向传播和优化
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                
                # 统计
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                
                # 更新进度条
                progress_bar.set_description(f"{phase} Loss: {running_loss/(i*BATCH_SIZE+inputs.size(0)):.4f} Acc: {running_corrects/(i*BATCH_SIZE+inputs.size(0)):.4f}")
            
            if phase == 'train' and scheduler:
                scheduler.step()
            
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
            
            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
            
            # 记录历史
            history[f'{phase}_loss'].append(epoch_loss)
            history[f'{phase}_acc'].append(epoch_acc.item())
            
            # 深拷贝模型
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                torch.save(model.state_dict(), f"{model_name}_best_model.pth")
        
        print()
    
    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')
    
    # 加载最佳模型权重
    model.load_state_dict(best_model_wts)
    return model, history

def evaluate_model(model, dataloader, class_names, model_name):
    """评估模型"""
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Evaluating"):
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # 打印分类报告
    print(f"\n{classification_report(all_labels, all_preds, target_names=class_names)}")
    
    # 计算混淆矩阵
    cm = confusion_matrix(all_labels, all_preds)
    plot_confusion_matrix(cm, class_names, f"{model_name}_confusion_matrix.png")
    
    return all_preds, all_labels

def plot_confusion_matrix(cm, classes, save_path, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
    """绘制混淆矩阵"""
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')
    
    plt.figure(figsize=(10, 10))
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)
    
    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, format(cm[i, j], fmt),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")
    
    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.savefig(save_path)
    plt.close()

def plot_training_curves(histories, save_path='training_curves.png'):
    """绘制多个模型的训练曲线"""
    plt.figure(figsize=(12, 10))
    
    # 绘制准确率曲线
    plt.subplot(2, 1, 1)
    for model_name, history in histories.items():
        plt.plot(history['train_acc'], label=f'{model_name} Train')
        plt.plot(history['val_acc'], label=f'{model_name} Val')
    
    plt.title('Model Accuracy')
    plt.ylabel('Accuracy')
    plt.xlabel('Epoch')
    plt.legend(loc='lower right')
    plt.grid(True)
    
    # 绘制损失曲线
    plt.subplot(2, 1, 2)
    for model_name, history in histories.items():
        plt.plot(history['train_loss'], label=f'{model_name} Train')
        plt.plot(history['val_loss'], label=f'{model_name} Val')
    
    plt.title('Model Loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(loc='upper right')
    plt.grid(True)
    
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

def main():
    """主函数"""
    # 划分数据集
    split_dir = split_dataset(DATA_DIR, OUTPUT_DIR, SPLIT_RATIO)
    
    # 创建数据加载器
    dataloaders, dataset_sizes, class_names = create_data_loaders(split_dir)
    print(f"Classes: {class_names}")
    num_classes = len(class_names)
    
    all_histories = {}
    
    # 训练所有模型
    for model_name in MODELS.keys():
        print(f"\n===== Training {model_name} model =====")
        
        # 初始化模型
        model_ft = initialize_model(model_name, num_classes, feature_extract=False)
        model_ft = model_ft.to(device)
        
        # 设置损失函数和优化器
        criterion = nn.CrossEntropyLoss()
        
        # 只优化最后一层的参数
        optimizer_ft = optim.Adam(model_ft.parameters(), lr=0.001)
        
        # 学习率调度器
        exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
        
        # 训练模型
        model_ft, history = train_model(
            model_ft, dataloaders, criterion, optimizer_ft, exp_lr_scheduler,
            num_epochs=EPOCHS, model_name=model_name
        )
        
        # 保存最终模型
        torch.save(model_ft.state_dict(), f"{model_name}_final_model.pth")
        
        # 评估模型
        print(f"\nEvaluating {model_name} model:")
        evaluate_model(model_ft, dataloaders['test'], class_names, model_name)
        
        # 保存训练历史
        all_histories[model_name] = history
    
    # 绘制所有模型的训练曲线
    plot_training_curves(all_histories, save_path='all_models_training_curves.png')
    
    print("\n===== Training completed for all models =====")

if __name__ == "__main__":
    main()

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, models, transforms
import time
import copy
import shutil
from tqdm import tqdm

# 设置随机种子确保结果可复现
torch.manual_seed(42)
np.random.seed(42)

# 配置参数
DATA_DIR = r"D:\homework\论文\论文\project\dataset\archive\Rice_Leaf_AUG\Rice_Leaf_AUG"  # 原始数据集路径
OUTPUT_DIR = "dataset_split"  # 划分后的数据集保存路径
IMAGE_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 1
NUM_WORKERS = 4  # 数据加载的线程数
SPLIT_RATIO = [0.7, 0.15, 0.15]  # 训练集、验证集、测试集比例

# 模型配置
MODELS = {
    #     "resnet50": {
    #     "model": models.resnet50,
    #     "pretrained": True,
    #     "feature_size": 2048  # ResNet50的特征维度
    # },
    # "vgg16": {
    #     "model": models.vgg16,
    #     "pretrained": True,
    #     "feature_size": 512 * 7 * 7  # VGG16的特征维度
    # },
    # "efficientnet-b0": {
    #     "model": models.efficientnet_b0,
    #     "pretrained": True,
    #     "feature_size": 1280  # EfficientNet-B0的特征维度
    # },
    "mobilenet_v3": {
        "model": models.mobilenet_v3_large,
        "pretrained": True,
        "feature_size": 960  # MobileNetV3的特征维度
    },

}

# 检查GPU是否可用
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")

def split_dataset(data_dir, output_dir, split_ratio):
    """将数据集划分为训练集、验证集和测试集"""
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    subdirs = [f.name for f in os.scandir(data_dir) if f.is_dir()]
    
    for subdir in subdirs:
        print(f"Processing class: {subdir}")
        class_dir = os.path.join(data_dir, subdir)
        files = [f for f in os.listdir(class_dir) if os.path.isfile(os.path.join(class_dir, f))]
        
        # 划分数据集
        train_files, test_files = train_test_split(files, test_size=split_ratio[2], random_state=42)
        train_files, val_files = train_test_split(train_files, test_size=split_ratio[1]/(split_ratio[0]+split_ratio[1]), random_state=42)
        
        # 创建输出目录
        for split in ["train", "val", "test"]:
            split_dir = os.path.join(output_dir, split, subdir)
            os.makedirs(split_dir, exist_ok=True)
        
        # 复制文件
        def copy_files(files_list, split_name):
            for file_name in tqdm(files_list, desc=f"Copying {split_name} files"):
                src = os.path.join(class_dir, file_name)
                dst = os.path.join(output_dir, split_name, subdir, file_name)
                shutil.copy(src, dst)
        
        copy_files(train_files, "train")
        copy_files(val_files, "val")
        copy_files(test_files, "test")
    
    print(f"Dataset split completed. Saved to {output_dir}")
    return output_dir

def create_data_transforms():
    data_transforms = {
        'train': transforms.Compose([
            transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.8, 1.0)),  # 调整裁剪范围
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(p=0.1),
            transforms.RandomRotation(30),  # 增加旋转角度
            transforms.ColorJitter(
                brightness=0.3,  # 增加亮度调整范围
                contrast=0.3,    # 增加对比度调整范围
                saturation=0.3,  # 增加饱和度调整范围
                hue=0.1          # 添加色调调整
            ),
            transforms.RandomGrayscale(p=0.1),
            transforms.RandomAffine(  # 仿射变换
                degrees=0,
                translate=(0.1, 0.1),  # 平移
                scale=(0.9, 1.1),      # 缩放
                shear=10               # 剪切
            ),
            transforms.RandomPerspective(distortion_scale=0.2, p=0.2),  # 透视变换
            transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),  # 高斯模糊
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.Resize(IMAGE_SIZE),
            transforms.CenterCrop(IMAGE_SIZE),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'test': transforms.Compose([
            transforms.Resize(IMAGE_SIZE),
            transforms.CenterCrop(IMAGE_SIZE),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }
    return data_transforms

def create_data_loaders(data_dir):
    """创建数据加载器"""
    data_transforms = create_data_transforms()
    
    # 创建数据集
    image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                              data_transforms[x])
                      for x in ['train', 'val', 'test']}
    
    # 创建数据加载器
    dataloaders = {x: DataLoader(image_datasets[x], batch_size=BATCH_SIZE,
                                 shuffle=True if x == 'train' else False,
                                 num_workers=NUM_WORKERS)
                   for x in ['train', 'val', 'test']}
    
    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val', 'test']}
    class_names = image_datasets['train'].classes
    
    return dataloaders, dataset_sizes, class_names

def initialize_model(model_name, num_classes, feature_extract=True):
    """初始化预训练模型"""
    model_info = MODELS[model_name]
    model_ft = model_info["model"](pretrained=model_info["pretrained"])
    
    # 冻结预训练模型的参数
    if feature_extract:
        for param in model_ft.parameters():
            param.requires_grad = False
    
    # 修改分类器
    if model_name == "efficientnet-b0":
        model_ft.classifier[1] = nn.Linear(model_info["feature_size"], num_classes)
    
    elif model_name == "mobilenet_v3":
        model_ft.classifier[3] = nn.Linear(model_info["feature_size"], num_classes)
    
    elif model_name == "resnet50":
        model_ft.fc = nn.Linear(model_info["feature_size"], num_classes)
    
    elif model_name == "vgg16":
        model_ft.classifier[6] = nn.Linear(model_info["feature_size"], num_classes)
    
    return model_ft

def train_model(model, dataloaders, criterion, optimizer, scheduler, num_epochs=25, model_name="model"):
    """训练模型"""
    since = time.time()
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    # 记录训练历史
    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': []
    }
    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs-1}')
        print('-' * 10)
        
        # 每个epoch都有一个训练和验证阶段
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # 训练模式
            else:
                model.eval()   # 评估模式
            
            running_loss = 0.0
            running_corrects = 0
            
            # 迭代数据
            progress_bar = tqdm(enumerate(dataloaders[phase]), total=len(dataloaders[phase]))
            for i, (inputs, labels) in progress_bar:
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                # 零梯度
                optimizer.zero_grad()
                
                # 前向传播
                # 只有在训练时才跟踪历史
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)
                    
                    # 只有在训练阶段才进行反向传播和优化
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                
                # 统计
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                
                # 更新进度条
                progress_bar.set_description(f"{phase} Loss: {running_loss/(i*BATCH_SIZE+inputs.size(0)):.4f} Acc: {running_corrects/(i*BATCH_SIZE+inputs.size(0)):.4f}")
            
            if phase == 'train' and scheduler:
                scheduler.step()
            
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
            
            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
            
            # 记录历史
            history[f'{phase}_loss'].append(epoch_loss)
            history[f'{phase}_acc'].append(epoch_acc.item())
            
            # 深拷贝模型
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                torch.save(model.state_dict(), f"{model_name}_best_model.pth")
        
        print()
    
    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')
    
    # 加载最佳模型权重
    model.load_state_dict(best_model_wts)
    return model, history

def evaluate_model(model, dataloader, class_names, model_name):
    """评估模型"""
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Evaluating"):
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # 打印分类报告
    print(f"\n{classification_report(all_labels, all_preds, target_names=class_names)}")
    
    # 计算混淆矩阵
    cm = confusion_matrix(all_labels, all_preds)
    plot_confusion_matrix(cm, class_names, f"{model_name}_confusion_matrix.png")
    
    return all_preds, all_labels

def plot_confusion_matrix(cm, classes, save_path, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
    """绘制混淆矩阵"""
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')
    
    plt.figure(figsize=(10, 10))
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)
    
    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, format(cm[i, j], fmt),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")
    
    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.savefig(save_path)
    plt.close()

def plot_training_curves(histories, save_path='training_curves.png'):
    """绘制多个模型的训练曲线"""
    plt.figure(figsize=(12, 10))
    
    # 绘制准确率曲线
    plt.subplot(2, 1, 1)
    for model_name, history in histories.items():
        plt.plot(history['train_acc'], label=f'{model_name} Train')
        plt.plot(history['val_acc'], label=f'{model_name} Val')
    
    plt.title('Model Accuracy')
    plt.ylabel('Accuracy')
    plt.xlabel('Epoch')
    plt.legend(loc='lower right')
    plt.grid(True)
    
    # 绘制损失曲线
    plt.subplot(2, 1, 2)
    for model_name, history in histories.items():
        plt.plot(history['train_loss'], label=f'{model_name} Train')
        plt.plot(history['val_loss'], label=f'{model_name} Val')
    
    plt.title('Model Loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(loc='upper right')
    plt.grid(True)
    
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

def main():
    """主函数"""
    # 划分数据集
    split_dir = split_dataset(DATA_DIR, OUTPUT_DIR, SPLIT_RATIO)
    
    # 创建数据加载器
    dataloaders, dataset_sizes, class_names = create_data_loaders(split_dir)
    print(f"Classes: {class_names}")
    num_classes = len(class_names)
    
    all_histories = {}
    
    # 训练所有模型
    for model_name in MODELS.keys():
        print(f"\n===== Training {model_name} model =====")
        
        # 初始化模型
        model_ft = initialize_model(model_name, num_classes, feature_extract=False)
        model_ft = model_ft.to(device)
        
        # 设置损失函数和优化器
        criterion = nn.CrossEntropyLoss()
        
        # 只优化最后一层的参数
        optimizer_ft = optim.Adam(model_ft.parameters(), lr=0.001)
        
        # 学习率调度器
        exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
        
        # 训练模型
        model_ft, history = train_model(
            model_ft, dataloaders, criterion, optimizer_ft, exp_lr_scheduler,
            num_epochs=EPOCHS, model_name=model_name
        )
        
        # 保存最终模型
        torch.save(model_ft.state_dict(), f"{model_name}_final_model.pth")
        
        # 评估模型
        print(f"\nEvaluating {model_name} model:")
        evaluate_model(model_ft, dataloaders['test'], class_names, model_name)
        
        # 保存训练历史
        all_histories[model_name] = history
    
    # 绘制所有模型的训练曲线
    plot_training_curves(all_histories, save_path='all_models_training_curves.png')
    
    print("\n===== Training completed for all models =====")

if __name__ == "__main__":
    main()