In [1]:
#数据清洗与标注
import os
import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import json
from collections import Counter
import shutil

In [2]:
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

In [None]:
# 数据集探索
def explore_dataset(data_path):
    """探索数据集结构和内容"""
    class_dirs = [d for d in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, d))]
    
    dataset_info = {}
    total_images = 0
    
    print("=== 数据集概览 ===")
    for class_name in class_dirs:
        class_path = os.path.join(data_path, class_name)
        images = [f for f in os.listdir(class_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        dataset_info[class_name] = len(images)
        total_images += len(images)
        print(f"{class_name}: {len(images)} 张图片")
    
    print(f"\n总计: {total_images} 张图片，{len(class_dirs)} 个类别")
    return dataset_info, class_dirs

# 探索数据集
dataset_info, class_names = explore_dataset("data/raw/melon17_full")

In [None]:
# 数据集探索
def explore_dataset(data_path):
    """探索数据集结构和内容"""
    class_dirs = [d for d in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, d))]
    
    dataset_info = {}
    total_images = 0
    
    print("=== 数据集概览 ===")
    for class_name in class_dirs:
        class_path = os.path.join(data_path, class_name)
        images = [f for f in os.listdir(class_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        dataset_info[class_name] = len(images)
        total_images += len(images)
        print(f"{class_name}: {len(images)} 张图片")
    
    print(f"\n总计: {total_images} 张图片，{len(class_dirs)} 个类别")
    return dataset_info, class_dirs

# 探索数据集
dataset_info, class_names = explore_dataset("data/raw/melon17_full")

In [5]:
# 数据质量检查
def check_image_quality(data_path, class_names):
    """检查图像质量：尺寸、格式、是否损坏"""
    
    quality_report = {
        'corrupted_images': [],
        'size_distribution': [],
        'format_distribution': Counter(),
        'class_quality': {}
    }
    
    for class_name in class_names:
        class_path = os.path.join(data_path, class_name)
        corrupted_count = 0
        valid_count = 0
        
        for img_file in os.listdir(class_path):
            if img_file.lower().endswith(('.jpg', '.jpeg', '.png')):
                img_path = os.path.join(class_path, img_file)
                
                try:
                    # 尝试读取图像
                    img = cv2.imread(img_path)
                    if img is None:
                        quality_report['corrupted_images'].append(img_path)
                        corrupted_count += 1
                    else:
                        h, w = img.shape[:2]
                        quality_report['size_distribution'].append((w, h))
                        
                        # 获取文件格式
                        ext = os.path.splitext(img_file)[1].lower()
                        quality_report['format_distribution'][ext] += 1
                        valid_count += 1
                        
                except Exception as e:
                    quality_report['corrupted_images'].append(img_path)
                    corrupted_count += 1
        
        quality_report['class_quality'][class_name] = {
            'valid': valid_count,
            'corrupted': corrupted_count
        }
    
    return quality_report

In [None]:
# 执行质量检查
print("正在检查图像质量...")
quality_report = check_image_quality("data/raw/melon17_full", class_names)

print(f"\n=== 质量检查报告 ===")
print(f"损坏图像数量: {len(quality_report['corrupted_images'])}")
print(f"格式分布: {dict(quality_report['format_distribution'])}")

if quality_report['size_distribution']:
    sizes = np.array(quality_report['size_distribution'])
    print(f"图像尺寸统计:")
    print(f"  平均尺寸: {sizes.mean(axis=0).astype(int)}")
    print(f"  最小尺寸: {sizes.min(axis=0)}")
    print(f"  最大尺寸: {sizes.max(axis=0)}")

In [None]:
def clean_dataset(data_path, class_names, output_path, target_size=(224, 224)):
    """清洗数据集：移除损坏图像，调整尺寸，转换格式"""
    
    # 创建清洗后的数据目录
    os.makedirs(output_path, exist_ok=True)# 数据清洗：移除损坏图像，统一格式
def clean_dataset(data_path, class_names, output_path, target_size=(224, 224)):
    """清洗数据集：移除损坏图像，调整尺寸，转换格式"""
    
    # 创建清洗后的数据目录
    os.makedirs(output_path, exist_ok=True)
    
    cleaning_stats = {
        'processed': 0,
        'skipped': 0,
        'resized': 0
    }
    
    for class_name in class_names:
        class_input_path = os.path.join(data_path, class_name)
        class_output_path = os.path.join(output_path, class_name)
        os.makedirs(class_output_path, exist_ok=True)
        
        print(f"正在处理类别: {class_name}")
        
        for img_file in os.listdir(class_input_path):
            if img_file.lower().endswith(('.jpg', '.jpeg', '.png')):
                input_img_path = os.path.join(class_input_path, img_file)
                
                try:
                    # 读取图像
                    img = cv2.imread(input_img_path)
                    if img is None:
                        cleaning_stats['skipped'] += 1
                        continue
                    
                    # 调整尺寸
                    if img.shape[:2] != target_size:
                        img = cv2.resize(img, target_size)
                        cleaning_stats['resized'] += 1
                    
                    # 保存为JPG格式
                    output_filename = os.path.splitext(img_file)[0] + '.jpg'
                    output_img_path = os.path.join(class_output_path, output_filename)
                    cv2.imwrite(output_img_path, img, [cv2.IMWRITE_JPEG_QUALITY, 95])
                    
                    cleaning_stats['processed'] += 1
                    
                except Exception as e:
                    print(f"处理 {input_img_path} 时出错: {e}")
                    cleaning_stats['skipped'] += 1
    
    return cleaning_stats

# 执行数据清洗
print("开始数据清洗...")
cleaning_stats = clean_dataset("data/raw/melon17_full", class_names, "data/processed/melon17_clean")

print(f"\n=== 数据清洗报告 ===")
print(f"成功处理: {cleaning_stats['processed']} 张图片")
print(f"跳过损坏: {cleaning_stats['skipped']} 张图片")
print(f"调整尺寸: {cleaning_stats['resized']} 张图片")

In [None]:

# 创建标注文件
def create_annotations(data_path, class_names):
    """创建训练所需的标注文件"""
    
    # 创建类别映射
    class_to_idx = {class_name: idx for idx, class_name in enumerate(sorted(class_names))}
    idx_to_class = {idx: class_name for class_name, idx in class_to_idx.items()}
    
    # 收集所有图像路径和标签
    annotations = []
    
    for class_name in class_names:
        class_path = os.path.join(data_path, class_name)
        class_idx = class_to_idx[class_name]
        
        for img_file in os.listdir(class_path):
            if img_file.lower().endswith('.jpg'):
                img_path = os.path.join(class_path, img_file)
                relative_path = os.path.relpath(img_path, data_path)
                
                annotations.append({
                    'image_path': relative_path,
                    'class_name': class_name,
                    'class_id': class_idx
                })
    
    # 保存标注信息
    annotations_data = {
        'class_to_idx': class_to_idx,
        'idx_to_class': idx_to_class,
        'annotations': annotations,
        'num_classes': len(class_names),
        'total_images': len(annotations)
    }
    
    # 保存为JSON文件
    with open('data/annotations/annotations.json', 'w', encoding='utf-8') as f:
        json.dump(annotations_data, f, ensure_ascii=False, indent=2)
    
    # 创建DataFrame并保存为CSV
    df = pd.DataFrame(annotations)
    df.to_csv('data/annotations/annotations.csv', index=False, encoding='utf-8')
    
    return annotations_data

# 创建标注
annotations_data = create_annotations("data/processed/melon17_clean", class_names)

print(f"=== 标注创建完成 ===")
print(f"类别数量: {annotations_data['num_classes']}")
print(f"图像总数: {annotations_data['total_images']}")
print(f"类别映射: {annotations_data['class_to_idx']}")


In [9]:
# 数据可视化
def visualize_dataset(data_path, annotations_data, samples_per_class=3):
    """可视化数据集样本"""
    
    class_names = list(annotations_data['class_to_idx'].keys())
    num_classes = len(class_names)
    
    fig, axes = plt.subplots(num_classes, samples_per_class, 
                            figsize=(samples_per_class*3, num_classes*3))
    
    if num_classes == 1:
        axes = axes.reshape(1, -1)
    
    for i, class_name in enumerate(class_names):
        class_path = os.path.join(data_path, class_name)
        images = [f for f in os.listdir(class_path) if f.endswith('.jpg')]
        
        for j in range(min(samples_per_class, len(images))):
            img_path = os.path.join(class_path, images[j])
            img = cv2.imread(img_path)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            
            axes[i][j].imshow(img)
            axes[i][j].set_title(f"{class_name}\n{images[j]}")
            axes[i][j].axis('off')
    
    plt.tight_layout()
    plt.savefig('results/dataset_samples.png', dpi=150, bbox_inches='tight')
    plt.show()


In [None]:
# 可视化数据集
visualize_dataset("data/processed/melon17_clean", annotations_data)