In [1]:
# download dataset only neeeded for the first time
# ! chmod +x download.sh
# ! ./download.sh

In [42]:
import os
from pathlib import Path
import shutil
from torchvision import models, datasets, transforms
from torch.utils.data import DataLoader, Dataset, random_split
import torch.nn as nn
from collections import defaultdict
from PIL import Image
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
from tqdm import tqdm
import os
from plant_dataset import PlantDataset
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report, precision_recall_fscore_support
import seaborn as sns
import pandas as pd

In [2]:
dataset_path = Path("dataset")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("mps" if torch.backends.mps.is_available() else device)

batch_size = 32
num_epochs = 25
learning_rate = 0.001
early_stopping_patience = 3

In [3]:
def fix_folder_name(path):
    for folder in os.listdir(path):
        folder_path = os.path.join(path, folder)
        if os.path.isdir(folder_path):
            if ",_" in folder:
                new_folder_name = folder.replace(",", "_")
                new_folder_path = os.path.join(path, new_folder_name)
                os.rename(folder_path, new_folder_path)
                folder_path = new_folder_path

            if folder[0] == "_":
                new_folder_name = folder[1:]
                new_folder_path = os.path.join(path, new_folder_name)
                os.rename(folder_path, new_folder_path)
                folder_path = new_folder_path

            fix_folder_name(folder_path)

def optimize_folder_structure(path):
    for folder in os.listdir(path):
        folder_path = os.path.join(path, folder)

        if os.path.isdir(folder_path) and "___" in folder:
            plant, status = folder.split("___", 1)
            plant_folder = os.path.join(path, plant)
            os.makedirs(plant_folder, exist_ok=True)

            status_folder = os.path.join(plant_folder, folder)

            if os.path.exists(status_folder):
                for item in os.listdir(folder_path):
                    src = os.path.join(folder_path, item)
                    dst = os.path.join(status_folder, item)
                    if not os.path.exists(dst):
                        shutil.move(src, dst)
                os.rmdir(folder_path)
            else:
                shutil.move(folder_path, status_folder)

            print(f"Reorganized: {folder} -> {plant}/{status_folder}")


# fix folder name
fix_folder_name(dataset_path)

# get plant classes
train_path = os.path.join(dataset_path, "train")
valid_path = os.path.join(dataset_path, "valid")

optimize_folder_structure(train_path)
optimize_folder_structure(valid_path)

In [17]:
# 数据转换定义
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize(256),
        transforms.RandomResizedCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# 创建数据集
train_dataset = PlantDataset(
    root_dir='dataset', transform=data_transforms['train'], mode='train')
val_dataset = PlantDataset(
    root_dir='dataset', transform=data_transforms['val'], mode='valid')

# 打印类别信息
print("植物种类:", train_dataset.plant_classes)
print("状态类别:", train_dataset.status_classes)
print("\n植物到状态的映射:")
for plant, statuses in train_dataset.plant_to_status.items():
    print(f"{plant}: {statuses}")

# 创建DataLoader
train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size,
                        shuffle=False, num_workers=4)

植物种类: ['Apple', 'Blueberry', 'Cherry_(including_sour)', 'Corn_(maize)', 'Grape', 'Orange', 'Peach', 'Pepper__bell', 'Potato', 'Raspberry', 'Soybean', 'Squash', 'Strawberry', 'Tomato']
状态类别: ['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy', 'Blueberry___healthy', 'Cherry_(including_sour)___Powdery_mildew', 'Cherry_(including_sour)___healthy', 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot', 'Corn_(maize)___Common_rust_', 'Corn_(maize)___Northern_Leaf_Blight', 'Corn_(maize)___healthy', 'Grape___Black_rot', 'Grape___Esca_(Black_Measles)', 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy', 'Orange___Haunglongbing_(Citrus_greening)', 'Peach___Bacterial_spot', 'Peach___healthy', 'Pepper__bell___Bacterial_spot', 'Pepper__bell___healthy', 'Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy', 'Raspberry___healthy', 'Soybean___healthy', 'Squash___Powdery_mildew', 'Strawberry___Leaf_scorch', 'Strawberry___healthy', 'Tomato

In [9]:
class CustomResNet50(nn.Module):
    def __init__(self, num_plants, num_statuses):
        super().__init__()
        self.backbone = models.resnet50(
            weights=models.ResNet50_Weights.DEFAULT)
        num_ftrs = self.backbone.fc.in_features
        # 修改最后一层
        self.backbone.fc = nn.Identity()
        self.dropout = nn.Dropout(p=0.5)
        
        # 双输出头
        self.plant_head = nn.Linear(num_ftrs, num_plants)
        self.status_head = nn.Linear(num_ftrs, num_statuses)

    def forward(self, x):
        features = self.backbone(x)
        features = self.dropout(features)
        plant_out = self.plant_head(features)
        status_out = self.status_head(features)
        return plant_out, status_out

    def compute_loss(self, plant_scores, status_scores, plant_y, status_y):
        # 计算植物分类的损失
        plant_loss = F.cross_entropy(plant_scores, plant_y)

        # 计算状态分类的损失
        status_loss = F.cross_entropy(status_scores, status_y)

        # 返回总损失
        return plant_loss * 2 + status_loss


class CustomDenseNet121(nn.Module):
    def __init__(self, num_plants, num_statuses):
        super().__init__()
        self.backbone = models.densenet121(
            weights=models.DenseNet121_Weights.DEFAULT)
        num_ftrs = self.backbone.classifier.in_features

        # 修改最后一层
        self.backbone.classifier = nn.Identity()
        self.dropout = nn.Dropout(p=0.5)

        # 双输出头
        self.plant_head = nn.Linear(num_ftrs, num_plants)
        self.status_head = nn.Linear(num_ftrs, num_statuses)

    def forward(self, x):
        features = self.backbone(x)
        features = self.dropout(features)
        plant_out = self.plant_head(features)
        status_out = self.status_head(features)
        return plant_out, status_out

    def compute_loss(self, plant_scores, status_scores, plant_y, status_y):
        # 计算植物分类的损失
        plant_loss = F.cross_entropy(plant_scores, plant_y)

        # 计算状态分类的损失
        status_loss = F.cross_entropy(status_scores, status_y)

        # 返回总损失
        return plant_loss * 2 + status_loss

# 初始化模型
resnet50_model = CustomResNet50(
    num_plants=len(train_dataset.plant_classes),
    num_statuses=len(train_dataset.status_classes)
)
denseNet121_model = CustomDenseNet121(
    num_plants=len(train_dataset.plant_classes),
    num_statuses=len(train_dataset.status_classes)
)

In [40]:
class Trainer:
    def __init__(self, model, train_loader, val_loader, device, config):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        self.config = config

        # 初始化优化器和学习率调度器
        self.optimizer = Adam(
            model.parameters(),
            lr=config['lr'],
            weight_decay=config['weight_decay']
        )
        self.scheduler = ReduceLROnPlateau(
            self.optimizer,
            mode='max',    # 监控指标越大越好
            patience=3,    # 3个epoch无改善则降低LR
            factor=0.5     # LR衰减因子
        )

        # 记录训练历史
        self.history = {
            'train_loss': [],
            'val_loss': [],
            'plant_acc': [],
            'status_acc': [],
            'lr': []
        }

        # 创建保存目录
        os.makedirs(config['save_dir'], exist_ok=True)

    def train(self):
        best_combined_acc = 0.0
        self.model.train()
        early_stopping_counter = 0

        for epoch in range(self.config['epochs']):
            print(f"\nEpoch {epoch+1}/{self.config['epochs']}")

            # 训练阶段
            train_loss, train_plant_acc, train_status_acc = self.train_epoch()

            # 验证阶段
            val_loss, val_plant_acc, val_status_acc = self.validate()

            # 记录当前学习率
            current_lr = self.optimizer.param_groups[0]['lr']
            self.history['lr'].append(current_lr)

            # 打印训练信息
            print(f"Train Loss: {train_loss:.4f} | "
                  f"Plant Acc: {train_plant_acc:.4f} | "
                  f"Status Acc: {train_status_acc:.4f} | "
                  f"LR: {current_lr:.2e}")

            print(f"Val Loss: {val_loss:.4f} | "
                  f"Plant Acc: {val_plant_acc:.4f} | "
                  f"Status Acc: {val_status_acc:.4f}")

            # 学习率调度（基于植物和病害分类的平均准确率）
            combined_acc = (val_plant_acc + val_status_acc) / 2
            self.scheduler.step(combined_acc)

            # 保存最佳模型
            if combined_acc > best_combined_acc:
                early_stopping_counter = 0
                best_combined_acc = combined_acc
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'best_acc': best_combined_acc,
                }, os.path.join(self.config['save_dir'], 'best_model.pth'))
                print(
                    f"Saved new best model with combined acc: {best_combined_acc:.4f}")
            else:
                early_stopping_counter += 1

            # 提前停止
            if early_stopping_counter >= early_stopping_patience:
                print(
                    f"Early stopping triggered after {early_stopping_patience} epochs without improvement.")
                break
        print("Training complete.")
        # 加载最佳模型
        checkpoint = torch.load(os.path.join(
            self.config['save_dir'], 'best_model.pth'))
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(
            checkpoint['optimizer_state_dict'])

        # 绘制训练曲线
    def train_epoch(self):
        self.model.train()
        total_loss = 0
        plant_correct = 0
        status_correct = 0
        total_samples = 0

        progress_bar = tqdm(self.train_loader, desc='Training')
        for images, plant_labels, status_labels in progress_bar:
            images = images.to(self.device)
            plant_labels = plant_labels.to(self.device)
            status_labels = status_labels.to(self.device)

            # 前向传播
            self.optimizer.zero_grad()
            plant_scores, status_scores = self.model(images)

            # 计算损失
            loss = self.model.compute_loss(
                plant_scores, status_scores, plant_labels, status_labels)

            # 反向传播
            loss.backward()
            self.optimizer.step()

            # 统计指标
            with torch.no_grad():
                total_loss += loss.item() * images.size(0)
                plant_pred = plant_scores.argmax(dim=1)
                status_pred = status_scores.argmax(dim=1)
                plant_correct += (plant_pred == plant_labels).sum().item()
                status_correct += (status_pred == status_labels).sum().item()
                total_samples += images.size(0)

            # 更新进度条
            progress_bar.set_postfix({
                'loss': total_loss / total_samples,
                'plant_acc': plant_correct / total_samples,
                'status_acc': status_correct / total_samples
            })

        avg_loss = total_loss / total_samples
        plant_acc = plant_correct / total_samples
        status_acc = status_correct / total_samples

        self.history['train_loss'].append(avg_loss)
        return avg_loss, plant_acc, status_acc

    def validate(self):
        self.model.eval()
        total_loss = 0
        plant_correct = 0
        status_correct = 0
        total_samples = 0
        progress_bar = tqdm(self.val_loader, desc='Validating')
        for images, plant_labels, status_labels in progress_bar:
            images = images.to(self.device)
            plant_labels = plant_labels.to(self.device)
            status_labels = status_labels.to(self.device)

            plant_scores, status_scores = self.model(images)
            loss = self.model.compute_loss(
                plant_scores, status_scores, plant_labels, status_labels)

            with torch.no_grad():
                total_loss += loss.item() * images.size(0)
                plant_pred = plant_scores.argmax(dim=1)
                status_pred = status_scores.argmax(dim=1)
                plant_correct += (plant_pred == plant_labels).sum().item()
                status_correct += (status_pred == status_labels).sum().item()
                total_samples += images.size(0)

            # 更新进度条
            progress_bar.set_postfix({
                'loss': total_loss / total_samples,
                'plant_acc': plant_correct / total_samples,
                'status_acc': status_correct / total_samples
            })
        avg_loss = total_loss / total_samples
        plant_acc = plant_correct / total_samples
        status_acc = status_correct / total_samples
        self.history['val_loss'].append(avg_loss)
        self.history['plant_acc'].append(plant_acc)
        self.history['status_acc'].append(status_acc)
        return avg_loss, plant_acc, status_acc

    def plot_history(self):
        plt.figure(figsize=(12, 5))

        # 损失曲线
        plt.subplot(1, 2, 1)
        plt.plot(self.history['train_loss'], label='Train Loss')
        plt.plot(self.history['val_loss'], label='Val Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()

        # 准确率曲线
        plt.subplot(1, 2, 2)
        plt.plot(self.history['plant_acc'], label='Plant Acc')
        plt.plot(self.history['status_acc'], label='Status Acc')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.legend()

        plt.tight_layout()
        plt.savefig(os.path.join(
            self.config['save_dir'], 'training_history.png'))
        plt.close()

    def evaluate(self, data_loader=None, plant_classes=None, status_classes=None, verbose=True):
        if data_loader is None:
            data_loader = self.val_loader

        self.model.eval()
        total_loss = 0
        plant_correct = 0
        status_correct = 0
        total_samples = 0

        # 用于计算指标
        all_plant_labels = []
        all_plant_preds = []
        all_status_labels = []
        all_status_preds = []

        with torch.no_grad():
            for images, plant_labels, status_labels in tqdm(data_loader, desc='Evaluating'):
                images = images.to(self.device)
                plant_labels = plant_labels.to(self.device)
                status_labels = status_labels.to(self.device)

                plant_scores, status_scores = self.model(images)
                loss = self.model.compute_loss(
                    plant_scores, status_scores, plant_labels, status_labels)

                total_loss += loss.item() * images.size(0)
                plant_pred = plant_scores.argmax(dim=1)
                status_pred = status_scores.argmax(dim=1)

                plant_correct += (plant_pred == plant_labels).sum().item()
                status_correct += (status_pred == status_labels).sum().item()
                total_samples += images.size(0)

                # 收集所有预测和标签
                all_plant_labels.extend(plant_labels.cpu().numpy())
                all_plant_preds.extend(plant_pred.cpu().numpy())
                all_status_labels.extend(status_labels.cpu().numpy())
                all_status_preds.extend(status_pred.cpu().numpy())

        # 计算指标
        avg_loss = total_loss / total_samples
        plant_acc = plant_correct / total_samples
        status_acc = status_correct / total_samples

        # 计算分类报告
        plant_report = classification_report(all_plant_labels, all_plant_preds,
                                             target_names=plant_classes,
                                             zero_division=0)
        status_report = classification_report(all_status_labels, all_status_preds,
                                              target_names=status_classes,
                                              zero_division=0)

        # 计算混淆矩阵
        plant_cm = confusion_matrix(all_plant_labels, all_plant_preds)
        status_cm = confusion_matrix(all_status_labels, all_status_preds)
        
        plant_precision, plant_recall, plant_f1, _ = precision_recall_fscore_support(
            all_plant_labels, all_plant_preds, average='weighted', zero_division=0)
        
        status_precision, status_recall, status_f1, _ = precision_recall_fscore_support(
            all_status_labels, all_status_preds, average='weighted', zero_division=0)
        print(f"Plant Precision: {plant_precision:.4f}, Recall: {plant_recall:.4f}, F1: {plant_f1:.4f}")
        print(f"Status Precision: {status_precision:.4f}, Recall: {status_recall:.4f}, F1: {status_f1:.4f}")

        if verbose:
            # print("\n" + "="*50)
            # print("Plant Classification Metrics:")
            # print(f"Accuracy: {plant_acc*100:.2f}%")
            # print(plant_report)
            self._plot_confusion_matrix(
                plant_cm, plant_classes, "Plant Confusion Matrix")

            # print("\n" + "="*50)
            # print("Status Classification Metrics:")
            # print(f"Accuracy: {status_acc*100:.2f}%")
            # print(status_report)
            self._plot_confusion_matrix(
                status_cm, status_classes, "Status Confusion Matrix")

            # print("\n" + "="*50)
            # print(f"Average Loss: {avg_loss:.4f}")
            # print("="*50 + "\n")

        return {
            'plant_accuracy': plant_acc,
            'status_accuracy': status_acc,
            'plant_report': plant_report,
            'status_report': status_report,
            'plant_cm': plant_cm,
            'status_cm': status_cm,
            'avg_loss': avg_loss
        }

    def _plot_confusion_matrix(self, cm, classes, title):
        """绘制混淆矩阵"""
        plt.figure(figsize=(12, 10))
        df_cm = pd.DataFrame(cm, index=classes, columns=classes)
        sns.heatmap(df_cm, annot=True, fmt='d', cmap='Blues')
        plt.title(title)
        plt.xlabel('Predicted')
        plt.ylabel('Actual')
        plt.xticks(rotation=45)
        plt.yticks(rotation=0)
        plt.tight_layout()

        # 保存图像
        save_path = os.path.join(
            self.config['save_dir'], f"{title.lower().replace(' ', '_')}.png")
        plt.savefig(save_path)
        plt.close()
        print(f"Confusion matrix saved to: {save_path}")


# 配置参数
config = {
    'lr': 1e-4,
    'weight_decay': 1e-5,
    'epochs': 20,
    'save_dir': './resnet50_model'
}

# 初始化训练器
trainer = Trainer(
    model=resnet50_model,  # 或 denseNet121_model
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    config=config
)

# 开始训练
# trainer.train()

# 绘制训练曲线
# trainer.plot_history()

In [None]:
train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(
    val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

denseNet121_model = denseNet121_model.to(device)

# 配置参数
resnet50_config = {
    'lr': 1e-4,
    'weight_decay': 1e-5,
    'epochs': num_epochs,
    'save_dir': './resnet50_model'
}

# 初始化训练器
trainer = Trainer(
    model=resnet50_model,  # 或 denseNet121_model
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    config=resnet50_config
)

# 开始训练
trainer.train()

# 绘制训练曲线
trainer.plot_history()

In [None]:
train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(
    val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

denseNet121_model = denseNet121_model.to(device)

# 配置参数
denseNet121_config = {
    'lr': 1e-4,
    'weight_decay': 1e-5,
    'epochs': num_epochs,
    'save_dir': './denseNet121_model'
}

# 初始化训练器
trainer = Trainer(
    model=denseNet121_model,  # 或 denseNet121_model
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    config=denseNet121_config
)

# 开始训练
trainer.train()

# 绘制训练曲线

trainer.plot_history()

In [43]:
# evaluate resnet50
model = CustomResNet50(
    num_plants=len(train_dataset.plant_classes),
    num_statuses=len(train_dataset.status_classes)
)
model.load_state_dict(torch.load(os.path.join(
    resnet50_config['save_dir'], 'best_model.pth'))['model_state_dict'])
model.to(device)
trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    config=resnet50_config
)

result = trainer.evaluate(
    data_loader=val_loader,
    plant_classes=train_dataset.plant_classes,
    status_classes=train_dataset.status_classes
)

print(result)

Evaluating: 100%|██████████| 550/550 [01:58<00:00,  4.63it/s]


Plant Precision: 0.9995, Recall: 0.9995, F1: 0.9995
Status Precision: 0.9980, Recall: 0.9980, F1: 0.9980
Confusion matrix saved to: ./resnet50_model/plant_confusion_matrix.png
Confusion matrix saved to: ./resnet50_model/status_confusion_matrix.png
{'plant_accuracy': 0.9994878215342591, 'status_accuracy': 0.9979512861370362, 'plant_report': '                         precision    recall  f1-score   support\n\n                  Apple       1.00      1.00      1.00      1943\n              Blueberry       1.00      1.00      1.00       454\nCherry_(including_sour)       1.00      1.00      1.00       877\n           Corn_(maize)       1.00      1.00      1.00      1829\n                  Grape       1.00      1.00      1.00      1805\n                 Orange       1.00      1.00      1.00       503\n                  Peach       1.00      1.00      1.00       891\n           Pepper__bell       1.00      1.00      1.00       975\n                 Potato       1.00      1.00      1.00      1

In [44]:
# evaluate densenet121
model = CustomDenseNet121(
    num_plants=len(train_dataset.plant_classes),
    num_statuses=len(train_dataset.status_classes)
)
model.load_state_dict(torch.load(os.path.join(
    denseNet121_config['save_dir'], 'best_model.pth'))['model_state_dict'])
model.to(device)
trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    config=denseNet121_config
)

result = trainer.evaluate(
    data_loader=val_loader,
    plant_classes=train_dataset.plant_classes,
    status_classes=train_dataset.status_classes
)

print(result)

Evaluating: 100%|██████████| 550/550 [01:51<00:00,  4.92it/s]


Plant Precision: 0.9997, Recall: 0.9997, F1: 0.9997
Status Precision: 0.9965, Recall: 0.9965, F1: 0.9965
Confusion matrix saved to: ./denseNet121_model/plant_confusion_matrix.png
Confusion matrix saved to: ./denseNet121_model/status_confusion_matrix.png
{'plant_accuracy': 0.9997154564079217, 'status_accuracy': 0.996471659458229, 'plant_report': '                         precision    recall  f1-score   support\n\n                  Apple       1.00      1.00      1.00      1943\n              Blueberry       1.00      1.00      1.00       454\nCherry_(including_sour)       1.00      1.00      1.00       877\n           Corn_(maize)       1.00      1.00      1.00      1829\n                  Grape       1.00      1.00      1.00      1805\n                 Orange       1.00      1.00      1.00       503\n                  Peach       1.00      1.00      1.00       891\n           Pepper__bell       1.00      1.00      1.00       975\n                 Potato       1.00      1.00      1.00  