In [1]:
# download dataset only neeeded for the first time
# ! chmod +x download.sh
# ! ./download.sh

In [None]:
import os
from pathlib import Path
import shutil
from torchvision import models, datasets, transforms
from torch.utils.data import DataLoader, Dataset, random_split
import torch.nn as nn
from collections import defaultdict
from PIL import Image
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
from tqdm import tqdm
import os
from plant_dataset import PlantDataset

In [None]:
dataset_path = Path("dataset")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("mps" if torch.backends.mps.is_available() else device)

batch_size = 32
num_epochs = 100
learning_rate = 0.001
early_stopping_patience = 10

In [None]:
def fix_folder_name(path):
    for folder in os.listdir(path):
        folder_path = os.path.join(path, folder)
        if os.path.isdir(folder_path):
            if ",_" in folder:
                new_folder_name = folder.replace(",", "_")
                new_folder_path = os.path.join(path, new_folder_name)
                os.rename(folder_path, new_folder_path)
                folder_path = new_folder_path

            if folder[0] == "_":
                new_folder_name = folder[1:]
                new_folder_path = os.path.join(path, new_folder_name)
                os.rename(folder_path, new_folder_path)
                folder_path = new_folder_path

            fix_folder_name(folder_path)

def optimize_folder_structure(path):
    for folder in os.listdir(path):
        folder_path = os.path.join(path, folder)

        if os.path.isdir(folder_path) and "___" in folder:
            plant, status = folder.split("___", 1)
            plant_folder = os.path.join(path, plant)
            os.makedirs(plant_folder, exist_ok=True)

            status_folder = os.path.join(plant_folder, folder)

            if os.path.exists(status_folder):
                for item in os.listdir(folder_path):
                    src = os.path.join(folder_path, item)
                    dst = os.path.join(status_folder, item)
                    if not os.path.exists(dst):
                        shutil.move(src, dst)
                os.rmdir(folder_path)
            else:
                shutil.move(folder_path, status_folder)

            print(f"Reorganized: {folder} -> {plant}/{status_folder}")


# fix folder name
fix_folder_name(dataset_path)

# get plant classes
train_path = os.path.join(dataset_path, "train")
valid_path = os.path.join(dataset_path, "valid")

optimize_folder_structure(train_path)
optimize_folder_structure(valid_path)

In [None]:

# 数据转换定义
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize(256),
        transforms.RandomResizedCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# 创建数据集
train_dataset = PlantDataset(
    root_dir='dataset', transform=data_transforms['train'], mode='train')
val_dataset = PlantDataset(
    root_dir='dataset', transform=data_transforms['val'], mode='valid')

# 打印类别信息
print("植物种类:", train_dataset.plant_classes)
print("状态类别:", train_dataset.status_classes)
print("\n植物到状态的映射:")
for plant, statuses in train_dataset.plant_to_status.items():
    print(f"{plant}: {statuses}")

# 创建DataLoader
batch_size = 32
train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size,
                        shuffle=False, num_workers=4)

植物种类: ['Apple', 'Blueberry', 'Cherry_(including_sour)', 'Corn_(maize)', 'Grape', 'Orange', 'Peach', 'Pepper__bell', 'Potato', 'Raspberry', 'Soybean', 'Squash', 'Strawberry', 'Tomato']
状态类别: ['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy', 'Blueberry___healthy', 'Cherry_(including_sour)___Powdery_mildew', 'Cherry_(including_sour)___healthy', 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot', 'Corn_(maize)___Common_rust_', 'Corn_(maize)___Northern_Leaf_Blight', 'Corn_(maize)___healthy', 'Grape___Black_rot', 'Grape___Esca_(Black_Measles)', 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy', 'Orange___Haunglongbing_(Citrus_greening)', 'Peach___Bacterial_spot', 'Peach___healthy', 'Pepper__bell___Bacterial_spot', 'Pepper__bell___healthy', 'Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy', 'Raspberry___healthy', 'Soybean___healthy', 'Squash___Powdery_mildew', 'Strawberry___Leaf_scorch', 'Strawberry___healthy', 'Tomato

In [None]:
class CustomResNet50(nn.Module):
    def __init__(self, num_plants, num_statuses):
        super().__init__()
        self.backbone = models.resnet50(
            weights=models.ResNet50_Weights.DEFAULT)
        num_ftrs = self.backbone.fc.in_features
        # 修改最后一层
        self.backbone.fc = nn.Identity()

        # 双输出头
        self.plant_head = nn.Linear(num_ftrs, num_plants)
        self.status_head = nn.Linear(num_ftrs, num_statuses)

    def forward(self, x):
        features = self.backbone(x)
        plant_out = self.plant_head(features)
        status_out = self.status_head(features)
        return plant_out, status_out

    def compute_loss(self, plant_scores, status_scores, plant_y, status_y):
        # 计算植物分类的损失
        plant_loss = F.cross_entropy(plant_scores, plant_y)

        # 计算状态分类的损失
        status_loss = F.cross_entropy(status_scores, status_y)

        # 返回总损失
        return plant_loss * 2 + status_loss


class CustomDenseNet121(nn.Module):
    def __init__(self, num_plants, num_statuses):
        super().__init__()
        self.backbone = models.densenet121(
            weights=models.DenseNet121_Weights.DEFAULT)
        num_ftrs = self.backbone.classifier.in_features

        # 修改最后一层
        self.backbone.classifier = nn.Identity()

        # 双输出头
        self.plant_head = nn.Linear(num_ftrs, num_plants)
        self.status_head = nn.Linear(num_ftrs, num_statuses)

    def forward(self, x):
        features = self.backbone(x)
        plant_out = self.plant_head(features)
        status_out = self.status_head(features)
        return plant_out, status_out

    def compute_loss(self, plant_scores, status_scores, plant_y, status_y):
        # 计算植物分类的损失
        plant_loss = F.cross_entropy(plant_scores, plant_y)

        # 计算状态分类的损失
        status_loss = F.cross_entropy(status_scores, status_y)

        # 返回总损失
        return plant_loss * 2 + status_loss

# 初始化模型
resnet50_model = CustomResNet50(
    num_plants=len(train_dataset.plant_classes),
    num_statuses=len(train_dataset.status_classes)
)
denseNet121_model = CustomDenseNet121(
    num_plants=len(train_dataset.plant_classes),
    num_statuses=len(train_dataset.status_classes)
)

In [None]:
class Trainer:
    def __init__(self, model, train_loader, val_loader, device, config):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        self.config = config

        # 初始化优化器和学习率调度器
        self.optimizer = Adam(
            model.parameters(),
            lr=config['lr'],
            weight_decay=config['weight_decay']
        )
        self.scheduler = ReduceLROnPlateau(
            self.optimizer,
            mode='max',    # 监控指标越大越好
            patience=3,    # 3个epoch无改善则降低LR
            factor=0.5     # LR衰减因子
        )

        # 记录训练历史
        self.history = {
            'train_loss': [],
            'val_loss': [],
            'plant_acc': [],
            'status_acc': [],
            'lr': []
        }

        # 创建保存目录
        os.makedirs(config['save_dir'], exist_ok=True)

    def train(self):
        best_combined_acc = 0.0
        self.model.train()

        for epoch in range(self.config['epochs']):
            print(f"\nEpoch {epoch+1}/{self.config['epochs']}")

            # 训练阶段
            train_loss, train_plant_acc, train_status_acc = self.train_epoch()

            # 验证阶段
            val_loss, val_plant_acc, val_status_acc = self.validate()

            # 记录当前学习率
            current_lr = self.optimizer.param_groups[0]['lr']
            self.history['lr'].append(current_lr)

            # 打印训练信息
            print(f"Train Loss: {train_loss:.4f} | "
                  f"Plant Acc: {train_plant_acc:.4f} | "
                  f"Status Acc: {train_status_acc:.4f} | "
                  f"LR: {current_lr:.2e}")

            print(f"Val Loss: {val_loss:.4f} | "
                  f"Plant Acc: {val_plant_acc:.4f} | "
                  f"Status Acc: {val_status_acc:.4f}")

            # 学习率调度（基于植物和病害分类的平均准确率）
            combined_acc = (val_plant_acc + val_status_acc) / 2
            self.scheduler.step(combined_acc)

            # 保存最佳模型
            if combined_acc > best_combined_acc:
                best_combined_acc = combined_acc
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'best_acc': best_combined_acc,
                }, os.path.join(self.config['save_dir'], 'best_model.pth'))
                print(
                    f"Saved new best model with combined acc: {best_combined_acc:.4f}")

    def train_epoch(self):
        self.model.train()
        total_loss = 0
        plant_correct = 0
        status_correct = 0
        total_samples = 0

        progress_bar = tqdm(self.train_loader, desc='Training')
        for images, plant_labels, status_labels in progress_bar:
            images = images.to(self.device)
            plant_labels = plant_labels.to(self.device)
            status_labels = status_labels.to(self.device)

            # 前向传播
            self.optimizer.zero_grad()
            plant_scores, status_scores = self.model(images)

            # 计算损失
            loss = self.model.compute_loss(
                plant_scores, status_scores, plant_labels, status_labels)

            # 反向传播
            loss.backward()
            self.optimizer.step()

            # 统计指标
            with torch.no_grad():
                total_loss += loss.item() * images.size(0)
                plant_pred = plant_scores.argmax(dim=1)
                status_pred = status_scores.argmax(dim=1)
                plant_correct += (plant_pred == plant_labels).sum().item()
                status_correct += (status_pred == status_labels).sum().item()
                total_samples += images.size(0)

            # 更新进度条
            progress_bar.set_postfix({
                'loss': total_loss / total_samples,
                'plant_acc': plant_correct / total_samples,
                'status_acc': status_correct / total_samples
            })

        avg_loss = total_loss / total_samples
        plant_acc = plant_correct / total_samples
        status_acc = status_correct / total_samples

        self.history['train_loss'].append(avg_loss)
        return avg_loss, plant_acc, status_acc

    def plot_history(self):
        import matplotlib.pyplot as plt

        plt.figure(figsize=(12, 5))

        # 损失曲线
        plt.subplot(1, 2, 1)
        plt.plot(self.history['train_loss'], label='Train Loss')
        plt.plot(self.history['val_loss'], label='Val Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()

        # 准确率曲线
        plt.subplot(1, 2, 2)
        plt.plot(self.history['plant_acc'], label='Plant Acc')
        plt.plot(self.history['status_acc'], label='Status Acc')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.legend()

        plt.tight_layout()
        plt.savefig(os.path.join(
            self.config['save_dir'], 'training_history.png'))
        plt.close()


# 配置参数
config = {
    'lr': 1e-4,
    'weight_decay': 1e-5,
    'epochs': 20,
    'save_dir': './checkpoints'
}

# 初始化训练器
trainer = Trainer(
    model=resnet50_model,  # 或 denseNet121_model
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    config=config
)

# 开始训练
# trainer.train()

# 绘制训练曲线
# trainer.plot_history()

: 

In [None]:
train_loader = DataLoader(
    train_dataset, batch_size=32, shuffle=True, num_workers=4)

resnet50_model = resnet50_model.to(device)
denseNet121_model = denseNet121_model.to(device)

# 定义优化器和学习率调度器
resnet50_optimizer = torch.optim.Adam(
    resnet50_model.parameters(), lr=learning_rate)
denseNet121_optimizer = torch.optim.Adam(
    denseNet121_model.parameters(), lr=learning_rate)
resnet50_scheduler = torch.optim.lr_scheduler.StepLR(
    resnet50_optimizer, step_size=7, gamma=0.1)
denseNet121_scheduler = torch.optim.lr_scheduler.StepLR(
    denseNet121_optimizer, step_size=7, gamma=0.1)

criterion = nn.CrossEntropyLoss()
early_stopping_counter = 0
best_val_loss = float('inf')


# 配置参数
config = {
    'lr': 1e-4,
    'weight_decay': 1e-5,
    'epochs': 20,
    'save_dir': './checkpoints'
}

# 初始化训练器
trainer = Trainer(
    model=resnet50_model,  # 或 denseNet121_model
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    config=config
)

# 开始训练
trainer.train()

# 绘制训练曲线
trainer.plot_history()


Epoch 1/20


Training:   0%|          | 4/2197 [00:29<4:09:54,  6.84s/it, loss=8.91, plant_acc=0.133, status_acc=0.0156]