In [66]:
# download dataset only neeeded for the first time
# ! chmod +x download.sh
# ! ./download.sh

In [None]:
import os
from pathlib import Path
import shutil
from torchvision import models, datasets, transforms
from torch.utils.data import DataLoader, Dataset, random_split
import torch.nn as nn
from collections import defaultdict
from PIL import Image
import numpy as np
import torch

In [68]:
dataset_path = Path("dataset")

In [69]:
def fix_folder_name(path):
    for folder in os.listdir(path):
        folder_path = os.path.join(path, folder)
        if os.path.isdir(folder_path):
            if ",_" in folder:
                new_folder_name = folder.replace(",", "_")
                new_folder_path = os.path.join(path, new_folder_name)
                os.rename(folder_path, new_folder_path)
                folder_path = new_folder_path

            if folder[0] == "_":
                new_folder_name = folder[1:]
                new_folder_path = os.path.join(path, new_folder_name)
                os.rename(folder_path, new_folder_path)
                folder_path = new_folder_path

            fix_folder_name(folder_path)

def optimize_folder_structure(path):
    for folder in os.listdir(path):
        folder_path = os.path.join(path, folder)

        if os.path.isdir(folder_path) and "___" in folder:
            plant, status = folder.split("___", 1)
            plant_folder = os.path.join(path, plant)
            os.makedirs(plant_folder, exist_ok=True)

            status_folder = os.path.join(plant_folder, folder)

            if os.path.exists(status_folder):
                for item in os.listdir(folder_path):
                    src = os.path.join(folder_path, item)
                    dst = os.path.join(status_folder, item)
                    if not os.path.exists(dst):
                        shutil.move(src, dst)
                os.rmdir(folder_path)
            else:
                shutil.move(folder_path, status_folder)

            print(f"Reorganized: {folder} -> {plant}/{status_folder}")


# fix folder name
fix_folder_name(dataset_path)

# get plant classes
train_path = os.path.join(dataset_path, "train")
valid_path = os.path.join(dataset_path, "valid")

optimize_folder_structure(train_path)
optimize_folder_structure(valid_path)

In [None]:
class PlantDataset(Dataset):
    def __init__(self, root_dir, transform=None, mode='train'):
        self.root_dir = os.path.join(root_dir, mode)
        self.transform = transform
        self.plant_to_status = defaultdict(list)
        self.status_to_plant = {}
        self.samples = []

        # 验证并构建层级关系
        self._validate_structure()

        # 获取类别信息
        self.plant_classes = sorted(self.plant_to_status.keys())
        self.status_classes = sorted(self.status_to_plant.keys())

        # 创建映射字典
        self.plant_to_idx = {plant: idx for idx,
                             plant in enumerate(self.plant_classes)}
        self.status_to_idx = {status: idx for idx,
                              status in enumerate(self.status_classes)}

    def _validate_structure(self):
        """验证文件夹结构并建立植物-状态映射关系"""
        for plant in os.listdir(self.root_dir):
            plant_path = os.path.join(self.root_dir, plant)
            if not os.path.isdir(plant_path):
                continue

            for status in os.listdir(plant_path):
                status_path = os.path.join(plant_path, status)
                if not os.path.isdir(status_path):
                    continue

                # 验证状态是否属于该植物的有效状态
                if plant in self.plant_to_status and status in self.plant_to_status[plant]:
                    pass  # 已经存在的关系
                else:
                    self.plant_to_status[plant].append(status)
                    self.status_to_plant[status] = plant

                # 添加样本
                for img_name in os.listdir(status_path):
                    if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                        img_path = os.path.join(status_path, img_name)
                        self.samples.append((img_path, plant, status))

    def is_valid_status(self, plant, status):
        return plant in status

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, plant, status = self.samples[idx]
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        # 返回图像、植物类别索引、状态类别索引
        return image, self.plant_to_idx[plant], self.status_to_idx[status]


# 数据转换定义
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# 创建数据集
train_dataset = PlantDataset(
    root_dir='dataset', transform=data_transforms['train'], mode='train')
val_dataset = PlantDataset(
    root_dir='dataset', transform=data_transforms['val'], mode='valid')

# 打印类别信息
print("植物种类:", train_dataset.plant_classes)
print("状态类别:", train_dataset.status_classes)
print("\n植物到状态的映射:")
for plant, statuses in train_dataset.plant_to_status.items():
    print(f"{plant}: {statuses}")

# 创建DataLoader
batch_size = 32
train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size,
                        shuffle=False, num_workers=4)

植物种类: ['Apple', 'Blueberry', 'Cherry_(including_sour)', 'Corn_(maize)', 'Grape', 'Orange', 'Peach', 'Pepper__bell', 'Potato', 'Raspberry', 'Soybean', 'Squash', 'Strawberry', 'Tomato']
状态类别: ['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy', 'Blueberry___healthy', 'Cherry_(including_sour)___Powdery_mildew', 'Cherry_(including_sour)___healthy', 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot', 'Corn_(maize)___Common_rust_', 'Corn_(maize)___Northern_Leaf_Blight', 'Corn_(maize)___healthy', 'Grape___Black_rot', 'Grape___Esca_(Black_Measles)', 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy', 'Orange___Haunglongbing_(Citrus_greening)', 'Peach___Bacterial_spot', 'Peach___healthy', 'Pepper__bell___Bacterial_spot', 'Pepper__bell___healthy', 'Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy', 'Raspberry___healthy', 'Soybean___healthy', 'Squash___Powdery_mildew', 'Strawberry___Leaf_scorch', 'Strawberry___healthy', 'Tomato

In [None]:
def get_resnet50_model():
    # Load pre-trained ResNet50 model
    model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
    return model


def get_densenet121_model():
    # Load pre-trained DenseNet121 model
    model = models.densenet121(weights=models.DenseNet121_Weights.DEFAULT)
    return model


class ResNet50_With_Dual_SVM():
    def __init__(self, num_plants, num_statuses):
        super().__init__()
        self.backbone = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        num_ftrs = self.backbone.fc.in_features
        
        # 使用SVM替换最后的全连接层
        
        
        # 双输出头
        self.plant_head = nn.Linear(num_ftrs, num_plants)
        self.status_head = nn.Linear(num_ftrs, num_statuses)
        # SVM层
        self.plant_svm = nn.Linear(num_plants, num_plants)
        self.status_svm = nn.Linear(num_statuses, num_statuses)


class DualHeadModel(nn.Module):
    def __init__(self, model, num_plants, num_statuses):
        super().__init__()
        self.backbone = model
        if hasattr(self.backbone, 'fc'):
            num_ftrs = self.backbone.fc.in_features
            self.backbone.fc = nn.Identity()
        else:
            num_ftrs = self.backbone.classifier.in_features
            self.backbone.classifier = nn.Identity()

        # 双输出头
        self.plant_head = nn.Linear(num_ftrs, num_plants)
        self.status_head = nn.Linear(num_ftrs, num_statuses)

    def forward(self, x):
        features = self.backbone(x)
        plant_out = self.plant_head(features)
        status_out = self.status_head(features)
        return plant_out, status_out


# 初始化模型
resnet50_model = DualHeadModel(
    model=get_resnet50_model(),
    num_plants=len(train_dataset.plant_classes),
    num_statuses=len(train_dataset.status_classes)
)
denseNet121_model = DualHeadModel(
    model=get_densenet121_model(),
    num_plants=len(train_dataset.plant_classes),
    num_statuses=len(train_dataset.status_classes)
)

AttributeError: 'Sequential' object has no attribute 'in_features'