In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm  # 用于显示训练进度条
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
import os
import json
from PIL import Image

def train_model():
    # 设置设备
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # 数据预处理和加载
    base_dir= os.getcwd()  # 替换为你的数据集路径
    train_dir = os.path.join(base_dir,"train")
    validation_dir = os.path.join(base_dir,"validation")
    test_dir = os.path.join(base_dir,"test")
    
        # 数据预处理。transforms提供一系列数据预处理方法
    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),    # 随机裁剪
                                     transforms.RandomHorizontalFlip(),    # 水平方向随机反转
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),    # 标准化
        "val": transforms.Compose([transforms.Resize(256),    # 图像缩放
                                   transforms.CenterCrop(224),    # 中心裁剪
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

    # 加载训练、验证和测试数据集
    train_dataset = datasets.ImageFolder(root=train_dir, transform=data_transform["train"])
    validation_dataset = datasets.ImageFolder(root=validation_dir, transform=data_transform["val"])
    test_dataset = datasets.ImageFolder(root=test_dir, transform=data_transform["train"])
    
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    validation_loader = DataLoader(validation_dataset, batch_size=32, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    
    veg_list = train_dataset.class_to_idx
    class_dict = dict((val, key) for key, val in veg_list.items())    # 将字典中键值对翻转。
 
    # 将class_dict编码成json格式文件
    json_str = json.dumps(class_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

     # 加载模型
    model_path = 'best_model.pth'
    if os.path.exists(model_path):
        print(f"加载已保存的模型：{model_path}")
        model = models.resnet50()
        num_classes = len(train_dataset.classes)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
        model.load_state_dict(torch.load(model_path, map_location=device))
    else:
        print("加载预训练的 ResNet50 模型")
        model = models.resnet50(pretrained=True)
        num_classes = len(train_dataset.classes)
        model.fc = nn.Linear(model.fc.in_features, num_classes)  # 修改最后一层
    
    # 使用 DataParallel
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    
    model = model.to(device)
    
    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    # 创建 TensorBoard SummaryWriter
    log_dir = os.path.join(base_dir,"tf-logs")  # TensorBoard 日志目录
    writer = SummaryWriter(log_dir)
    total_time = 0    # 统计训练过程总时间
    # 训练模型
    num_epochs = 10
    best_accuracy = 0.0
    best_model_state = None
    accuracy_threshold = 95.0  # 提前停止的准确率阈值
    patience = 5               # 提前停止的耐心次数
    no_improve_epochs = 0
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        total_samples = 0
        correct_samples = 0
    
        for images, labels in tqdm(train_loader):
            images, labels = images.to(device), labels.to(device)
    
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
    
            running_loss += loss.item()
    
            # 计算准确率
            _, predicted = torch.max(outputs.data, 1)
            total_samples += labels.size(0)
            correct_samples += (predicted == labels).sum().item()
    
        avg_loss = running_loss / len(train_loader)
        train_accuracy = 100 * correct_samples / total_samples
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Accuracy: {train_accuracy:.2f}%')
    
        # 记录训练损失和准确率到 TensorBoard
        writer.add_scalar('Loss/train', avg_loss, epoch)
        writer.add_scalar('Accuracy/train', train_accuracy, epoch)
    
        # 每个 epoch 后验证
        model.eval()
        correct = 0
        total = 0
        val_loss = 0.0
        all_labels = []
        all_predictions = []
        
        with torch.no_grad():
            for images, labels in validation_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
    
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                
                # 记录所有标签和预测
                all_labels.extend(labels.cpu().numpy())
                all_predictions.extend(predicted.cpu().numpy())
    
        val_accuracy = 100 * correct / total
        avg_val_loss = val_loss / len(validation_loader)
        print(f'Validation Loss: {avg_val_loss:.4f}, Accuracy: {val_accuracy:.2f}%')
    
        # 记录验证损失和准确率到 TensorBoard
        writer.add_scalar('Loss/validation', avg_val_loss, epoch)
        writer.add_scalar('Accuracy/validation', val_accuracy, epoch)
    
        # 打印分类报告
        print(classification_report(all_labels, all_predictions, target_names=train_dataset.classes))
        # 早停条件
        if val_accuracy > best_accuracy:
            best_accuracy = val_accuracy
            best_model_state = model.state_dict()
            torch.save(best_model_state, model_path)
            no_improve_epochs = 0
        else:
            no_improve_epochs += 1

        # 提前停止的检查
        if val_accuracy >= accuracy_threshold:
            print(f"验证准确率达到 {accuracy_threshold}%，提前停止训练。")
            break
        elif no_improve_epochs >= patience:
            print(f"在 {patience} 个 epoch 中没有改进，提前停止训练。")
            break
            
    # 加载最佳模型状态
    model.load_state_dict(best_model_state)
    print("训练完成！")
    
    # 测试模型
    model.eval()
    correct = 0
    total = 0
    test_loss = 0.0
    all_labels = []
    all_predictions = []
    
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            test_loss += loss.item()
    
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            # 记录所有标签和预测
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())

    # 计算混淆矩阵
    cm = confusion_matrix(all_labels, all_predictions)

    # 绘制混淆矩阵
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=train_dataset.classes,
                yticklabels=train_dataset.classes)
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.title('Confusion Matrix')
    plt.show()

    # 保存混淆矩阵图像为临时文件
    confusion_matrix_image = 'confusion_matrix.png'
    plt.savefig(confusion_matrix_image)
    plt.close()
    # 将混淆矩阵图像添加到TensorBoard
    image = Image.open(confusion_matrix_image)
    image_tensor = transforms.ToTensor()(image)
    writer.add_image('Confusion Matrix', image_tensor)
    
    test_accuracy = 100 * correct / total
    avg_test_loss = test_loss / len(test_loader)
    print(f'Test Loss: {avg_test_loss:.4f}, Accuracy: {test_accuracy:.2f}%')
    
    # 记录测试损失和准确率到 TensorBoard
    writer.add_scalar('Loss/test', avg_test_loss)
    writer.add_scalar('Accuracy/test', test_accuracy)
    
    # 生成分类报告
    print(classification_report(all_labels, all_predictions, target_names=train_dataset.classes))
    
    print("Training complete！")  # 用于表示训练结束

    # 关闭 TensorBoard writer
    writer.close()

if __name__ == '__main__':
    train_model()
