# 类别默认

In [8]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
通用遥感影像监督分类系统
-------------------------------------------------
功能：
1. 自动读取多波段遥感影像；
2. 从矢量样本中提取训练/验证数据；
3. 支持随机森林 / SVM / XGBoost 分类；
4. 采用分块预测模式；
5. 输出分类结果 GeoTIFF；
6. 自动生成分类报告与混淆矩阵；
7. 显示分类影像和精度评价结果。
"""

import os
import time
import logging
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import seaborn as sns
import geopandas as gpd
import rioxarray as rxr
from shapely.geometry import mapping
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.metrics import (confusion_matrix, accuracy_score, classification_report, 
                           cohen_kappa_score, precision_score, recall_score, f1_score)
from sklearn.inspection import permutation_importance
from tqdm import tqdm

plt.rcParams["font.sans-serif"] = ["SimHei", "DejaVu Sans"]  # 支持中文
plt.rcParams["axes.unicode_minus"] = False  # 支持负号显示

# ------------------ 参数配置 ------------------
IMAGE_PATH = r"D:\code313\Geo_programe\rasterio\RF\data\2017_09_05_stack.tif"
TRAIN_SHP = r"D:\code313\Geo_programe\rasterio\RF\data\cal.shp"
VAL_SHP = r"D:\code313\Geo_programe\rasterio\RF\data\val.shp"
ATTRIBUTE = "class"
OUT_DIR = Path("./results_v2")

CLASSIFIER = "rf"  # 可选: "rf", "svm", "xgb"
N_ESTIMATORS = 300
BLOCK_SIZE = 512
USE_GPU = False

# 自动生成颜色配置
COLOR_PALETTE = ['forestgreen', 'lightblue', 'gray', 'tan', 'yellow', 
                'darkred', 'purple', 'orange', 'pink', 'brown', 
                'cyan', 'magenta', 'lime', 'navy', 'teal']

OUT_DIR.mkdir(exist_ok=True)

# ------------------ 日志系统 ------------------
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    handlers=[
        logging.FileHandler(OUT_DIR / "classification_log.txt", encoding="utf-8"),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# ------------------ 辅助函数 ------------------
def get_class_info_from_shp(shp_path, attribute):
    """从shp文件中获取类别信息和自动生成的颜色"""
    gdf = gpd.read_file(shp_path)
    unique_classes = sorted(gdf[attribute].unique())
    
    # 生成类别名称映射（直接用class值）
    class_names = {cls: f'Class_{cls}' for cls in unique_classes}
    
    # 生成颜色映射
    class_colors = {}
    for i, cls in enumerate(unique_classes):
        class_colors[cls] = COLOR_PALETTE[i % len(COLOR_PALETTE)]
    
    return class_names, class_colors, unique_classes

def rasterize_samples(shp, ref_img, attr):
    """将矢量样本栅格化为与影像对齐的数组"""
    import rasterio.features
    
    gdf = gpd.read_file(shp)
    gdf = gdf.to_crs(ref_img.rio.crs)
    shapes = ((geom, value) for geom, value in zip(gdf.geometry, gdf[attr]))
    
    arr = rasterio.features.rasterize(
        shapes=shapes,
        out_shape=ref_img.shape[1:],
        transform=ref_img.rio.transform(),
        fill=0,
        all_touched=True,
        dtype="uint16"
    )
    return arr

def extract_samples(image, mask):
    """根据掩膜提取样本特征与标签"""
    data = np.moveaxis(image.values, 0, -1)  # (bands, rows, cols) → (rows, cols, bands)
    valid = mask > 0
    X = data[valid]
    y = mask[valid]
    return X, y

def get_classifier(name):
    """构造分类器"""
    if name == "rf":
        return RandomForestClassifier(
            n_estimators=N_ESTIMATORS, n_jobs=-1, oob_score=True, verbose=1
        )
    elif name == "svm":
        return SVC(kernel="rbf", probability=True)
    elif name == "xgb":
        try:
            from xgboost import XGBClassifier
            return XGBClassifier(
                n_estimators=N_ESTIMATORS, learning_rate=0.1, max_depth=8, n_jobs=-1
            )
        except ImportError:
            raise ImportError("未安装 xgboost，请先运行 pip install xgboost")
    else:
        raise ValueError(f"未知分类器类型: {name}")

def plot_confusion_matrix(y_true, y_pred, class_names, save_path):
    """绘制详细的混淆矩阵"""
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(10, 8))
    
    # 计算百分比
    cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100
    
    # 创建热图
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names,
                cbar_kws={'label': '样本数量'})
    
    plt.xlabel('预测类别', fontsize=12)
    plt.ylabel('真实类别', fontsize=12)
    plt.title('混淆矩阵', fontsize=14, fontweight='bold')
    plt.xticks(rotation=45)
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    # 同时保存百分比版本
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm_percent, annot=True, fmt='.1f', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names,
                cbar_kws={'label': '百分比 (%)'})
    plt.xlabel('预测类别', fontsize=12)
    plt.ylabel('真实类别', fontsize=12)
    plt.title('混淆矩阵 (百分比)', fontsize=14, fontweight='bold')
    plt.xticks(rotation=45)
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig(str(save_path).replace('.png', '_percent.png'), dpi=300, bbox_inches='tight')
    plt.close()

def comprehensive_evaluation(y_true, y_pred, class_names, save_path):
    """全方位精度评价"""
    # 计算各项指标
    overall_accuracy = accuracy_score(y_true, y_pred)
    kappa = cohen_kappa_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average=None, zero_division=0)
    recall = recall_score(y_true, y_pred, average=None, zero_division=0)
    f1 = f1_score(y_true, y_pred, average=None, zero_division=0)
    
    # 创建详细报告
    report = classification_report(y_true, y_pred, target_names=class_names, digits=4)
    
    # 创建精度评价表格
    eval_df = pd.DataFrame({
        '类别': class_names,
        '精确率 (Precision)': precision,
        '召回率 (Recall)': recall,
        'F1分数': f1,
        '样本数量': np.bincount(y_true)[1:len(class_names)+1]  # 从1开始计数
    })
    
    # 保存详细报告
    with open(save_path, 'w', encoding='utf-8') as f:
        f.write("="*60 + "\n")
        f.write("           遥感影像分类精度评价报告\n")
        f.write("="*60 + "\n\n")
        
        f.write(f"总体精度 (Overall Accuracy): {overall_accuracy:.4f}\n")
        f.write(f"Kappa系数: {kappa:.4f}\n\n")
        
        f.write("各类别精度评价:\n")
        f.write("-"*60 + "\n")
        f.write(eval_df.to_string(index=False, float_format='%.4f'))
        f.write("\n\n")
        
        f.write("详细分类报告:\n")
        f.write("-"*60 + "\n")
        f.write(report)
    
    # 绘制精度指标条形图
    plt.figure(figsize=(12, 6))
    x = np.arange(len(class_names))
    width = 0.25
    
    plt.bar(x - width, precision, width, label='精确率', alpha=0.8)
    plt.bar(x, recall, width, label='召回率', alpha=0.8)
    plt.bar(x + width, f1, width, label='F1分数', alpha=0.8)
    
    plt.xlabel('地物类别')
    plt.ylabel('分数')
    plt.title('各类别分类精度指标')
    plt.xticks(x, class_names, rotation=45)
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(str(save_path).replace('.txt', '_metrics.png'), dpi=300, bbox_inches='tight')
    plt.close()
    
    return overall_accuracy, kappa, eval_df

def plot_classification_results(original_img, classified_img, class_names, class_colors, save_path):
    """显示原始影像和分类结果"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))
    
    # 显示原始影像 (使用前3个波段作为RGB)
    if original_img.shape[0] >= 3:
        rgb_data = np.moveaxis(original_img.values[:3], 0, -1)
        # 数据标准化显示
        p2, p98 = np.percentile(rgb_data, (2, 98))
        rgb_display = np.clip((rgb_data - p2) / (p98 - p2), 0, 1)
        ax1.imshow(rgb_display)
    else:
        # 单波段影像显示
        ax1.imshow(original_img.values[0], cmap='gray')
    
    ax1.set_title('原始遥感影像', fontsize=14, fontweight='bold')
    ax1.axis('off')
    
    # 显示分类结果
    classified_data = classified_img.values.squeeze()
    
    # 创建分类图例
    classes = np.unique(classified_data)
    classes = classes[classes > 0]  # 排除背景值
    
    # 创建颜色映射
    colors = [class_colors.get(c, 'black') for c in classes]
    labels = [class_names.get(c, f'Class_{c}') for c in classes]
    
    cmap = mcolors.ListedColormap(colors)
    bounds = np.append(classes, classes[-1] + 1) - 0.5
    norm = mcolors.BoundaryNorm(bounds, cmap.N)
    
    im = ax2.imshow(classified_data, cmap=cmap, norm=norm)
    ax2.set_title('分类结果', fontsize=14, fontweight='bold')
    ax2.axis('off')
    
    # 添加图例
    from matplotlib.patches import Patch
    legend_elements = [Patch(facecolor=color, label=label) 
                      for color, label in zip(colors, labels)]
    ax2.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.35, 1))
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

def plot_feature_importance(clf, feature_names, save_path):
    """绘制特征重要性图（适用于随机森林和XGBoost）"""
    if hasattr(clf, 'feature_importances_'):
        importances = clf.feature_importances_
        indices = np.argsort(importances)[::-1]
        
        plt.figure(figsize=(10, 6))
        plt.title('特征重要性排序', fontsize=14, fontweight='bold')
        plt.bar(range(len(importances)), importances[indices])
        plt.xticks(range(len(importances)), [feature_names[i] for i in indices], rotation=45)
        plt.xlabel('特征波段')
        plt.ylabel('重要性')
        plt.tight_layout()
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()

def predict_by_block(model, image, out_path, block_size=BLOCK_SIZE):
    """分块预测整幅影像"""
    import rasterio
    
    height, width = image.shape[1], image.shape[2]
    
    # 修正：正确获取栅格配置文件
    profile = {
        'driver': 'GTiff',
        'dtype': 'uint16',
        'nodata': 0,
        'width': width,
        'height': height,
        'count': 1,
        'crs': image.rio.crs,
        'transform': image.rio.transform(),
        'compress': 'lzw',
        'tiled': True,
        'blockxsize': min(block_size, width),
        'blockysize': min(block_size, block_size)
    }

    with rasterio.open(out_path, "w", **profile) as dst:
        for y in tqdm(range(0, height, block_size), desc="Block predicting"):
            h = min(block_size, height - y)
            
            # 读取当前块的数据
            block_data = image.isel(y=slice(y, y+h)).values
            data = np.moveaxis(block_data, 0, -1)
            original_shape = data.shape
            data = data.reshape(-1, data.shape[-1])
            data = np.nan_to_num(data)
            
            # 预测
            preds = model.predict(data).reshape(original_shape[0], original_shape[1]).astype("uint16")
            
            # 写入结果
            dst.write(preds, 1, window=rasterio.windows.Window(0, y, width, h))
    
    return out_path

# ------------------ 主流程 ------------------
def main():
    t0 = time.time()
    logger.info("开始监督分类任务...")

    # 0. 从训练样本shp文件中获取类别信息
    logger.info("正在读取类别信息...")
    class_names, class_colors, train_classes = get_class_info_from_shp(TRAIN_SHP, ATTRIBUTE)
    logger.info(f"检测到类别: {list(class_names.values())}")

    # 1. 读取影像
    img = rxr.open_rasterio(IMAGE_PATH, masked=True)
    logger.info(f"影像尺寸: {img.shape}, 波段数: {img.rio.count}")

    # 2. 训练样本栅格化与提取
    logger.info("正在处理训练样本...")
    train_mask = rasterize_samples(TRAIN_SHP, img, ATTRIBUTE)
    X_train, y_train = extract_samples(img, train_mask)
    logger.info(f"训练样本数: {len(y_train)}")

    # 3. 训练分类器
    clf = get_classifier(CLASSIFIER)
    logger.info(f"使用分类器: {clf.__class__.__name__}")
    clf.fit(X_train, y_train)
    logger.info("模型训练完成。")

    # 4. 精度评估（训练集）
    y_pred_train = clf.predict(X_train)
    
    # 获取实际存在的类别
    actual_train_classes = sorted(np.unique(y_train))
    train_class_names = [class_names.get(c, f'Class_{c}') for c in actual_train_classes if c > 0]
    
    # 全方位精度评价
    overall_acc, kappa, eval_df = comprehensive_evaluation(
        y_train, y_pred_train, train_class_names, OUT_DIR / "train_evaluation.txt"
    )
    logger.info(f"训练集总体精度: {overall_acc:.4f}, Kappa: {kappa:.4f}")
    
    # 绘制训练集混淆矩阵
    plot_confusion_matrix(y_train, y_pred_train, train_class_names, OUT_DIR / "train_cm.png")

    # 5. 特征重要性分析（如果适用）
    if hasattr(clf, 'feature_importances_'):
        feature_names = [f'波段{i+1}' for i in range(X_train.shape[1])]
        plot_feature_importance(clf, feature_names, OUT_DIR / "feature_importance.png")

    # 6. 分块预测整幅影像
    logger.info("开始分块预测...")
    classified_path = OUT_DIR / "classified_result.tif"
    predict_by_block(clf, img, classified_path)
    logger.info(f"分类结果保存至: {classified_path}")

    # 7. 显示分类结果
    logger.info("生成分类结果可视化...")
    classified_img = rxr.open_rasterio(classified_path)
    plot_classification_results(img, classified_img, class_names, class_colors, OUT_DIR / "classification_visualization.png")

    # 8. 验证阶段
    if os.path.exists(VAL_SHP):
        logger.info("正在进行验证...")
        val_mask = rasterize_samples(VAL_SHP, img, ATTRIBUTE)
        with rxr.open_rasterio(classified_path) as pred_img:
            pred_arr = pred_img.values.squeeze()
        
        Xv = pred_arr[val_mask > 0]
        yv = val_mask[val_mask > 0]
        
        # 验证集全方位精度评价
        val_classes = sorted(np.unique(yv))
        val_class_names = [class_names.get(c, f'Class_{c}') for c in val_classes if c > 0]
        
        val_overall_acc, val_kappa, val_eval_df = comprehensive_evaluation(
            yv, Xv, val_class_names, OUT_DIR / "validation_evaluation.txt"
        )
        logger.info(f"验证集总体精度: {val_overall_acc:.4f}, Kappa: {val_kappa:.4f}")
        
        # 绘制验证集混淆矩阵
        plot_confusion_matrix(yv, Xv, val_class_names, OUT_DIR / "val_cm.png")

        # 生成综合报告
        with open(OUT_DIR / "comprehensive_report.txt", "w", encoding="utf-8") as f:
            f.write("遥感影像分类综合报告\n")
            f.write("="*50 + "\n")
            f.write(f"分类器: {clf.__class__.__name__}\n")
            f.write(f"训练样本数: {len(y_train)}\n")
            f.write(f"验证样本数: {len(yv)}\n")
            f.write(f"类别属性字段: {ATTRIBUTE}\n")
            f.write(f"检测到的类别: {list(class_names.values())}\n\n")
            
            f.write("精度评价汇总:\n")
            f.write("-"*30 + "\n")
            f.write(f"训练集总体精度: {overall_acc:.4f}\n")
            f.write(f"训练集Kappa系数: {kappa:.4f}\n")
            f.write(f"验证集总体精度: {val_overall_acc:.4f}\n")
            f.write(f"验证集Kappa系数: {val_kappa:.4f}\n\n")
            
            f.write("各类别验证精度:\n")
            f.write("-"*30 + "\n")
            f.write(val_eval_df.to_string(index=False, float_format='%.4f'))

    # 9. 生成分类统计报告
    logger.info("生成分类统计报告...")
    classified_data = classified_img.values.squeeze()
    unique, counts = np.unique(classified_data[classified_data > 0], return_counts=True)
    total_pixels = np.sum(counts)
    
    stats_df = pd.DataFrame({
        '类别编号': unique,
        '类别名称': [class_names.get(c, f'Class_{c}') for c in unique],
        '像元数量': counts,
        '面积占比 (%)': (counts / total_pixels * 100).round(2)
    })
    
    stats_df.to_csv(OUT_DIR / "classification_statistics.csv", index=False, encoding='utf-8-sig')
    
    # 绘制面积占比饼图
    plt.figure(figsize=(10, 8))
    plt.pie(stats_df['面积占比 (%)'], labels=stats_df['类别名称'], autopct='%1.1f%%', startangle=90)
    plt.title('分类结果面积占比分布', fontsize=14, fontweight='bold')
    plt.axis('equal')
    plt.tight_layout()
    plt.savefig(OUT_DIR / "area_distribution.png", dpi=300, bbox_inches='tight')
    plt.close()

    # 10. 保存类别信息
    class_info_df = pd.DataFrame({
        '类别编号': list(class_names.keys()),
        '类别名称': list(class_names.values()),
        '显示颜色': [class_colors.get(c, 'black') for c in class_names.keys()]
    })
    class_info_df.to_csv(OUT_DIR / "class_information.csv", index=False, encoding='utf-8-sig')
    
    logger.info(f"全部任务完成，用时 {time.time()-t0:.1f} 秒。")
    logger.info(f"所有结果已保存至: {OUT_DIR.absolute()}")

if __name__ == "__main__":
    main()

2025-10-15 14:34:03,665 [INFO] 开始监督分类任务...
2025-10-15 14:34:03,666 [INFO] 正在读取类别信息...
2025-10-15 14:34:03,673 [INFO] 检测到类别: ['Class_1', 'Class_2', 'Class_3', 'Class_4', 'Class_5', 'Class_6', 'Class_7', 'Class_8', 'Class_9', 'Class_10', 'Class_11']
2025-10-15 14:34:03,683 [INFO] 影像尺寸: (14, 1024, 2098), 波段数: 14
2025-10-15 14:34:03,684 [INFO] 正在处理训练样本...
2025-10-15 14:34:03,949 [INFO] 训练样本数: 15041
2025-10-15 14:34:03,951 [INFO] 使用分类器: RandomForestClassifier
[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 48 concurrent workers.
[Parallel(n_jobs=-1)]: Done 104 tasks      | elapsed:    0.5s
[Parallel(n_jobs=-1)]: Done 300 out of 300 | elapsed:    1.2s finished
2025-10-15 14:34:05,992 [INFO] 模型训练完成。
[Parallel(n_jobs=48)]: Using backend ThreadingBackend with 48 concurrent workers.
[Parallel(n_jobs=48)]: Done 104 tasks      | elapsed:    0.0s
[Parallel(n_jobs=48)]: Done 300 out of 300 | elapsed:    0.1s finished
2025-10-15 14:34:06,800 [INFO] 训练集总体精度: 1.0000, Kappa: 1.0000
2025-10-15

# 类别名称从训练数据读取name字段，同时对分类影像进行后处理

In [1]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
通用遥感影像监督分类系统
-------------------------------------------------
功能：
1. 自动读取多波段遥感影像；
2. 从矢量样本中提取训练/验证数据；
3. 支持随机森林 / SVM / XGBoost 分类；
4. 采用分块预测模式；
5. 输出分类结果 GeoTIFF；
6. 自动生成分类报告与混淆矩阵；
7. 显示分类影像和精度评价结果；
8. 分类面积统计（平方千米）；
9. 后处理功能（去除小图斑、形态学操作）。
"""

import os
import time
import logging
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import seaborn as sns
import geopandas as gpd
import rioxarray as rxr
from shapely.geometry import mapping
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.metrics import (confusion_matrix, accuracy_score, classification_report, 
                           cohen_kappa_score, precision_score, recall_score, f1_score)
from sklearn.inspection import permutation_importance
from tqdm import tqdm
from scipy import ndimage
from skimage import morphology

plt.rcParams["font.sans-serif"] = ["SimHei", "DejaVu Sans"]  # 支持中文
plt.rcParams["axes.unicode_minus"] = False  # 支持负号显示

# ------------------ 参数配置 ------------------
IMAGE_PATH = r"D:\code313\Geo_programe\rasterio\RF\data\2017_09_05_stack.tif"
TRAIN_SHP = r"D:\code313\Geo_programe\rasterio\RF\data\cal.shp"
VAL_SHP = r"D:\code313\Geo_programe\rasterio\RF\data\val.shp"
CLASS_ATTRIBUTE = "class"  # 类别编号字段
NAME_ATTRIBUTE = "name"    # 类别名称字段
OUT_DIR = Path("./results_v2")

CLASSIFIER = "rf"  # 可选: "rf", "svm", "xgb"
N_ESTIMATORS = 300
BLOCK_SIZE = 512
USE_GPU = False

# 后处理参数
POSTPROCESSING = True  # 是否进行后处理
MIN_PATCH_SIZE = 10    # 最小图斑大小（像元数），小于此值的图斑将被去除
MORPHOLOGY_OPERATION = "opening"  # 形态学操作: "opening"（开运算）, "closing"（闭运算）, "both"（两者都）, "none"（无）
MORPHOLOGY_SIZE = 3     # 形态学操作核大小

# 预定义颜色映射（可根据需要扩展）
LANDUSE_COLORS = {
    # 水体相关
    "水体": "lightblue",
    "河流": "blue",
    "湖泊": "deepskyblue",
    "水库": "dodgerblue",
    "海洋": "navy",
    
    # 植被相关
    "植被": "forestgreen",
    "森林": "darkgreen",
    "草地": "limegreen",
    "农田": "yellowgreen",
    "耕地": "olivedrab",
    
    # 建筑相关
    "建筑": "gray",
    "城市": "dimgray",
    "居民地": "slategray",
    "工业区": "darkgray",
    
    # 其他地物
    "裸地": "tan",
    "沙地": "wheat",
    "岩石": "sienna",
    "雪": "white",
    "云": "ghostwhite",
    
    # 默认颜色（如果上述未匹配）
    "其他": "darkred"
}

# 自动生成颜色配置（用于未匹配的类别）
COLOR_PALETTE = ['forestgreen', 'lightblue', 'gray', 'tan', 'yellow', 
                'darkred', 'purple', 'orange', 'pink', 'brown', 
                'cyan', 'magenta', 'lime', 'navy', 'teal']

OUT_DIR.mkdir(exist_ok=True)

# ------------------ 日志系统 ------------------
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    handlers=[
        logging.FileHandler(OUT_DIR / "classification_log.txt", encoding="utf-8"),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# ------------------ 辅助函数 ------------------
def get_class_info_from_shp(shp_path, class_attr, name_attr):
    """从shp文件中获取类别信息和自动生成的颜色"""
    gdf = gpd.read_file(shp_path)
    
    # 检查是否存在名称字段
    if name_attr not in gdf.columns:
        logger.warning(f"shp文件中没有找到 '{name_attr}' 字段，将使用类别编号作为名称")
        # 如果没有名称字段，使用类别编号作为名称
        gdf[name_attr] = gdf[class_attr].apply(lambda x: f"Class_{x}")
    
    # 获取唯一的类别编号和对应的名称
    class_info = gdf[[class_attr, name_attr]].drop_duplicates()
    class_names = dict(zip(class_info[class_attr], class_info[name_attr]))
    
    # 生成颜色映射
    class_colors = {}
    for i, (class_id, class_name) in enumerate(class_names.items()):
        # 尝试从预定义颜色中匹配
        color_found = False
        for key, color in LANDUSE_COLORS.items():
            if key in class_name:
                class_colors[class_id] = color
                color_found = True
                break
        
        # 如果没有匹配到预定义颜色，使用自动分配的颜色
        if not color_found:
            class_colors[class_id] = COLOR_PALETTE[i % len(COLOR_PALETTE)]
    
    unique_classes = sorted(class_names.keys())
    
    return class_names, class_colors, unique_classes

def rasterize_samples(shp, ref_img, attr):
    """将矢量样本栅格化为与影像对齐的数组"""
    import rasterio.features
    
    gdf = gpd.read_file(shp)
    gdf = gdf.to_crs(ref_img.rio.crs)
    shapes = ((geom, value) for geom, value in zip(gdf.geometry, gdf[attr]))
    
    arr = rasterio.features.rasterize(
        shapes=shapes,
        out_shape=ref_img.shape[1:],
        transform=ref_img.rio.transform(),
        fill=0,
        all_touched=True,
        dtype="uint16"
    )
    return arr

def extract_samples(image, mask):
    """根据掩膜提取样本特征与标签"""
    data = np.moveaxis(image.values, 0, -1)  # (bands, rows, cols) → (rows, cols, bands)
    valid = mask > 0
    X = data[valid]
    y = mask[valid]
    return X, y

def get_classifier(name):
    """构造分类器"""
    if name == "rf":
        return RandomForestClassifier(
            n_estimators=N_ESTIMATORS, n_jobs=-1, oob_score=True, verbose=1
        )
    elif name == "svm":
        return SVC(kernel="rbf", probability=True)
    elif name == "xgb":
        try:
            from xgboost import XGBClassifier
            return XGBClassifier(
                n_estimators=N_ESTIMATORS, learning_rate=0.1, max_depth=8, n_jobs=-1
            )
        except ImportError:
            raise ImportError("未安装 xgboost，请先运行 pip install xgboost")
    else:
        raise ValueError(f"未知分类器类型: {name}")

def plot_confusion_matrix(y_true, y_pred, class_names, save_path):
    """绘制详细的混淆矩阵"""
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(10, 8))
    
    # 计算百分比
    cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100
    
    # 创建热图
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names,
                cbar_kws={'label': '样本数量'})
    
    plt.xlabel('预测类别', fontsize=12)
    plt.ylabel('真实类别', fontsize=12)
    plt.title('混淆矩阵', fontsize=14, fontweight='bold')
    plt.xticks(rotation=45)
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    # 同时保存百分比版本
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm_percent, annot=True, fmt='.1f', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names,
                cbar_kws={'label': '百分比 (%)'})
    plt.xlabel('预测类别', fontsize=12)
    plt.ylabel('真实类别', fontsize=12)
    plt.title('混淆矩阵 (百分比)', fontsize=14, fontweight='bold')
    plt.xticks(rotation=45)
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig(str(save_path).replace('.png', '_percent.png'), dpi=300, bbox_inches='tight')
    plt.close()

def comprehensive_evaluation(y_true, y_pred, class_names, save_path):
    """全方位精度评价"""
    # 计算各项指标
    overall_accuracy = accuracy_score(y_true, y_pred)
    kappa = cohen_kappa_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average=None, zero_division=0)
    recall = recall_score(y_true, y_pred, average=None, zero_division=0)
    f1 = f1_score(y_true, y_pred, average=None, zero_division=0)
    
    # 创建详细报告
    report = classification_report(y_true, y_pred, target_names=class_names, digits=4)
    
    # 创建精度评价表格
    eval_df = pd.DataFrame({
        '类别': class_names,
        '精确率 (Precision)': precision,
        '召回率 (Recall)': recall,
        'F1分数': f1,
        '样本数量': np.bincount(y_true)[1:len(class_names)+1]  # 从1开始计数
    })
    
    # 保存详细报告
    with open(save_path, 'w', encoding='utf-8') as f:
        f.write("="*60 + "\n")
        f.write("           遥感影像分类精度评价报告\n")
        f.write("="*60 + "\n\n")
        
        f.write(f"总体精度 (Overall Accuracy): {overall_accuracy:.4f}\n")
        f.write(f"Kappa系数: {kappa:.4f}\n\n")
        
        f.write("各类别精度评价:\n")
        f.write("-"*60 + "\n")
        f.write(eval_df.to_string(index=False, float_format='%.4f'))
        f.write("\n\n")
        
        f.write("详细分类报告:\n")
        f.write("-"*60 + "\n")
        f.write(report)
    
    # 绘制精度指标条形图
    plt.figure(figsize=(12, 6))
    x = np.arange(len(class_names))
    width = 0.25
    
    plt.bar(x - width, precision, width, label='精确率', alpha=0.8)
    plt.bar(x, recall, width, label='召回率', alpha=0.8)
    plt.bar(x + width, f1, width, label='F1分数', alpha=0.8)
    
    plt.xlabel('地物类别')
    plt.ylabel('分数')
    plt.title('各类别分类精度指标')
    plt.xticks(x, class_names, rotation=45)
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(str(save_path).replace('.txt', '_metrics.png'), dpi=300, bbox_inches='tight')
    plt.close()
    
    return overall_accuracy, kappa, eval_df

def plot_classification_results(original_img, classified_img, class_names, class_colors, save_path, title_suffix=""):
    """显示原始影像和分类结果"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))
    
    # 显示原始影像 (使用前3个波段作为RGB)
    if original_img.shape[0] >= 3:
        rgb_data = np.moveaxis(original_img.values[:3], 0, -1)
        # 数据标准化显示
        p2, p98 = np.percentile(rgb_data, (2, 98))
        rgb_display = np.clip((rgb_data - p2) / (p98 - p2), 0, 1)
        ax1.imshow(rgb_display)
    else:
        # 单波段影像显示
        ax1.imshow(original_img.values[0], cmap='gray')
    
    ax1.set_title('原始遥感影像', fontsize=14, fontweight='bold')
    ax1.axis('off')
    
    # 显示分类结果
    classified_data = classified_img.values.squeeze()
    
    # 创建分类图例
    classes = np.unique(classified_data)
    classes = classes[classes > 0]  # 排除背景值
    
    # 创建颜色映射
    colors = [class_colors.get(c, 'black') for c in classes]
    labels = [class_names.get(c, f'未知类别_{c}') for c in classes]
    
    cmap = mcolors.ListedColormap(colors)
    bounds = np.append(classes, classes[-1] + 1) - 0.5
    norm = mcolors.BoundaryNorm(bounds, cmap.N)
    
    im = ax2.imshow(classified_data, cmap=cmap, norm=norm)
    title = '分类结果' + title_suffix
    ax2.set_title(title, fontsize=14, fontweight='bold')
    ax2.axis('off')
    
    # 添加图例
    from matplotlib.patches import Patch
    legend_elements = [Patch(facecolor=color, label=label) 
                      for color, label in zip(colors, labels)]
    ax2.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.35, 1))
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

def plot_feature_importance(clf, feature_names, save_path):
    """绘制特征重要性图（适用于随机森林和XGBoost）"""
    if hasattr(clf, 'feature_importances_'):
        importances = clf.feature_importances_
        indices = np.argsort(importances)[::-1]
        
        plt.figure(figsize=(10, 6))
        plt.title('特征重要性排序', fontsize=14, fontweight='bold')
        plt.bar(range(len(importances)), importances[indices])
        plt.xticks(range(len(importances)), [feature_names[i] for i in indices], rotation=45)
        plt.xlabel('特征波段')
        plt.ylabel('重要性')
        plt.tight_layout()
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()

def calculate_pixel_area(transform):
    """计算单个像元的面积（单位：平方米）"""
    # 获取像元尺寸（通常为米）
    pixel_width = abs(transform[0])  # x方向分辨率
    pixel_height = abs(transform[4])  # y方向分辨率
    
    # 计算单个像元面积（平方米）
    pixel_area = pixel_width * pixel_height
    
    return pixel_area

def predict_by_block(model, image, out_path, block_size=BLOCK_SIZE):
    """分块预测整幅影像"""
    import rasterio
    
    height, width = image.shape[1], image.shape[2]
    
    # 修正：正确获取栅格配置文件
    profile = {
        'driver': 'GTiff',
        'dtype': 'uint16',
        'nodata': 0,
        'width': width,
        'height': height,
        'count': 1,
        'crs': image.rio.crs,
        'transform': image.rio.transform(),
        'compress': 'lzw',
        'tiled': True,
        'blockxsize': min(block_size, width),
        'blockysize': min(block_size, block_size)
    }

    with rasterio.open(out_path, "w", **profile) as dst:
        for y in tqdm(range(0, height, block_size), desc="Block predicting"):
            h = min(block_size, height - y)
            
            # 读取当前块的数据
            block_data = image.isel(y=slice(y, y+h)).values
            data = np.moveaxis(block_data, 0, -1)
            original_shape = data.shape
            data = data.reshape(-1, data.shape[-1])
            data = np.nan_to_num(data)
            
            # 预测
            preds = model.predict(data).reshape(original_shape[0], original_shape[1]).astype("uint16")
            
            # 写入结果
            dst.write(preds, 1, window=rasterio.windows.Window(0, y, width, h))
    
    return out_path

def postprocess_classification(classified_data, min_patch_size=10, morphology_op="opening", morphology_size=3):
    """
    后处理分类结果
    
    参数:
    - classified_data: 分类结果数组
    - min_patch_size: 最小图斑大小（像元数）
    - morphology_op: 形态学操作类型 ("opening", "closing", "both", "none")
    - morphology_size: 形态学操作核大小
    
    返回:
    - 后处理后的分类结果
    """
    logger.info("开始后处理分类结果...")
    processed_data = classified_data.copy()
    
    # 获取所有类别（排除背景0）
    classes = np.unique(classified_data)
    classes = classes[classes > 0]
    
    # 对每个类别进行后处理
    for class_id in classes:
        # 创建二值掩膜
        binary_mask = (classified_data == class_id).astype(np.uint8)
        
        # 去除小图斑
        if min_patch_size > 0:
            # 使用连通组件分析标记图斑
            labeled_array, num_features = ndimage.label(binary_mask)
            
            # 计算每个图斑的大小
            component_sizes = np.bincount(labeled_array.ravel())
            
            # 创建掩膜，只保留大于最小图斑大小的区域
            size_mask = component_sizes >= min_patch_size
            size_mask[0] = 0  # 背景
            
            # 应用大小过滤
            binary_mask = size_mask[labeled_array]
        
        # 形态学操作
        if morphology_op != "none" and morphology_size > 0:
            # 创建结构元素
            structure = np.ones((morphology_size, morphology_size), dtype=np.uint8)
            
            if morphology_op == "opening":
                binary_mask = morphology.binary_opening(binary_mask, structure)
            elif morphology_op == "closing":
                binary_mask = morphology.binary_closing(binary_mask, structure)
            elif morphology_op == "both":
                binary_mask = morphology.binary_opening(binary_mask, structure)
                binary_mask = morphology.binary_closing(binary_mask, structure)
        
        # 更新分类结果
        processed_data[binary_mask > 0] = class_id
        # 将去除的小图斑区域设为背景（0）
        processed_data[(classified_data == class_id) & (binary_mask == 0)] = 0
    
    # 统计后处理变化
    original_nonzero = np.count_nonzero(classified_data)
    processed_nonzero = np.count_nonzero(processed_data)
    change_percent = (original_nonzero - processed_nonzero) / original_nonzero * 100
    
    logger.info(f"后处理完成: 原始非零像元数 {original_nonzero}, 处理后非零像元数 {processed_nonzero}")
    logger.info(f"后处理去除了 {original_nonzero - processed_nonzero} 个像元 ({change_percent:.2f}%)")
    
    return processed_data

def save_classification_result(data, transform, crs, out_path):
    """保存分类结果到GeoTIFF文件"""
    import rasterio
    
    profile = {
        'driver': 'GTiff',
        'dtype': 'uint16',
        'nodata': 0,
        'width': data.shape[1],
        'height': data.shape[0],
        'count': 1,
        'crs': crs,
        'transform': transform,
        'compress': 'lzw',
        'tiled': True
    }
    
    with rasterio.open(out_path, "w", **profile) as dst:
        dst.write(data.astype("uint16"), 1)
    
    return out_path

def calculate_area_statistics(classified_data, class_names, class_colors, pixel_area_km2, suffix=""):
    """
    计算分类面积统计
    
    参数:
    - classified_data: 分类结果数组
    - class_names: 类别名称字典
    - class_colors: 类别颜色字典
    - pixel_area_km2: 单个像元面积（平方千米）
    - suffix: 文件名后缀
    
    返回:
    - stats_df: 统计DataFrame
    - total_area_km2: 总面积
    """
    # 获取类别和数量
    unique, counts = np.unique(classified_data[classified_data > 0], return_counts=True)
    total_pixels = np.sum(counts)
    
    # 计算各类别面积
    areas_km2 = [count * pixel_area_km2 for count in counts]
    total_area_km2 = np.sum(areas_km2)
    
    # 创建统计表格
    stats_df = pd.DataFrame({
        '类别编号': unique,
        '类别名称': [class_names.get(c, f'未知类别_{c}') for c in unique],
        '像元数量': counts,
        '面积(km²)': [round(area, 4) for area in areas_km2],
        '面积占比 (%)': (counts / total_pixels * 100).round(2)
    })
    
    # 保存统计表格
    stats_filename = f"classification_statistics{suffix}.csv"
    stats_df.to_csv(OUT_DIR / stats_filename, index=False, encoding='utf-8-sig')
    
    # 绘制面积占比饼图
    plt.figure(figsize=(10, 8))
    plt.pie(stats_df['面积占比 (%)'], labels=stats_df['类别名称'], autopct='%1.1f%%', startangle=90)
    plt.title(f'分类结果面积占比分布{suffix}', fontsize=14, fontweight='bold')
    plt.axis('equal')
    plt.tight_layout()
    plt.savefig(OUT_DIR / f"area_distribution{suffix}.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    # 绘制面积柱状图
    plt.figure(figsize=(12, 6))
    plt.bar(stats_df['类别名称'], stats_df['面积(km²)'], 
            color=[class_colors.get(c, 'gray') for c in unique])
    plt.xlabel('地物类别')
    plt.ylabel('面积 (km²)')
    plt.title(f'各类别面积统计{suffix}', fontsize=14, fontweight='bold')
    plt.xticks(rotation=45)
    plt.grid(True, alpha=0.3)
    
    # 在柱状图上添加数值标签
    for i, v in enumerate(stats_df['面积(km²)']):
        plt.text(i, v + max(areas_km2)*0.01, f'{v:.2f}', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.savefig(OUT_DIR / f"area_bar_chart{suffix}.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    # 生成面积统计摘要
    with open(OUT_DIR / f"area_summary{suffix}.txt", "w", encoding="utf-8") as f:
        f.write(f"分类面积统计摘要{suffix}\n")
        f.write("="*50 + "\n")
        f.write(f"总分类面积: {total_area_km2:.4f} km²\n")
        f.write(f"像元大小: {pixel_area_km2 * 1e6:.2f} 平方米\n")
        f.write(f"总像元数: {total_pixels}\n\n")
        
        f.write("各类别面积统计:\n")
        f.write("-"*50 + "\n")
        for _, row in stats_df.iterrows():
            f.write(f"{row['类别名称']}: {row['面积(km²)']:.4f} km² ({row['面积占比 (%)']}%)\n")
    
    return stats_df, total_area_km2

# ------------------ 主流程 ------------------
def main():
    t0 = time.time()
    logger.info("开始监督分类任务...")

    # 0. 从训练样本shp文件中获取类别信息
    logger.info("正在读取类别信息...")
    class_names, class_colors, train_classes = get_class_info_from_shp(TRAIN_SHP, CLASS_ATTRIBUTE, NAME_ATTRIBUTE)
    logger.info(f"检测到类别: {list(class_names.values())}")

    # 1. 读取影像
    img = rxr.open_rasterio(IMAGE_PATH, masked=True)
    logger.info(f"影像尺寸: {img.shape}, 波段数: {img.rio.count}")
    
    # 获取影像的空间参考信息
    transform = img.rio.transform()
    crs = img.rio.crs
    logger.info(f"影像坐标系: {crs}")
    logger.info(f"影像变换参数: {transform}")

    # 计算像元面积
    pixel_area_m2 = calculate_pixel_area(transform)
    pixel_area_km2 = pixel_area_m2 / 1e6  # 转换为平方千米
    logger.info(f"单个像元面积: {pixel_area_m2:.2f} 平方米 ({pixel_area_km2:.6f} 平方千米)")

    # 2. 训练样本栅格化与提取
    logger.info("正在处理训练样本...")
    train_mask = rasterize_samples(TRAIN_SHP, img, CLASS_ATTRIBUTE)
    X_train, y_train = extract_samples(img, train_mask)
    logger.info(f"训练样本数: {len(y_train)}")

    # 3. 训练分类器
    clf = get_classifier(CLASSIFIER)
    logger.info(f"使用分类器: {clf.__class__.__name__}")
    clf.fit(X_train, y_train)
    logger.info("模型训练完成。")

    # 4. 精度评估（训练集）
    y_pred_train = clf.predict(X_train)
    
    # 获取实际存在的类别
    actual_train_classes = sorted(np.unique(y_train))
    train_class_names = [class_names.get(c, f'未知类别_{c}') for c in actual_train_classes if c > 0]
    
    # 全方位精度评价
    overall_acc, kappa, eval_df = comprehensive_evaluation(
        y_train, y_pred_train, train_class_names, OUT_DIR / "train_evaluation.txt"
    )
    logger.info(f"训练集总体精度: {overall_acc:.4f}, Kappa: {kappa:.4f}")
    
    # 绘制训练集混淆矩阵
    plot_confusion_matrix(y_train, y_pred_train, train_class_names, OUT_DIR / "train_cm.png")

    # 5. 特征重要性分析（如果适用）
    if hasattr(clf, 'feature_importances_'):
        feature_names = [f'波段{i+1}' for i in range(X_train.shape[1])]
        plot_feature_importance(clf, feature_names, OUT_DIR / "feature_importance.png")

    # 6. 分块预测整幅影像
    logger.info("开始分块预测...")
    classified_path = OUT_DIR / "classified_result.tif"
    predict_by_block(clf, img, classified_path)
    logger.info(f"分类结果保存至: {classified_path}")

    # 7. 显示原始分类结果
    logger.info("生成原始分类结果可视化...")
    classified_img = rxr.open_rasterio(classified_path)
    plot_classification_results(img, classified_img, class_names, class_colors, 
                               OUT_DIR / "classification_visualization.png", " (原始)")

    # 8. 原始分类结果面积统计
    logger.info("计算原始分类结果面积统计...")
    original_classified_data = classified_img.values.squeeze()
    original_stats_df, original_total_area = calculate_area_statistics(
        original_classified_data, class_names, class_colors, pixel_area_km2, "_original"
    )
    logger.info(f"原始分类总面积: {original_total_area:.4f} 平方千米")

    # 9. 后处理
    if POSTPROCESSING:
        logger.info("开始后处理...")
        logger.info(f"后处理参数: 最小图斑大小={MIN_PATCH_SIZE}, 形态学操作={MORPHOLOGY_OPERATION}, 核大小={MORPHOLOGY_SIZE}")
        
        # 进行后处理
        processed_data = postprocess_classification(
            original_classified_data, 
            min_patch_size=MIN_PATCH_SIZE,
            morphology_op=MORPHOLOGY_OPERATION,
            morphology_size=MORPHOLOGY_SIZE
        )
        
        # 保存后处理结果
        processed_path = OUT_DIR / "classified_result_postprocessed.tif"
        save_classification_result(processed_data, transform, crs, processed_path)
        logger.info(f"后处理结果保存至: {processed_path}")
        
        # 显示后处理分类结果
        logger.info("生成后处理分类结果可视化...")
        processed_img = rxr.open_rasterio(processed_path)
        plot_classification_results(img, processed_img, class_names, class_colors,
                                   OUT_DIR / "classification_visualization_postprocessed.png", " (后处理)")
        
        # 后处理分类结果面积统计
        logger.info("计算后处理分类结果面积统计...")
        processed_stats_df, processed_total_area = calculate_area_statistics(
            processed_data, class_names, class_colors, pixel_area_km2, "_postprocessed"
        )
        logger.info(f"后处理分类总面积: {processed_total_area:.4f} 平方千米")
        
        # 生成后处理变化报告
        area_change = processed_total_area - original_total_area
        area_change_percent = (area_change / original_total_area) * 100
        
        with open(OUT_DIR / "postprocessing_report.txt", "w", encoding="utf-8") as f:
            f.write("后处理变化报告\n")
            f.write("="*50 + "\n")
            f.write(f"后处理参数:\n")
            f.write(f"  最小图斑大小: {MIN_PATCH_SIZE} 像元\n")
            f.write(f"  形态学操作: {MORPHOLOGY_OPERATION}\n")
            f.write(f"  核大小: {MORPHOLOGY_SIZE}\n\n")
            
            f.write(f"面积变化:\n")
            f.write(f"  原始总面积: {original_total_area:.4f} km²\n")
            f.write(f"  后处理总面积: {processed_total_area:.4f} km²\n")
            f.write(f"  面积变化: {area_change:+.4f} km² ({area_change_percent:+.2f}%)\n\n")
            
            f.write("各类别面积变化:\n")
            f.write("-"*50 + "\n")
            for class_id in class_names.keys():
                if class_id in original_stats_df['类别编号'].values and class_id in processed_stats_df['类别编号'].values:
                    orig_area = original_stats_df[original_stats_df['类别编号'] == class_id]['面积(km²)'].values[0]
                    proc_area = processed_stats_df[processed_stats_df['类别编号'] == class_id]['面积(km²)'].values[0]
                    change = proc_area - orig_area
                    change_percent = (change / orig_area) * 100 if orig_area > 0 else 0
                    f.write(f"{class_names[class_id]}: {orig_area:.4f} → {proc_area:.4f} km² ({change:+.4f}, {change_percent:+.2f}%)\n")
                elif class_id in original_stats_df['类别编号'].values:
                    orig_area = original_stats_df[original_stats_df['类别编号'] == class_id]['面积(km²)'].values[0]
                    f.write(f"{class_names[class_id]}: {orig_area:.4f} → 0.0000 km² (完全去除)\n")
                elif class_id in processed_stats_df['类别编号'].values:
                    proc_area = processed_stats_df[processed_stats_df['类别编号'] == class_id]['面积(km²)'].values[0]
                    f.write(f"{class_names[class_id]}: 0.0000 → {proc_area:.4f} km² (新增)\n")

    # 10. 验证阶段
    if os.path.exists(VAL_SHP):
        logger.info("正在进行验证...")
        val_mask = rasterize_samples(VAL_SHP, img, CLASS_ATTRIBUTE)
        
        # 使用原始分类结果进行验证
        with rxr.open_rasterio(classified_path) as pred_img:
            pred_arr = pred_img.values.squeeze()
        
        Xv = pred_arr[val_mask > 0]
        yv = val_mask[val_mask > 0]
        
        # 验证集全方位精度评价
        val_classes = sorted(np.unique(yv))
        val_class_names = [class_names.get(c, f'未知类别_{c}') for c in val_classes if c > 0]
        
        val_overall_acc, val_kappa, val_eval_df = comprehensive_evaluation(
            yv, Xv, val_class_names, OUT_DIR / "validation_evaluation.txt"
        )
        logger.info(f"验证集总体精度: {val_overall_acc:.4f}, Kappa: {val_kappa:.4f}")
        
        # 绘制验证集混淆矩阵
        plot_confusion_matrix(yv, Xv, val_class_names, OUT_DIR / "val_cm.png")

        # 生成综合报告
        with open(OUT_DIR / "comprehensive_report.txt", "w", encoding="utf-8") as f:
            f.write("遥感影像分类综合报告\n")
            f.write("="*50 + "\n")
            f.write(f"分类器: {clf.__class__.__name__}\n")
            f.write(f"训练样本数: {len(y_train)}\n")
            f.write(f"验证样本数: {len(yv)}\n")
            f.write(f"类别编号字段: {CLASS_ATTRIBUTE}\n")
            f.write(f"类别名称字段: {NAME_ATTRIBUTE}\n")
            f.write(f"检测到的类别: {list(class_names.values())}\n")
            f.write(f"像元面积: {pixel_area_m2:.2f} 平方米\n")
            f.write(f"后处理: {'是' if POSTPROCESSING else '否'}\n\n")
            
            f.write("精度评价汇总:\n")
            f.write("-"*30 + "\n")
            f.write(f"训练集总体精度: {overall_acc:.4f}\n")
            f.write(f"训练集Kappa系数: {kappa:.4f}\n")
            f.write(f"验证集总体精度: {val_overall_acc:.4f}\n")
            f.write(f"验证集Kappa系数: {val_kappa:.4f}\n\n")
            
            f.write("各类别验证精度:\n")
            f.write("-"*30 + "\n")
            f.write(val_eval_df.to_string(index=False, float_format='%.4f'))

    # 11. 保存类别信息
    class_info_df = pd.DataFrame({
        '类别编号': list(class_names.keys()),
        '类别名称': list(class_names.values()),
        '显示颜色': [class_colors.get(c, 'black') for c in class_names.keys()]
    })
    class_info_df.to_csv(OUT_DIR / "class_information.csv", index=False, encoding='utf-8-sig')
    
    logger.info(f"全部任务完成，用时 {time.time()-t0:.1f} 秒。")
    logger.info(f"所有结果已保存至: {OUT_DIR.absolute()}")

if __name__ == "__main__":
    main()

2025-10-16 07:58:08,220 [INFO] 开始监督分类任务...
2025-10-16 07:58:08,222 [INFO] 正在读取类别信息...
2025-10-16 07:58:08,377 [INFO] 检测到类别: ['类1', '类2', '类3', '类4', '类5', '类6', '类7', '类8', '类9', '类10', '类11']
2025-10-16 07:58:08,443 [INFO] 影像尺寸: (14, 1024, 2098), 波段数: 14
2025-10-16 07:58:08,446 [INFO] 影像坐标系: EPSG:32633
2025-10-16 07:58:08,447 [INFO] 影像变换参数: | 0.20, 0.00, 351916.64|
| 0.00,-0.20, 5997247.36|
| 0.00, 0.00, 1.00|
2025-10-16 07:58:08,448 [INFO] 单个像元面积: 0.04 平方米 (0.000000 平方千米)
2025-10-16 07:58:08,448 [INFO] 正在处理训练样本...
2025-10-16 07:58:08,717 [INFO] 训练样本数: 15041
2025-10-16 07:58:08,719 [INFO] 使用分类器: RandomForestClassifier
[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 48 concurrent workers.
[Parallel(n_jobs=-1)]: Done 104 tasks      | elapsed:    0.7s
[Parallel(n_jobs=-1)]: Done 300 out of 300 | elapsed:    1.4s finished
2025-10-16 07:58:10,887 [INFO] 模型训练完成。
[Parallel(n_jobs=48)]: Using backend ThreadingBackend with 48 concurrent workers.
[Parallel(n_jobs=48)]: Done 104 tasks

## 主要改进内容：

### 1. **新增名称字段配置**
- 添加了 `NAME_ATTRIBUTE = "name"` 参数，用于指定包含用地类型名称的字段
- 保留了 `CLASS_ATTRIBUTE = "class"` 用于类别编号

### 2. **智能颜色分配系统**
- 创建了 `LANDUSE_COLORS` 字典，包含常见用地类型的预定义颜色
- 系统会尝试根据类别名称中的关键词匹配预定义颜色
- 如果无法匹配，则使用自动分配的颜色

### 3. **改进的类别信息获取**
- `get_class_info_from_shp()` 函数现在同时读取类别编号和名称
- 如果shp文件中没有名称字段，会使用类别编号作为名称并发出警告
- 确保每个类别编号对应唯一的名称

### 4. **预定义用地类型颜色**
系统包含以下用地类型的预定义颜色：
- **水体相关**：水体(浅蓝)、河流(蓝色)、湖泊(深天蓝)、水库(道奇蓝)、海洋(海军蓝)
- **植被相关**：植被(森林绿)、森林(深绿)、草地(酸橙绿)、农田(黄绿)、耕地(橄榄绿)
- **建筑相关**：建筑(灰色)、城市(暗灰)、居民地(石板灰)、工业区(深灰)
- **其他地物**：裸地(棕褐色)、沙地(小麦色)、岩石(赭色)、雪(白色)、云(幽灵白)

### 5. **增强的日志和报告**
- 在综合报告中显示类别名称字段信息
- 更清晰的类别信息输出

## 使用说明：

1. **配置参数**：确保 `CLASS_ATTRIBUTE` 和 `NAME_ATTRIBUTE` 与您的shp文件字段匹配
2. **自定义颜色**：如果需要，可以扩展 `LANDUSE_COLORS` 字典添加更多用地类型和颜色
3. **运行脚本**：系统会自动从shp文件的"name"字段读取用地类型名称

现在系统会根据您的shp文件中的"name"字段显示用地类型名称，并为每种用地类型分配语义化的颜色，使分类结果更加直观易懂。

# 用rioxarray读写栅格

## cloude

### 程序1 可以切换3分类方法

In [5]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
通用遥感影像监督分类系统 (纯rioxarray版本 - 支持背景值处理)
-------------------------------------------------
功能：
1. 自动读取多波段遥感影像；
2. 从矢量样本中提取训练/验证数据；
3. 支持随机森林 / SVM / XGBoost 分类；
4. 采用分块预测模式；
5. 输出分类结果 GeoTIFF；
6. 自动生成分类报告与混淆矩阵；
7. 显示分类影像和精度评价结果；
8. 分类面积统计（平方千米）；
9. 后处理功能（去除小图斑、形态学操作）；
10. 忽略背景值（所有波段值为0的像元）。
"""

import os
import time
import logging
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import seaborn as sns
import geopandas as gpd
import rioxarray as rxr
import xarray as xr
from rasterio import features  # 仅用于矢量栅格化
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.metrics import (confusion_matrix, accuracy_score, classification_report, 
                           cohen_kappa_score, precision_score, recall_score, f1_score)
from sklearn.inspection import permutation_importance
from tqdm import tqdm
from scipy import ndimage
from skimage import morphology

plt.rcParams["font.sans-serif"] = ["SimHei", "DejaVu Sans"]  # 支持中文
plt.rcParams["axes.unicode_minus"] = False  # 支持负号显示

# ------------------ 参数配置 ------------------
IMAGE_PATH = r"D:\code313\Geo_programe\rasterio\RF\data\2017_09_05_stack.tif"
TRAIN_SHP = r"D:\code313\Geo_programe\rasterio\RF\data\cal.shp"
VAL_SHP = r"D:\code313\Geo_programe\rasterio\RF\data\val.shp"
CLASS_ATTRIBUTE = "class"  # 类别编号字段
NAME_ATTRIBUTE = "name"    # 类别名称字段
OUT_DIR = Path("./results_v2")

CLASSIFIER = "rf"  # 可选: "rf", "svm", "xgb"
N_ESTIMATORS = 300
BLOCK_SIZE = 512
USE_GPU = False

# 后处理参数
POSTPROCESSING = True  # 是否进行后处理
MIN_PATCH_SIZE = 10    # 最小图斑大小（像元数），小于此值的图斑将被去除
MORPHOLOGY_OPERATION = "opening"  # 形态学操作: "opening"（开运算）, "closing"（闭运算）, "both"（两者都）, "none"（无）
MORPHOLOGY_SIZE = 3     # 形态学操作核大小

# 背景值处理
BACKGROUND_VALUE = 0  # 分类结果中的背景值
IGNORE_BACKGROUND = True  # 是否忽略所有波段值为0的像元

# 预定义颜色映射（可根据需要扩展）
LANDUSE_COLORS = {
    # 水体相关
    "水体": "lightblue",
    "河流": "blue",
    "湖泊": "deepskyblue",
    "水库": "dodgerblue",
    "海洋": "navy",
    
    # 植被相关
    "植被": "forestgreen",
    "森林": "darkgreen",
    "草地": "limegreen",
    "农田": "yellowgreen",
    "耕地": "olivedrab",
    
    # 建筑相关
    "建筑": "gray",
    "城市": "dimgray",
    "居民地": "slategray",
    "工业区": "darkgray",
    
    # 其他地物
    "裸地": "tan",
    "沙地": "wheat",
    "岩石": "sienna",
    "雪": "white",
    "云": "ghostwhite",
    
    # 默认颜色（如果上述未匹配）
    "其他": "darkred"
}

# 自动生成颜色配置（用于未匹配的类别）
COLOR_PALETTE = ['forestgreen', 'lightblue', 'gray', 'tan', 'yellow', 
                'darkred', 'purple', 'orange', 'pink', 'brown', 
                'cyan', 'magenta', 'lime', 'navy', 'teal']

OUT_DIR.mkdir(exist_ok=True)

# ------------------ 日志系统 ------------------
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    handlers=[
        logging.FileHandler(OUT_DIR / "classification_log.txt", encoding="utf-8"),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# ------------------ 辅助函数 ------------------
def get_background_mask(image):
    """
    获取背景掩膜（所有波段值为0的像元）
    
    参数:
    - image: xarray DataArray，形状为 (bands, rows, cols)
    
    返回:
    - background_mask: numpy array，形状为 (rows, cols)，True表示背景
    """
    # 检查所有波段是否都为0
    data = image.values  # (bands, rows, cols)
    background_mask = np.all(data == 0, axis=0)  # (rows, cols)
    
    return background_mask

def get_class_info_from_shp(shp_path, class_attr, name_attr):
    """从shp文件中获取类别信息和自动生成的颜色"""
    gdf = gpd.read_file(shp_path)
    
    # 检查是否存在名称字段
    if name_attr not in gdf.columns:
        logger.warning(f"shp文件中没有找到 '{name_attr}' 字段，将使用类别编号作为名称")
        # 如果没有名称字段，使用类别编号作为名称
        gdf[name_attr] = gdf[class_attr].apply(lambda x: f"Class_{x}")
    
    # 获取唯一的类别编号和对应的名称
    class_info = gdf[[class_attr, name_attr]].drop_duplicates()
    class_names = dict(zip(class_info[class_attr], class_info[name_attr]))
    
    # 生成颜色映射
    class_colors = {}
    for i, (class_id, class_name) in enumerate(class_names.items()):
        # 尝试从预定义颜色中匹配
        color_found = False
        for key, color in LANDUSE_COLORS.items():
            if key in class_name:
                class_colors[class_id] = color
                color_found = True
                break
        
        # 如果没有匹配到预定义颜色，使用自动分配的颜色
        if not color_found:
            class_colors[class_id] = COLOR_PALETTE[i % len(COLOR_PALETTE)]
    
    unique_classes = sorted(class_names.keys())
    
    return class_names, class_colors, unique_classes

def rasterize_samples(shp, ref_img, attr):
    """将矢量样本栅格化为与影像对齐的数组"""
    gdf = gpd.read_file(shp)
    gdf = gdf.to_crs(ref_img.rio.crs)
    shapes = ((geom, value) for geom, value in zip(gdf.geometry, gdf[attr]))
    
    arr = features.rasterize(
        shapes=shapes,
        out_shape=ref_img.shape[1:],
        transform=ref_img.rio.transform(),
        fill=0,
        all_touched=True,
        dtype="uint16"
    )
    return arr

def extract_samples(image, mask, ignore_background=True):
    """
    根据掩膜提取样本特征与标签，可选忽略背景
    
    参数:
    - image: xarray DataArray，影像数据
    - mask: numpy array，样本掩膜
    - ignore_background: 是否忽略背景（所有波段值为0的像元）
    
    返回:
    - X: 特征数组
    - y: 标签数组
    """
    data = np.moveaxis(image.values, 0, -1)  # (bands, rows, cols) → (rows, cols, bands)
    valid = mask > 0
    
    if ignore_background:
        # 获取背景掩膜
        background_mask = get_background_mask(image)
        # 排除背景像元
        valid = valid & (~background_mask)
        n_background = np.sum(mask > 0) - np.sum(valid)
        if n_background > 0:
            logger.info(f"排除了 {n_background} 个背景像元")
    
    X = data[valid]
    y = mask[valid]
    
    return X, y

def get_classifier(name):
    """构造分类器"""
    if name == "rf":
        return RandomForestClassifier(
            n_estimators=N_ESTIMATORS, n_jobs=-1, oob_score=True, verbose=1
        )
    elif name == "svm":
        return SVC(kernel="rbf", probability=True)
    elif name == "xgb":
        try:
            from xgboost import XGBClassifier
            return XGBClassifier(
                n_estimators=N_ESTIMATORS, learning_rate=0.1, max_depth=8, n_jobs=-1
            )
        except ImportError:
            raise ImportError("未安装 xgboost，请先运行 pip install xgboost")
    else:
        raise ValueError(f"未知分类器类型: {name}")

def plot_confusion_matrix(y_true, y_pred, class_names, save_path):
    """绘制详细的混淆矩阵"""
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(10, 8))
    
    # 计算百分比
    cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100
    
    # 创建热图
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names,
                cbar_kws={'label': '样本数量'})
    
    plt.xlabel('预测类别', fontsize=12)
    plt.ylabel('真实类别', fontsize=12)
    plt.title('混淆矩阵', fontsize=14, fontweight='bold')
    plt.xticks(rotation=45)
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    # 同时保存百分比版本
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm_percent, annot=True, fmt='.1f', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names,
                cbar_kws={'label': '百分比 (%)'})
    plt.xlabel('预测类别', fontsize=12)
    plt.ylabel('真实类别', fontsize=12)
    plt.title('混淆矩阵 (百分比)', fontsize=14, fontweight='bold')
    plt.xticks(rotation=45)
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig(str(save_path).replace('.png', '_percent.png'), dpi=300, bbox_inches='tight')
    plt.close()

def comprehensive_evaluation(y_true, y_pred, class_names, save_path):
    """全方位精度评价"""
    # 计算各项指标
    overall_accuracy = accuracy_score(y_true, y_pred)
    kappa = cohen_kappa_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average=None, zero_division=0)
    recall = recall_score(y_true, y_pred, average=None, zero_division=0)
    f1 = f1_score(y_true, y_pred, average=None, zero_division=0)
    
    # 创建详细报告
    report = classification_report(y_true, y_pred, target_names=class_names, digits=4)
    
    # 创建精度评价表格
    eval_df = pd.DataFrame({
        '类别': class_names,
        '精确率 (Precision)': precision,
        '召回率 (Recall)': recall,
        'F1分数': f1,
        '样本数量': np.bincount(y_true)[1:len(class_names)+1]  # 从1开始计数
    })
    
    # 保存详细报告
    with open(save_path, 'w', encoding='utf-8') as f:
        f.write("="*60 + "\n")
        f.write("           遥感影像分类精度评价报告\n")
        f.write("="*60 + "\n\n")
        
        f.write(f"总体精度 (Overall Accuracy): {overall_accuracy:.4f}\n")
        f.write(f"Kappa系数: {kappa:.4f}\n\n")
        
        f.write("各类别精度评价:\n")
        f.write("-"*60 + "\n")
        f.write(eval_df.to_string(index=False, float_format='%.4f'))
        f.write("\n\n")
        
        f.write("详细分类报告:\n")
        f.write("-"*60 + "\n")
        f.write(report)
    
    # 绘制精度指标条形图
    plt.figure(figsize=(12, 6))
    x = np.arange(len(class_names))
    width = 0.25
    
    plt.bar(x - width, precision, width, label='精确率', alpha=0.8)
    plt.bar(x, recall, width, label='召回率', alpha=0.8)
    plt.bar(x + width, f1, width, label='F1分数', alpha=0.8)
    
    plt.xlabel('地物类别')
    plt.ylabel('分数')
    plt.title('各类别分类精度指标')
    plt.xticks(x, class_names, rotation=45)
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(str(save_path).replace('.txt', '_metrics.png'), dpi=300, bbox_inches='tight')
    plt.close()
    
    return overall_accuracy, kappa, eval_df

def plot_classification_results(original_img, classified_img, class_names, class_colors, save_path, title_suffix=""):
    """显示原始影像和分类结果"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))
    
    # 显示原始影像 (使用前3个波段作为RGB)
    if original_img.shape[0] >= 3:
        rgb_data = np.moveaxis(original_img.values[:3], 0, -1)
        # 数据标准化显示
        p2, p98 = np.percentile(rgb_data[rgb_data > 0], (2, 98))  # 排除0值
        rgb_display = np.clip((rgb_data - p2) / (p98 - p2), 0, 1)
        ax1.imshow(rgb_display)
    else:
        # 单波段影像显示
        ax1.imshow(original_img.values[0], cmap='gray')
    
    ax1.set_title('原始遥感影像', fontsize=14, fontweight='bold')
    ax1.axis('off')
    
    # 显示分类结果
    classified_data = classified_img.values.squeeze()
    
    # 创建分类图例
    classes = np.unique(classified_data)
    classes = classes[classes > 0]  # 排除背景值
    
    # 创建颜色映射
    colors = [class_colors.get(c, 'black') for c in classes]
    labels = [class_names.get(c, f'未知类别_{c}') for c in classes]
    
    cmap = mcolors.ListedColormap(colors)
    bounds = np.append(classes, classes[-1] + 1) - 0.5
    norm = mcolors.BoundaryNorm(bounds, cmap.N)
    
    # 创建显示数据，背景值设为NaN以便透明显示
    display_data = classified_data.astype(float)
    display_data[classified_data == 0] = np.nan
    
    im = ax2.imshow(display_data, cmap=cmap, norm=norm)
    title = '分类结果' + title_suffix
    ax2.set_title(title, fontsize=14, fontweight='bold')
    ax2.axis('off')
    
    # 添加图例
    from matplotlib.patches import Patch
    legend_elements = [Patch(facecolor=color, label=label) 
                      for color, label in zip(colors, labels)]
    ax2.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.35, 1))
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

def plot_feature_importance(clf, feature_names, save_path):
    """绘制特征重要性图（适用于随机森林和XGBoost）"""
    if hasattr(clf, 'feature_importances_'):
        importances = clf.feature_importances_
        indices = np.argsort(importances)[::-1]
        
        plt.figure(figsize=(10, 6))
        plt.title('特征重要性排序', fontsize=14, fontweight='bold')
        plt.bar(range(len(importances)), importances[indices])
        plt.xticks(range(len(importances)), [feature_names[i] for i in indices], rotation=45)
        plt.xlabel('特征波段')
        plt.ylabel('重要性')
        plt.tight_layout()
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()

def calculate_pixel_area(transform):
    """计算单个像元的面积（单位：平方米）"""
    # 获取像元尺寸（通常为米）
    pixel_width = abs(transform[0])  # x方向分辨率
    pixel_height = abs(transform[4])  # y方向分辨率
    
    # 计算单个像元面积（平方米）
    pixel_area = pixel_width * pixel_height
    
    return pixel_area

def predict_by_block(model, image, out_path, block_size=BLOCK_SIZE, ignore_background=True):
    """
    分块预测整幅影像（使用rioxarray），忽略背景值
    
    参数:
    - model: 训练好的分类器
    - image: xarray DataArray，输入影像
    - out_path: 输出路径
    - block_size: 块大小
    - ignore_background: 是否忽略背景（所有波段值为0的像元）
    """
    height, width = image.shape[1], image.shape[2]
    n_bands = image.shape[0]
    
    # 创建空的预测结果数组，初始化为背景值
    prediction = np.zeros((height, width), dtype='uint16')
    
    # 如果需要忽略背景，获取背景掩膜
    if ignore_background:
        logger.info("计算背景掩膜...")
        background_mask = get_background_mask(image)
        n_background = np.sum(background_mask)
        n_total = height * width
        logger.info(f"背景像元数: {n_background} ({n_background/n_total*100:.2f}%)")
    
    # 分块预测
    total_blocks = int(np.ceil(height / block_size))
    
    for y in tqdm(range(0, height, block_size), desc="分块预测"):
        h = min(block_size, height - y)
        
        # 读取当前块的数据
        block_data = image.isel(y=slice(y, y+h)).values
        data = np.moveaxis(block_data, 0, -1)  # (bands, h, width) → (h, width, bands)
        original_shape = data.shape
        
        # 重塑为 (n_pixels, n_bands)
        data_flat = data.reshape(-1, data.shape[-1])
        
        if ignore_background:
            # 获取当前块的背景掩膜
            block_bg_mask = background_mask[y:y+h, :].flatten()
            
            # 只预测非背景像元
            non_bg_indices = ~block_bg_mask
            
            if np.any(non_bg_indices):
                # 预测非背景像元
                data_to_predict = np.nan_to_num(data_flat[non_bg_indices])
                preds_non_bg = model.predict(data_to_predict)
                
                # 创建完整的预测结果（背景为0）
                preds_flat = np.zeros(len(data_flat), dtype='uint16')
                preds_flat[non_bg_indices] = preds_non_bg
                
                preds = preds_flat.reshape(original_shape[0], original_shape[1])
            else:
                # 整个块都是背景
                preds = np.zeros((original_shape[0], original_shape[1]), dtype='uint16')
        else:
            # 预测所有像元
            data_flat = np.nan_to_num(data_flat)
            preds = model.predict(data_flat).reshape(original_shape[0], original_shape[1]).astype("uint16")
        
        # 存储结果
        prediction[y:y+h, :] = preds
    
    # 创建xarray DataArray
    prediction_da = xr.DataArray(
        prediction,
        dims=['y', 'x'],
        coords={
            'y': image.coords['y'],
            'x': image.coords['x']
        }
    )
    
    # 设置空间参考信息
    prediction_da.rio.write_crs(image.rio.crs, inplace=True)
    prediction_da.rio.write_transform(image.rio.transform(), inplace=True)
    prediction_da.rio.write_nodata(BACKGROUND_VALUE, inplace=True)
    
    # 保存为GeoTIFF
    prediction_da.rio.to_raster(
        out_path,
        driver='GTiff',
        dtype='uint16',
        compress='lzw',
        tiled=True
    )
    
    logger.info(f"预测结果已保存至: {out_path}")
    return out_path

def save_classification_result(data, ref_image, out_path):
    """
    保存分类结果到GeoTIFF文件（使用rioxarray）
    
    参数:
    - data: numpy array，分类结果
    - ref_image: xarray DataArray，参考影像（用于获取空间参考信息）
    - out_path: 输出路径
    """
    # 创建xarray DataArray
    result_da = xr.DataArray(
        data,
        dims=['y', 'x'],
        coords={
            'y': ref_image.coords['y'],
            'x': ref_image.coords['x']
        }
    )
    
    # 设置空间参考信息
    result_da.rio.write_crs(ref_image.rio.crs, inplace=True)
    result_da.rio.write_transform(ref_image.rio.transform(), inplace=True)
    result_da.rio.write_nodata(BACKGROUND_VALUE, inplace=True)
    
    # 保存为GeoTIFF
    result_da.rio.to_raster(
        out_path,
        driver='GTiff',
        dtype='uint16',
        compress='lzw',
        tiled=True
    )
    
    logger.info(f"分类结果已保存至: {out_path}")
    return out_path

def postprocess_classification(classified_data, min_patch_size=10, morphology_op="opening", morphology_size=3):
    """
    后处理分类结果（不处理背景值0）
    
    参数:
    - classified_data: 分类结果数组
    - min_patch_size: 最小图斑大小（像元数）
    - morphology_op: 形态学操作类型 ("opening", "closing", "both", "none")
    - morphology_size: 形态学操作核大小
    
    返回:
    - 后处理后的分类结果
    """
    logger.info("开始后处理分类结果...")
    processed_data = classified_data.copy()
    
    # 获取所有类别（排除背景0）
    classes = np.unique(classified_data)
    classes = classes[classes > 0]
    
    # 对每个类别进行后处理
    for class_id in classes:
        # 创建二值掩膜
        binary_mask = (classified_data == class_id).astype(np.uint8)
        
        # 去除小图斑
        if min_patch_size > 0:
            # 使用连通组件分析标记图斑
            labeled_array, num_features = ndimage.label(binary_mask)
            
            # 计算每个图斑的大小
            component_sizes = np.bincount(labeled_array.ravel())
            
            # 创建掩膜，只保留大于最小图斑大小的区域
            size_mask = component_sizes >= min_patch_size
            size_mask[0] = 0  # 背景
            
            # 应用大小过滤
            binary_mask = size_mask[labeled_array]
        
        # 形态学操作
        if morphology_op != "none" and morphology_size > 0:
            # 创建结构元素
            structure = np.ones((morphology_size, morphology_size), dtype=np.uint8)
            
            if morphology_op == "opening":
                binary_mask = morphology.binary_opening(binary_mask, structure)
            elif morphology_op == "closing":
                binary_mask = morphology.binary_closing(binary_mask, structure)
            elif morphology_op == "both":
                binary_mask = morphology.binary_opening(binary_mask, structure)
                binary_mask = morphology.binary_closing(binary_mask, structure)
        
        # 更新分类结果
        processed_data[binary_mask > 0] = class_id
        # 将去除的小图斑区域设为背景（0）
        processed_data[(classified_data == class_id) & (binary_mask == 0)] = BACKGROUND_VALUE
    
    # 统计后处理变化
    original_nonzero = np.count_nonzero(classified_data)
    processed_nonzero = np.count_nonzero(processed_data)
    change_percent = (original_nonzero - processed_nonzero) / original_nonzero * 100 if original_nonzero > 0 else 0
    
    logger.info(f"后处理完成: 原始非零像元数 {original_nonzero}, 处理后非零像元数 {processed_nonzero}")
    logger.info(f"后处理去除了 {original_nonzero - processed_nonzero} 个像元 ({change_percent:.2f}%)")
    
    return processed_data

def calculate_area_statistics(classified_data, class_names, class_colors, pixel_area_km2, suffix=""):
    """
    计算分类面积统计（排除背景值0）
    
    参数:
    - classified_data: 分类结果数组
    - class_names: 类别名称字典
    - class_colors: 类别颜色字典
    - pixel_area_km2: 单个像元面积（平方千米）
    - suffix: 文件名后缀
    
    返回:
    - stats_df: 统计DataFrame
    - total_area_km2: 总面积
    """
    # 获取类别和数量（排除背景0）
    unique, counts = np.unique(classified_data[classified_data > 0], return_counts=True)
    
    if len(unique) == 0:
        logger.warning("分类结果中没有有效类别（非背景）！")
        return pd.DataFrame(), 0.0
    
    total_pixels = np.sum(counts)
    
    # 计算各类别面积
    areas_km2 = [count * pixel_area_km2 for count in counts]
    total_area_km2 = np.sum(areas_km2)
    
    # 创建统计表格
    stats_df = pd.DataFrame({
        '类别编号': unique,
        '类别名称': [class_names.get(c, f'未知类别_{c}') for c in unique],
        '像元数量': counts,
        '面积(km²)': [round(area, 4) for area in areas_km2],
        '面积占比 (%)': (counts / total_pixels * 100).round(2)
    })
    
    # 保存统计表格
    stats_filename = f"classification_statistics{suffix}.csv"
    stats_df.to_csv(OUT_DIR / stats_filename, index=False, encoding='utf-8-sig')
    
    # 绘制面积占比饼图
    plt.figure(figsize=(10, 8))
    plt.pie(stats_df['面积占比 (%)'], labels=stats_df['类别名称'], autopct='%1.1f%%', startangle=90)
    plt.title(f'分类结果面积占比分布{suffix}', fontsize=14, fontweight='bold')
    plt.axis('equal')
    plt.tight_layout()
    plt.savefig(OUT_DIR / f"area_distribution{suffix}.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    # 绘制面积柱状图
    plt.figure(figsize=(12, 6))
    plt.bar(stats_df['类别名称'], stats_df['面积(km²)'], 
            color=[class_colors.get(c, 'gray') for c in unique])
    plt.xlabel('地物类别')
    plt.ylabel('面积 (km²)')
    plt.title(f'各类别面积统计{suffix}', fontsize=14, fontweight='bold')
    plt.xticks(rotation=45)
    plt.grid(True, alpha=0.3)
    
    # 在柱状图上添加数值标签
    for i, v in enumerate(stats_df['面积(km²)']):
        plt.text(i, v + max(areas_km2)*0.01, f'{v:.2f}', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.savefig(OUT_DIR / f"area_bar_chart{suffix}.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    # 统计背景信息
    n_background = np.sum(classified_data == 0)
    n_total = classified_data.size
    background_area_km2 = n_background * pixel_area_km2
    
    # 生成面积统计摘要
    with open(OUT_DIR / f"area_summary{suffix}.txt", "w", encoding="utf-8") as f:
        f.write(f"分类面积统计摘要{suffix}\n")
        f.write("="*50 + "\n")
        f.write(f"总分类面积: {total_area_km2:.4f} km²\n")
        f.write(f"背景面积: {background_area_km2:.4f} km²\n")
        f.write(f"总影像面积: {(total_area_km2 + background_area_km2):.4f} km²\n")
        f.write(f"像元大小: {pixel_area_km2 * 1e6:.2f} 平方米\n")
        f.write(f"分类像元数: {total_pixels}\n")
        f.write(f"背景像元数: {n_background}\n")
        f.write(f"总像元数: {n_total}\n\n")
        
        f.write("各类别面积统计:\n")
        f.write("-"*50 + "\n")
        for _, row in stats_df.iterrows():
            f.write(f"{row['类别名称']}: {row['面积(km²)']:.4f} km² ({row['面积占比 (%)']}%)\n")
    
    return stats_df, total_area_km2

# ------------------ 主流程 ------------------
def main():
    t0 = time.time()
    logger.info("开始监督分类任务...")
    logger.info(f"背景值处理: {'启用' if IGNORE_BACKGROUND else '禁用'}")

    # 0. 从训练样本shp文件中获取类别信息
    logger.info("正在读取类别信息...")
    class_names, class_colors, train_classes = get_class_info_from_shp(TRAIN_SHP, CLASS_ATTRIBUTE, NAME_ATTRIBUTE)
    logger.info(f"检测到类别: {list(class_names.values())}")

    # 1. 读取影像（使用rioxarray）
    logger.info("正在读取遥感影像...")
    img = rxr.open_rasterio(IMAGE_PATH, masked=True)
    logger.info(f"影像尺寸: {img.shape}, 波段数: {img.rio.count}")
    
    # 获取影像的空间参考信息
    transform = img.rio.transform()
    crs = img.rio.crs
    logger.info(f"影像坐标系: {crs}")
    logger.info(f"影像变换参数: {transform}")

    # 计算像元面积
    pixel_area_m2 = calculate_pixel_area(transform)
    pixel_area_km2 = pixel_area_m2 / 1e6  # 转换为平方千米
    logger.info(f"单个像元面积: {pixel_area_m2:.2f} 平方米 ({pixel_area_km2:.6f} 平方千米)")

    # 统计背景信息
    if IGNORE_BACKGROUND:
        background_mask = get_background_mask(img)
        n_background = np.sum(background_mask)
        n_total = img.shape[1] * img.shape[2]
        logger.info(f"影像背景像元数: {n_background} / {n_total} ({n_background/n_total*100:.2f}%)")

    # 2. 训练样本栅格化与提取
    logger.info("正在处理训练样本...")
    train_mask = rasterize_samples(TRAIN_SHP, img, CLASS_ATTRIBUTE)
    X_train, y_train = extract_samples(img, train_mask, ignore_background=IGNORE_BACKGROUND)
    logger.info(f"训练样本数: {len(y_train)}")

    # 3. 训练分类器
    clf = get_classifier(CLASSIFIER)
    logger.info(f"使用分类器: {clf.__class__.__name__}")
    logger.info("开始训练模型...")
    clf.fit(X_train, y_train)
    logger.info("模型训练完成。")

    # 4. 精度评估（训练集）
    logger.info("正在评估训练集精度...")
    y_pred_train = clf.predict(X_train)
    
    # 获取实际存在的类别
    actual_train_classes = sorted(np.unique(y_train))
    train_class_names = [class_names.get(c, f'未知类别_{c}') for c in actual_train_classes if c > 0]
    
    # 全方位精度评价
    overall_acc, kappa, eval_df = comprehensive_evaluation(
        y_train, y_pred_train, train_class_names, OUT_DIR / "train_evaluation.txt"
    )
    logger.info(f"训练集总体精度: {overall_acc:.4f}, Kappa: {kappa:.4f}")
    
    # 绘制训练集混淆矩阵
    plot_confusion_matrix(y_train, y_pred_train, train_class_names, OUT_DIR / "train_cm.png")

    # 5. 特征重要性分析（如果适用）
    if hasattr(clf, 'feature_importances_'):
        logger.info("正在分析特征重要性...")
        feature_names = [f'波段{i+1}' for i in range(X_train.shape[1])]
        plot_feature_importance(clf, feature_names, OUT_DIR / "feature_importance.png")

    # 6. 分块预测整幅影像
    logger.info("开始分块预测整幅影像...")
    classified_path = OUT_DIR / "classified_result.tif"
    predict_by_block(clf, img, classified_path, ignore_background=IGNORE_BACKGROUND)

    # 7. 显示原始分类结果
    logger.info("生成原始分类结果可视化...")
    classified_img = rxr.open_rasterio(classified_path)
    plot_classification_results(img, classified_img, class_names, class_colors, 
                               OUT_DIR / "classification_visualization.png", " (原始)")

    # 8. 原始分类结果面积统计
    logger.info("计算原始分类结果面积统计...")
    original_classified_data = classified_img.values.squeeze()
    original_stats_df, original_total_area = calculate_area_statistics(
        original_classified_data, class_names, class_colors, pixel_area_km2, "_original"
    )
    logger.info(f"原始分类总面积: {original_total_area:.4f} 平方千米")

    # 9. 后处理
    if POSTPROCESSING:
        logger.info("="*60)
        logger.info("开始后处理...")
        logger.info(f"后处理参数: 最小图斑大小={MIN_PATCH_SIZE}, 形态学操作={MORPHOLOGY_OPERATION}, 核大小={MORPHOLOGY_SIZE}")
        
        # 进行后处理
        processed_data = postprocess_classification(
            original_classified_data, 
            min_patch_size=MIN_PATCH_SIZE,
            morphology_op=MORPHOLOGY_OPERATION,
            morphology_size=MORPHOLOGY_SIZE
        )
        
        # 保存后处理结果（使用rioxarray）
        processed_path = OUT_DIR / "classified_result_postprocessed.tif"
        save_classification_result(processed_data, img, processed_path)
        
        # 显示后处理分类结果
        logger.info("生成后处理分类结果可视化...")
        processed_img = rxr.open_rasterio(processed_path)
        plot_classification_results(img, processed_img, class_names, class_colors,
                                   OUT_DIR / "classification_visualization_postprocessed.png", " (后处理)")
        
        # 后处理分类结果面积统计
        logger.info("计算后处理分类结果面积统计...")
        processed_stats_df, processed_total_area = calculate_area_statistics(
            processed_data, class_names, class_colors, pixel_area_km2, "_postprocessed"
        )
        logger.info(f"后处理分类总面积: {processed_total_area:.4f} 平方千米")
        
        # 生成后处理变化报告
        area_change = processed_total_area - original_total_area
        area_change_percent = (area_change / original_total_area) * 100 if original_total_area > 0 else 0
        
        with open(OUT_DIR / "postprocessing_report.txt", "w", encoding="utf-8") as f:
            f.write("后处理变化报告\n")
            f.write("="*50 + "\n")
            f.write(f"后处理参数:\n")
            f.write(f"  最小图斑大小: {MIN_PATCH_SIZE} 像元\n")
            f.write(f"  形态学操作: {MORPHOLOGY_OPERATION}\n")
            f.write(f"  核大小: {MORPHOLOGY_SIZE}\n\n")
            
            f.write(f"面积变化:\n")
            f.write(f"  原始总面积: {original_total_area:.4f} km²\n")
            f.write(f"  后处理总面积: {processed_total_area:.4f} km²\n")
            f.write(f"  面积变化: {area_change:+.4f} km² ({area_change_percent:+.2f}%)\n\n")
            
            f.write("各类别面积变化:\n")
            f.write("-"*50 + "\n")
            for class_id in class_names.keys():
                if class_id in original_stats_df['类别编号'].values and class_id in processed_stats_df['类别编号'].values:
                    orig_area = original_stats_df[original_stats_df['类别编号'] == class_id]['面积(km²)'].values[0]
                    proc_area = processed_stats_df[processed_stats_df['类别编号'] == class_id]['面积(km²)'].values[0]
                    change = proc_area - orig_area
                    change_percent = (change / orig_area) * 100 if orig_area > 0 else 0
                    f.write(f"{class_names[class_id]}: {orig_area:.4f} → {proc_area:.4f} km² ({change:+.4f}, {change_percent:+.2f}%)\n")
                elif class_id in original_stats_df['类别编号'].values:
                    orig_area = original_stats_df[original_stats_df['类别编号'] == class_id]['面积(km²)'].values[0]
                    f.write(f"{class_names[class_id]}: {orig_area:.4f} → 0.0000 km² (完全去除)\n")
                elif class_id in processed_stats_df['类别编号'].values:
                    proc_area = processed_stats_df[processed_stats_df['类别编号'] == class_id]['面积(km²)'].values[0]
                    f.write(f"{class_names[class_id]}: 0.0000 → {proc_area:.4f} km² (新增)\n")

    # 10. 验证阶段
    if os.path.exists(VAL_SHP):
        logger.info("="*60)
        logger.info("开始验证阶段...")
        val_mask = rasterize_samples(VAL_SHP, img, CLASS_ATTRIBUTE)
        
        # 使用原始分类结果进行验证
        with rxr.open_rasterio(classified_path) as pred_img:
            pred_arr = pred_img.values.squeeze()
        
        # 提取验证样本（忽略背景）
        if IGNORE_BACKGROUND:
            background_mask = get_background_mask(img)
            valid_val = (val_mask > 0) & (~background_mask)
            n_excluded = np.sum(val_mask > 0) - np.sum(valid_val)
            if n_excluded > 0:
                logger.info(f"验证集中排除了 {n_excluded} 个背景像元")
        else:
            valid_val = val_mask > 0
        
        Xv = pred_arr[valid_val]
        yv = val_mask[valid_val]
        
        logger.info(f"验证样本数: {len(yv)}")
        
        # 验证集全方位精度评价
        val_classes = sorted(np.unique(yv))
        val_class_names = [class_names.get(c, f'未知类别_{c}') for c in val_classes if c > 0]
        
        val_overall_acc, val_kappa, val_eval_df = comprehensive_evaluation(
            yv, Xv, val_class_names, OUT_DIR / "validation_evaluation.txt"
        )
        logger.info(f"验证集总体精度: {val_overall_acc:.4f}, Kappa: {val_kappa:.4f}")
        
        # 绘制验证集混淆矩阵
        plot_confusion_matrix(yv, Xv, val_class_names, OUT_DIR / "val_cm.png")

        # 生成综合报告
        logger.info("生成综合报告...")
        with open(OUT_DIR / "comprehensive_report.txt", "w", encoding="utf-8") as f:
            f.write("遥感影像分类综合报告\n")
            f.write("="*50 + "\n")
            f.write(f"分类器: {clf.__class__.__name__}\n")
            f.write(f"训练样本数: {len(y_train)}\n")
            f.write(f"验证样本数: {len(yv)}\n")
            f.write(f"类别编号字段: {CLASS_ATTRIBUTE}\n")
            f.write(f"类别名称字段: {NAME_ATTRIBUTE}\n")
            f.write(f"检测到的类别: {list(class_names.values())}\n")
            f.write(f"像元面积: {pixel_area_m2:.2f} 平方米\n")
            f.write(f"背景值处理: {'启用' if IGNORE_BACKGROUND else '禁用'}\n")
            f.write(f"后处理: {'是' if POSTPROCESSING else '否'}\n\n")
            
            f.write("精度评价汇总:\n")
            f.write("-"*30 + "\n")
            f.write(f"训练集总体精度: {overall_acc:.4f}\n")
            f.write(f"训练集Kappa系数: {kappa:.4f}\n")
            f.write(f"验证集总体精度: {val_overall_acc:.4f}\n")
            f.write(f"验证集Kappa系数: {val_kappa:.4f}\n\n")
            
            f.write("各类别验证精度:\n")
            f.write("-"*30 + "\n")
            f.write(val_eval_df.to_string(index=False, float_format='%.4f'))
    else:
        logger.warning(f"验证集文件不存在: {VAL_SHP}")

    # 11. 保存类别信息
    logger.info("保存类别信息...")
    class_info_df = pd.DataFrame({
        '类别编号': list(class_names.keys()),
        '类别名称': list(class_names.values()),
        '显示颜色': [class_colors.get(c, 'black') for c in class_names.keys()]
    })
    class_info_df.to_csv(OUT_DIR / "class_information.csv", index=False, encoding='utf-8-sig')
    
    logger.info("="*60)
    logger.info(f"全部任务完成，用时 {time.time()-t0:.1f} 秒。")
    logger.info(f"所有结果已保存至: {OUT_DIR.absolute()}")

if __name__ == "__main__":
    main()

2025-10-16 08:47:41,251 [INFO] 开始监督分类任务...
2025-10-16 08:47:41,252 [INFO] 背景值处理: 启用
2025-10-16 08:47:41,253 [INFO] 正在读取类别信息...
2025-10-16 08:47:41,263 [INFO] 检测到类别: ['类1', '类2', '类3', '类4', '类5', '类6', '类7', '类8', '类9', '类10', '类11']
2025-10-16 08:47:41,263 [INFO] 正在读取遥感影像...
2025-10-16 08:47:41,273 [INFO] 影像尺寸: (14, 1024, 2098), 波段数: 14
2025-10-16 08:47:41,276 [INFO] 影像坐标系: EPSG:32633
2025-10-16 08:47:41,278 [INFO] 影像变换参数: | 0.20, 0.00, 351916.64|
| 0.00,-0.20, 5997247.36|
| 0.00, 0.00, 1.00|
2025-10-16 08:47:41,279 [INFO] 单个像元面积: 0.04 平方米 (0.000000 平方千米)
2025-10-16 08:47:41,546 [INFO] 影像背景像元数: 614034 / 2148352 (28.58%)
2025-10-16 08:47:41,547 [INFO] 正在处理训练样本...
2025-10-16 08:47:41,613 [INFO] 训练样本数: 15041
2025-10-16 08:47:41,615 [INFO] 使用分类器: RandomForestClassifier
2025-10-16 08:47:41,616 [INFO] 开始训练模型...
[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 48 concurrent workers.
[Parallel(n_jobs=-1)]: Done 104 tasks      | elapsed:    0.5s
[Parallel(n_jobs=-1)]: Done 300 out of

### 程序2 增加更多分类方法

In [6]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
通用遥感影像监督分类系统 (多分类器对比版本)
-------------------------------------------------
功能：
1. 自动读取多波段遥感影像；
2. 从矢量样本中提取训练/验证数据；
3. 支持12种分类器对比；
4. 采用分块预测模式；
5. 输出分类结果 GeoTIFF；
6. 自动生成分类报告与混淆矩阵；
7. 显示分类影像和精度评价结果；
8. 分类面积统计（平方千米）；
9. 后处理功能（去除小图斑、形态学操作）；
10. 忽略背景值（所有波段值为0的像元）；
11. 多分类器性能对比分析。
"""

import os
import time
import logging
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import seaborn as sns
import geopandas as gpd
import rioxarray as rxr
import xarray as xr
from rasterio import features
from sklearn.ensemble import (RandomForestClassifier, GradientBoostingClassifier, 
                              AdaBoostClassifier, ExtraTreesClassifier)
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import (confusion_matrix, accuracy_score, classification_report, 
                           cohen_kappa_score, precision_score, recall_score, f1_score)
from tqdm import tqdm
from scipy import ndimage
from skimage import morphology
import warnings
warnings.filterwarnings('ignore')

plt.rcParams["font.sans-serif"] = ["SimHei", "DejaVu Sans"]
plt.rcParams["axes.unicode_minus"] = False

# ------------------ 参数配置 ------------------
IMAGE_PATH = r"D:\code313\Geo_programe\rasterio\RF\data\2017_09_05_stack.tif"
TRAIN_SHP = r"D:\code313\Geo_programe\rasterio\RF\data\cal.shp"
VAL_SHP = r"D:\code313\Geo_programe\rasterio\RF\data\val.shp"
CLASS_ATTRIBUTE = "class"
NAME_ATTRIBUTE = "name"
OUT_DIR = Path("./results_comparison")

# 选择要对比的分类器（可以选择部分或全部）
CLASSIFIERS_TO_RUN = [
    "rf",           # Random Forest
    "svm",          # Support Vector Machine
    "xgb",          # XGBoost
    "dt",           # Decision Tree
    "knn",          # K-Nearest Neighbors
    "nb",           # Naive Bayes
    "gb",           # Gradient Boosting
    "ada",          # AdaBoost
    "et",           # Extra Trees
    "lr",           # Logistic Regression
    "mlp",          # Multi-layer Perceptron
    "lgb"           # LightGBM
]

# 分类器参数
N_ESTIMATORS = 100  # 减少到100以加快速度
BLOCK_SIZE = 512
RANDOM_STATE = 42

# 后处理参数
POSTPROCESSING = False  # 对比时可暂时关闭后处理以加快速度
MIN_PATCH_SIZE = 10
MORPHOLOGY_OPERATION = "opening"
MORPHOLOGY_SIZE = 3

# 背景值处理
BACKGROUND_VALUE = 0
IGNORE_BACKGROUND = True

# 预定义颜色映射
LANDUSE_COLORS = {
    "水体": "lightblue", "河流": "blue", "湖泊": "deepskyblue",
    "植被": "forestgreen", "森林": "darkgreen", "草地": "limegreen",
    "农田": "yellowgreen", "耕地": "olivedrab",
    "建筑": "gray", "城市": "dimgray", "居民地": "slategray",
    "裸地": "tan", "沙地": "wheat", "其他": "darkred"
}

COLOR_PALETTE = ['forestgreen', 'lightblue', 'gray', 'tan', 'yellow', 
                'darkred', 'purple', 'orange', 'pink', 'brown', 
                'cyan', 'magenta', 'lime', 'navy', 'teal']

OUT_DIR.mkdir(exist_ok=True)

# ------------------ 日志系统 ------------------
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    handlers=[
        logging.FileHandler(OUT_DIR / "classification_comparison_log.txt", encoding="utf-8"),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# ------------------ 分类器配置 ------------------
def get_all_classifiers():
    """
    获取所有分类器的配置
    返回: {分类器代码: (分类器对象, 分类器名称, 分类器描述)}
    """
    classifiers = {
        "rf": (
            RandomForestClassifier(
                n_estimators=N_ESTIMATORS, 
                n_jobs=-1, 
                random_state=RANDOM_STATE,
                verbose=0
            ),
            "随机森林",
            "Random Forest - 集成学习方法，适合高维数据"
        ),
        "svm": (
            SVC(
                kernel="rbf", 
                probability=True, 
                random_state=RANDOM_STATE,
                verbose=False
            ),
            "支持向量机",
            "Support Vector Machine - 基于核函数的分类器"
        ),
        "dt": (
            DecisionTreeClassifier(
                random_state=RANDOM_STATE,
                max_depth=20
            ),
            "决策树",
            "Decision Tree - 简单直观的树形分类器"
        ),
        "knn": (
            KNeighborsClassifier(
                n_neighbors=5,
                n_jobs=-1
            ),
            "K近邻",
            "K-Nearest Neighbors - 基于距离的分类器"
        ),
        "nb": (
            GaussianNB(),
            "朴素贝叶斯",
            "Naive Bayes - 基于概率的快速分类器"
        ),
        "gb": (
            GradientBoostingClassifier(
                n_estimators=N_ESTIMATORS,
                random_state=RANDOM_STATE,
                verbose=0
            ),
            "梯度提升",
            "Gradient Boosting - 强大的集成学习方法"
        ),
        "ada": (
            AdaBoostClassifier(
                n_estimators=N_ESTIMATORS,
                random_state=RANDOM_STATE
            ),
            "AdaBoost",
            "AdaBoost - 自适应提升集成方法"
        ),
        "et": (
            ExtraTreesClassifier(
                n_estimators=N_ESTIMATORS,
                n_jobs=-1,
                random_state=RANDOM_STATE,
                verbose=0
            ),
            "极端随机树",
            "Extra Trees - 极端随机化的森林方法"
        ),
        "lr": (
            LogisticRegression(
                max_iter=1000,
                n_jobs=-1,
                random_state=RANDOM_STATE,
                verbose=0
            ),
            "逻辑回归",
            "Logistic Regression - 经典的线性分类器"
        ),
        "mlp": (
            MLPClassifier(
                hidden_layer_sizes=(100, 50),
                max_iter=300,
                random_state=RANDOM_STATE,
                verbose=False,
                early_stopping=True
            ),
            "神经网络",
            "Multi-layer Perceptron - 前馈神经网络"
        ),
    }
    
    # 尝试添加XGBoost
    try:
        from xgboost import XGBClassifier
        classifiers["xgb"] = (
            XGBClassifier(
                n_estimators=N_ESTIMATORS,
                learning_rate=0.1,
                max_depth=8,
                n_jobs=-1,
                random_state=RANDOM_STATE,
                verbosity=0
            ),
            "XGBoost",
            "XGBoost - 高性能梯度提升框架"
        )
    except ImportError:
        logger.warning("未安装 xgboost，跳过XGBoost分类器")
    
    # 尝试添加LightGBM
    try:
        from lightgbm import LGBMClassifier
        classifiers["lgb"] = (
            LGBMClassifier(
                n_estimators=N_ESTIMATORS,
                learning_rate=0.1,
                n_jobs=-1,
                random_state=RANDOM_STATE,
                verbose=-1
            ),
            "LightGBM",
            "LightGBM - 轻量级梯度提升框架"
        )
    except ImportError:
        logger.warning("未安装 lightgbm，跳过LightGBM分类器")
    
    return classifiers

# ------------------ 辅助函数 ------------------
def get_background_mask(image):
    """获取背景掩膜（所有波段值为0的像元）"""
    data = image.values
    background_mask = np.all(data == 0, axis=0)
    return background_mask

def get_class_info_from_shp(shp_path, class_attr, name_attr):
    """从shp文件中获取类别信息和自动生成的颜色"""
    gdf = gpd.read_file(shp_path)
    
    if name_attr not in gdf.columns:
        logger.warning(f"shp文件中没有找到 '{name_attr}' 字段，将使用类别编号作为名称")
        gdf[name_attr] = gdf[class_attr].apply(lambda x: f"Class_{x}")
    
    class_info = gdf[[class_attr, name_attr]].drop_duplicates()
    class_names = dict(zip(class_info[class_attr], class_info[name_attr]))
    
    class_colors = {}
    for i, (class_id, class_name) in enumerate(class_names.items()):
        color_found = False
        for key, color in LANDUSE_COLORS.items():
            if key in class_name:
                class_colors[class_id] = color
                color_found = True
                break
        if not color_found:
            class_colors[class_id] = COLOR_PALETTE[i % len(COLOR_PALETTE)]
    
    unique_classes = sorted(class_names.keys())
    return class_names, class_colors, unique_classes

def rasterize_samples(shp, ref_img, attr):
    """将矢量样本栅格化为与影像对齐的数组"""
    gdf = gpd.read_file(shp)
    gdf = gdf.to_crs(ref_img.rio.crs)
    shapes = ((geom, value) for geom, value in zip(gdf.geometry, gdf[attr]))
    
    arr = features.rasterize(
        shapes=shapes,
        out_shape=ref_img.shape[1:],
        transform=ref_img.rio.transform(),
        fill=0,
        all_touched=True,
        dtype="uint16"
    )
    return arr

def extract_samples(image, mask, ignore_background=True):
    """根据掩膜提取样本特征与标签"""
    data = np.moveaxis(image.values, 0, -1)
    valid = mask > 0
    
    if ignore_background:
        background_mask = get_background_mask(image)
        valid = valid & (~background_mask)
        n_background = np.sum(mask > 0) - np.sum(valid)
        if n_background > 0:
            logger.debug(f"排除了 {n_background} 个背景像元")
    
    X = data[valid]
    y = mask[valid]
    return X, y

def calculate_pixel_area(transform):
    """计算单个像元的面积（单位：平方米）"""
    pixel_width = abs(transform[0])
    pixel_height = abs(transform[4])
    pixel_area = pixel_width * pixel_height
    return pixel_area

def predict_by_block(model, image, out_path, block_size=BLOCK_SIZE, ignore_background=True):
    """分块预测整幅影像"""
    height, width = image.shape[1], image.shape[2]
    prediction = np.zeros((height, width), dtype='uint16')
    
    if ignore_background:
        background_mask = get_background_mask(image)
    
    for y in tqdm(range(0, height, block_size), desc="分块预测", leave=False):
        h = min(block_size, height - y)
        block_data = image.isel(y=slice(y, y+h)).values
        data = np.moveaxis(block_data, 0, -1)
        original_shape = data.shape
        data_flat = data.reshape(-1, data.shape[-1])
        
        if ignore_background:
            block_bg_mask = background_mask[y:y+h, :].flatten()
            non_bg_indices = ~block_bg_mask
            
            if np.any(non_bg_indices):
                data_to_predict = np.nan_to_num(data_flat[non_bg_indices])
                preds_non_bg = model.predict(data_to_predict)
                preds_flat = np.zeros(len(data_flat), dtype='uint16')
                preds_flat[non_bg_indices] = preds_non_bg
                preds = preds_flat.reshape(original_shape[0], original_shape[1])
            else:
                preds = np.zeros((original_shape[0], original_shape[1]), dtype='uint16')
        else:
            data_flat = np.nan_to_num(data_flat)
            preds = model.predict(data_flat).reshape(original_shape[0], original_shape[1]).astype("uint16")
        
        prediction[y:y+h, :] = preds
    
    # 创建xarray DataArray
    prediction_da = xr.DataArray(
        prediction,
        dims=['y', 'x'],
        coords={'y': image.coords['y'], 'x': image.coords['x']}
    )
    
    prediction_da.rio.write_crs(image.rio.crs, inplace=True)
    prediction_da.rio.write_transform(image.rio.transform(), inplace=True)
    prediction_da.rio.write_nodata(BACKGROUND_VALUE, inplace=True)
    
    prediction_da.rio.to_raster(out_path, driver='GTiff', dtype='uint16', compress='lzw', tiled=True)
    return out_path

def calculate_metrics(y_true, y_pred):
    """计算分类指标"""
    metrics = {
        'overall_accuracy': accuracy_score(y_true, y_pred),
        'kappa': cohen_kappa_score(y_true, y_pred),
        'precision_macro': precision_score(y_true, y_pred, average='macro', zero_division=0),
        'recall_macro': recall_score(y_true, y_pred, average='macro', zero_division=0),
        'f1_macro': f1_score(y_true, y_pred, average='macro', zero_division=0),
        'precision_weighted': precision_score(y_true, y_pred, average='weighted', zero_division=0),
        'recall_weighted': recall_score(y_true, y_pred, average='weighted', zero_division=0),
        'f1_weighted': f1_score(y_true, y_pred, average='weighted', zero_division=0)
    }
    return metrics

def plot_confusion_matrix(y_true, y_pred, class_names, save_path):
    """绘制混淆矩阵"""
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(10, 8))
    
    cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100
    
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names,
                cbar_kws={'label': '样本数量'})
    
    plt.xlabel('预测类别', fontsize=12)
    plt.ylabel('真实类别', fontsize=12)
    plt.title('混淆矩阵', fontsize=14, fontweight='bold')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

def plot_classification_result(original_img, classified_img, class_names, class_colors, 
                               save_path, title_suffix=""):
    """显示分类结果"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))
    
    # 原始影像
    if original_img.shape[0] >= 3:
        rgb_data = np.moveaxis(original_img.values[:3], 0, -1)
        p2, p98 = np.percentile(rgb_data[rgb_data > 0], (2, 98))
        rgb_display = np.clip((rgb_data - p2) / (p98 - p2), 0, 1)
        ax1.imshow(rgb_display)
    else:
        ax1.imshow(original_img.values[0], cmap='gray')
    
    ax1.set_title('原始遥感影像', fontsize=14, fontweight='bold')
    ax1.axis('off')
    
    # 分类结果
    classified_data = classified_img.values.squeeze()
    classes = np.unique(classified_data)
    classes = classes[classes > 0]
    
    colors = [class_colors.get(c, 'black') for c in classes]
    labels = [class_names.get(c, f'未知类别_{c}') for c in classes]
    
    cmap = mcolors.ListedColormap(colors)
    bounds = np.append(classes, classes[-1] + 1) - 0.5
    norm = mcolors.BoundaryNorm(bounds, cmap.N)
    
    display_data = classified_data.astype(float)
    display_data[classified_data == 0] = np.nan
    
    ax2.imshow(display_data, cmap=cmap, norm=norm)
    ax2.set_title(f'分类结果{title_suffix}', fontsize=14, fontweight='bold')
    ax2.axis('off')
    
    from matplotlib.patches import Patch
    legend_elements = [Patch(facecolor=color, label=label) 
                      for color, label in zip(colors, labels)]
    ax2.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.35, 1))
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

# ------------------ 对比分析函数 ------------------
def plot_classifier_comparison(comparison_df, save_dir):
    """绘制分类器性能对比图表"""
    
    # 1. 总体精度和Kappa系数对比
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    
    # 总体精度对比
    colors = plt.cm.viridis(np.linspace(0, 1, len(comparison_df)))
    bars1 = ax1.barh(comparison_df['分类器名称'], comparison_df['训练集精度'], 
                     color=colors, alpha=0.7, label='训练集')
    bars2 = ax1.barh(comparison_df['分类器名称'], comparison_df['验证集精度'], 
                     color=colors, alpha=0.4, label='验证集')
    
    ax1.set_xlabel('总体精度', fontsize=12)
    ax1.set_title('各分类器总体精度对比', fontsize=14, fontweight='bold')
    ax1.legend()
    ax1.grid(True, alpha=0.3, axis='x')
    ax1.set_xlim([0, 1])
    
    # 添加数值标签
    for i, (train_acc, val_acc) in enumerate(zip(comparison_df['训练集精度'], 
                                                   comparison_df['验证集精度'])):
        ax1.text(train_acc + 0.01, i, f'{train_acc:.4f}', va='center', fontsize=9)
        ax1.text(val_acc + 0.01, i, f'{val_acc:.4f}', va='center', fontsize=9)
    
    # Kappa系数对比
    bars3 = ax2.barh(comparison_df['分类器名称'], comparison_df['训练集Kappa'], 
                     color=colors, alpha=0.7, label='训练集')
    bars4 = ax2.barh(comparison_df['分类器名称'], comparison_df['验证集Kappa'], 
                     color=colors, alpha=0.4, label='验证集')
    
    ax2.set_xlabel('Kappa系数', fontsize=12)
    ax2.set_title('各分类器Kappa系数对比', fontsize=14, fontweight='bold')
    ax2.legend()
    ax2.grid(True, alpha=0.3, axis='x')
    ax2.set_xlim([0, 1])
    
    for i, (train_kappa, val_kappa) in enumerate(zip(comparison_df['训练集Kappa'], 
                                                       comparison_df['验证集Kappa'])):
        ax2.text(train_kappa + 0.01, i, f'{train_kappa:.4f}', va='center', fontsize=9)
        ax2.text(val_kappa + 0.01, i, f'{val_kappa:.4f}', va='center', fontsize=9)
    
    plt.tight_layout()
    plt.savefig(save_dir / "comparison_accuracy_kappa.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    # 2. F1分数对比
    plt.figure(figsize=(12, 6))
    x = np.arange(len(comparison_df))
    width = 0.35
    
    plt.bar(x - width/2, comparison_df['训练集F1'], width, label='训练集F1', alpha=0.8)
    plt.bar(x + width/2, comparison_df['验证集F1'], width, label='验证集F1', alpha=0.8)
    
    plt.xlabel('分类器')
    plt.ylabel('F1分数')
    plt.title('各分类器F1分数对比', fontsize=14, fontweight='bold')
    plt.xticks(x, comparison_df['分类器名称'], rotation=45, ha='right')
    plt.legend()
    plt.grid(True, alpha=0.3, axis='y')
    plt.ylim([0, 1])
    plt.tight_layout()
    plt.savefig(save_dir / "comparison_f1.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    # 3. 训练时间对比
    plt.figure(figsize=(12, 6))
    colors_time = plt.cm.plasma(np.linspace(0, 1, len(comparison_df)))
    bars = plt.bar(comparison_df['分类器名称'], comparison_df['训练时间(秒)'], 
                   color=colors_time, alpha=0.7)
    
    plt.xlabel('分类器')
    plt.ylabel('训练时间 (秒)')
    plt.title('各分类器训练时间对比', fontsize=14, fontweight='bold')
    plt.xticks(rotation=45, ha='right')
    plt.grid(True, alpha=0.3, axis='y')
    
    # 添加数值标签
    for i, (bar, time_val) in enumerate(zip(bars, comparison_df['训练时间(秒)'])):
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height,
                f'{time_val:.2f}s', ha='center', va='bottom', fontsize=9)
    
    plt.tight_layout()
    plt.savefig(save_dir / "comparison_training_time.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    # 4. 预测时间对比
    plt.figure(figsize=(12, 6))
    bars = plt.bar(comparison_df['分类器名称'], comparison_df['预测时间(秒)'], 
                   color=colors_time, alpha=0.7)
    
    plt.xlabel('分类器')
    plt.ylabel('预测时间 (秒)')
    plt.title('各分类器预测时间对比', fontsize=14, fontweight='bold')
    plt.xticks(rotation=45, ha='right')
    plt.grid(True, alpha=0.3, axis='y')
    
    for i, (bar, time_val) in enumerate(zip(bars, comparison_df['预测时间(秒)'])):
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height,
                f'{time_val:.2f}s', ha='center', va='bottom', fontsize=9)
    
    plt.tight_layout()
    plt.savefig(save_dir / "comparison_prediction_time.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    # 5. 综合性能雷达图
    plot_radar_chart(comparison_df, save_dir)
    
    # 6. 散点图：精度 vs 速度
    plot_accuracy_speed_scatter(comparison_df, save_dir)

def plot_radar_chart(comparison_df, save_dir):
    """绘制性能雷达图"""
    from math import pi
    
    # 选择前5个分类器
    top_n = min(5, len(comparison_df))
    top_df = comparison_df.nlargest(top_n, '验证集精度')
    
    categories = ['验证集精度', '验证集Kappa', '验证集F1', 
                  '训练速度*', '预测速度*']
    N = len(categories)
    
    # 准备数据（归一化）
    fig, ax = plt.subplots(figsize=(10, 10), subplot_kw=dict(projection='polar'))
    
    angles = [n / float(N) * 2 * pi for n in range(N)]
    angles += angles[:1]
    
    ax.set_theta_offset(pi / 2)
    ax.set_theta_direction(-1)
    ax.set_xticks(angles[:-1])
    ax.set_xticklabels(categories)
    
    colors_radar = plt.cm.Set2(np.linspace(0, 1, top_n))
    
    for idx, (_, row) in enumerate(top_df.iterrows()):
        # 归一化数据（速度取倒数并归一化）
        max_train_time = comparison_df['训练时间(秒)'].max()
        max_pred_time = comparison_df['预测时间(秒)'].max()
        
        values = [
            row['验证集精度'],
            row['验证集Kappa'],
            row['验证集F1'],
            1 - (row['训练时间(秒)'] / max_train_time),  # 速度越快越好
            1 - (row['预测时间(秒)'] / max_pred_time)
        ]
        values += values[:1]
        
        ax.plot(angles, values, 'o-', linewidth=2, label=row['分类器名称'], 
                color=colors_radar[idx])
        ax.fill(angles, values, alpha=0.15, color=colors_radar[idx])
    
    ax.set_ylim(0, 1)
    ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1))
    ax.set_title('分类器综合性能雷达图\n(*速度指标已归一化，越大越好)', 
                 fontsize=14, fontweight='bold', pad=20)
    ax.grid(True)
    
    plt.tight_layout()
    plt.savefig(save_dir / "comparison_radar.png", dpi=300, bbox_inches='tight')
    plt.close()

def plot_accuracy_speed_scatter(comparison_df, save_dir):
    """绘制精度-速度散点图"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    
    # 精度 vs 训练时间
    colors = plt.cm.rainbow(np.linspace(0, 1, len(comparison_df)))
    
    for idx, (_, row) in enumerate(comparison_df.iterrows()):
        ax1.scatter(row['训练时间(秒)'], row['验证集精度'], 
                   s=200, c=[colors[idx]], alpha=0.7, edgecolors='black', linewidth=1)
        ax1.annotate(row['分类器名称'], 
                    (row['训练时间(秒)'], row['验证集精度']),
                    xytext=(5, 5), textcoords='offset points', fontsize=9)
    
    ax1.set_xlabel('训练时间 (秒)', fontsize=12)
    ax1.set_ylabel('验证集精度', fontsize=12)
    ax1.set_title('精度 vs 训练时间', fontsize=14, fontweight='bold')
    ax1.grid(True, alpha=0.3)
    
    # 精度 vs 预测时间
    for idx, (_, row) in enumerate(comparison_df.iterrows()):
        ax2.scatter(row['预测时间(秒)'], row['验证集精度'], 
                   s=200, c=[colors[idx]], alpha=0.7, edgecolors='black', linewidth=1)
        ax2.annotate(row['分类器名称'], 
                    (row['预测时间(秒)'], row['验证集精度']),
                    xytext=(5, 5), textcoords='offset points', fontsize=9)
    
    ax2.set_xlabel('预测时间 (秒)', fontsize=12)
    ax2.set_ylabel('验证集精度', fontsize=12)
    ax2.set_title('精度 vs 预测时间', fontsize=14, fontweight='bold')
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(save_dir / "comparison_accuracy_speed.png", dpi=300, bbox_inches='tight')
    plt.close()

def generate_comparison_report(comparison_df, save_path):
    """生成对比分析报告"""
    with open(save_path, 'w', encoding='utf-8') as f:
        f.write("="*80 + "\n")
        f.write("          遥感影像分类器性能对比分析报告\n")
        f.write("="*80 + "\n\n")
        
        f.write(f"分类器数量: {len(comparison_df)}\n")
        f.write(f"测试日期: {time.strftime('%Y-%m-%d %H:%M:%S')}\n\n")
        
        f.write("-"*80 + "\n")
        f.write("一、验证集精度排名 (Total Accuracy)\n")
        f.write("-"*80 + "\n")
        sorted_by_acc = comparison_df.sort_values('验证集精度', ascending=False)
        for idx, (_, row) in enumerate(sorted_by_acc.iterrows(), 1):
            f.write(f"{idx}. {row['分类器名称']:12s} - "
                   f"精度: {row['验证集精度']:.4f}, "
                   f"Kappa: {row['验证集Kappa']:.4f}\n")
        
        f.write("\n" + "-"*80 + "\n")
        f.write("二、Kappa系数排名\n")
        f.write("-"*80 + "\n")
        sorted_by_kappa = comparison_df.sort_values('验证集Kappa', ascending=False)
        for idx, (_, row) in enumerate(sorted_by_kappa.iterrows(), 1):
            f.write(f"{idx}. {row['分类器名称']:12s} - "
                   f"Kappa: {row['验证集Kappa']:.4f}\n")
        
        f.write("\n" + "-"*80 + "\n")
        f.write("三、F1分数排名\n")
        f.write("-"*80 + "\n")
        sorted_by_f1 = comparison_df.sort_values('验证集F1', ascending=False)
        for idx, (_, row) in enumerate(sorted_by_f1.iterrows(), 1):
            f.write(f"{idx}. {row['分类器名称']:12s} - "
                   f"F1: {row['验证集F1']:.4f}\n")
        
        f.write("\n" + "-"*80 + "\n")
        f.write("四、训练速度排名 (从快到慢)\n")
        f.write("-"*80 + "\n")
        sorted_by_train_time = comparison_df.sort_values('训练时间(秒)')
        for idx, (_, row) in enumerate(sorted_by_train_time.iterrows(), 1):
            f.write(f"{idx}. {row['分类器名称']:12s} - "
                   f"训练时间: {row['训练时间(秒)']:.2f} 秒\n")
        
        f.write("\n" + "-"*80 + "\n")
        f.write("五、预测速度排名 (从快到慢)\n")
        f.write("-"*80 + "\n")
        sorted_by_pred_time = comparison_df.sort_values('预测时间(秒)')
        for idx, (_, row) in enumerate(sorted_by_pred_time.iterrows(), 1):
            f.write(f"{idx}. {row['分类器名称']:12s} - "
                   f"预测时间: {row['预测时间(秒)']:.2f} 秒\n")
        
        f.write("\n" + "-"*80 + "\n")
        f.write("六、综合评价\n")
        f.write("-"*80 + "\n")
        
        # 最佳精度
        best_acc_row = comparison_df.loc[comparison_df['验证集精度'].idxmax()]
        f.write(f"\n【最佳精度】 {best_acc_row['分类器名称']}\n")
        f.write(f"  验证集精度: {best_acc_row['验证集精度']:.4f}\n")
        f.write(f"  Kappa系数: {best_acc_row['验证集Kappa']:.4f}\n")
        f.write(f"  训练时间: {best_acc_row['训练时间(秒)']:.2f} 秒\n")
        
        # 最快速度
        best_speed_row = comparison_df.loc[comparison_df['训练时间(秒)'].idxmin()]
        f.write(f"\n【最快训练】 {best_speed_row['分类器名称']}\n")
        f.write(f"  训练时间: {best_speed_row['训练时间(秒)']:.2f} 秒\n")
        f.write(f"  验证集精度: {best_speed_row['验证集精度']:.4f}\n")
        
        # 综合性能（精度和速度的平衡）
        comparison_df['综合得分'] = (
            comparison_df['验证集精度'] * 0.6 + 
            (1 - comparison_df['训练时间(秒)'] / comparison_df['训练时间(秒)'].max()) * 0.2 +
            (1 - comparison_df['预测时间(秒)'] / comparison_df['预测时间(秒)'].max()) * 0.2
        )
        best_overall_row = comparison_df.loc[comparison_df['综合得分'].idxmax()]
        f.write(f"\n【综合最佳】 {best_overall_row['分类器名称']}\n")
        f.write(f"  综合得分: {best_overall_row['综合得分']:.4f}\n")
        f.write(f"  验证集精度: {best_overall_row['验证集精度']:.4f}\n")
        f.write(f"  训练时间: {best_overall_row['训练时间(秒)']:.2f} 秒\n")
        f.write(f"  预测时间: {best_overall_row['预测时间(秒)']:.2f} 秒\n")
        
        f.write("\n" + "-"*80 + "\n")
        f.write("七、详细对比表格\n")
        f.write("-"*80 + "\n\n")
        f.write(comparison_df.to_string(index=False))
        
        f.write("\n\n" + "="*80 + "\n")
        f.write("注：综合得分 = 验证集精度×0.6 + 训练速度得分×0.2 + 预测速度得分×0.2\n")
        f.write("="*80 + "\n")

# ------------------ 主流程 ------------------
def main():
    t0 = time.time()
    logger.info("="*80)
    logger.info("开始多分类器对比实验...")
    logger.info("="*80)
    
    # 0. 获取所有可用分类器
    all_classifiers = get_all_classifiers()
    
    # 筛选要运行的分类器
    classifiers_to_run = {k: v for k, v in all_classifiers.items() 
                         if k in CLASSIFIERS_TO_RUN}
    
    if not classifiers_to_run:
        logger.error("没有可用的分类器！请检查配置。")
        return
    
    logger.info(f"将对比 {len(classifiers_to_run)} 个分类器:")
    for code, (_, name, desc) in classifiers_to_run.items():
        logger.info(f"  - {name} ({code}): {desc}")
    
    # 1. 读取类别信息
    logger.info("\n正在读取类别信息...")
    class_names, class_colors, train_classes = get_class_info_from_shp(
        TRAIN_SHP, CLASS_ATTRIBUTE, NAME_ATTRIBUTE
    )
    logger.info(f"检测到类别: {list(class_names.values())}")
    
    # 2. 读取影像
    logger.info("\n正在读取遥感影像...")
    img = rxr.open_rasterio(IMAGE_PATH, masked=True)
    logger.info(f"影像尺寸: {img.shape}, 波段数: {img.rio.count}")
    
    transform = img.rio.transform()
    crs = img.rio.crs
    pixel_area_m2 = calculate_pixel_area(transform)
    pixel_area_km2 = pixel_area_m2 / 1e6
    
    # 3. 提取训练样本
    logger.info("\n正在处理训练样本...")
    train_mask = rasterize_samples(TRAIN_SHP, img, CLASS_ATTRIBUTE)
    X_train, y_train = extract_samples(img, train_mask, ignore_background=IGNORE_BACKGROUND)
    logger.info(f"训练样本数: {len(y_train)}")
    
    # 4. 提取验证样本
    val_exists = os.path.exists(VAL_SHP)
    if val_exists:
        logger.info("正在处理验证样本...")
        val_mask = rasterize_samples(VAL_SHP, img, CLASS_ATTRIBUTE)
        
        if IGNORE_BACKGROUND:
            background_mask = get_background_mask(img)
            valid_val = (val_mask > 0) & (~background_mask)
        else:
            valid_val = val_mask > 0
        
        yv_true = val_mask[valid_val]
        logger.info(f"验证样本数: {len(yv_true)}")
    else:
        logger.warning(f"验证集文件不存在: {VAL_SHP}")
    
    # 5. 对每个分类器进行训练和评估
    comparison_results = []
    
    for clf_code, (clf, clf_name, clf_desc) in classifiers_to_run.items():
        logger.info("\n" + "="*80)
        logger.info(f"正在测试分类器: {clf_name} ({clf_code})")
        logger.info("="*80)
        
        # 创建分类器专属目录
        clf_dir = OUT_DIR / clf_code
        clf_dir.mkdir(exist_ok=True)
        
        try:
            # 训练
            logger.info("开始训练...")
            train_start = time.time()
            clf.fit(X_train, y_train)
            train_time = time.time() - train_start
            logger.info(f"训练完成，耗时: {train_time:.2f} 秒")
            
            # 训练集精度
            logger.info("评估训练集...")
            y_train_pred = clf.predict(X_train)
            train_metrics = calculate_metrics(y_train, y_train_pred)
            logger.info(f"训练集精度: {train_metrics['overall_accuracy']:.4f}, "
                       f"Kappa: {train_metrics['kappa']:.4f}")
            
            # 绘制训练集混淆矩阵
            train_class_names = [class_names.get(c, f'Class_{c}') 
                               for c in sorted(np.unique(y_train))]
            plot_confusion_matrix(y_train, y_train_pred, train_class_names,
                                clf_dir / "train_confusion_matrix.png")
            
            # 预测整幅影像
            logger.info("开始预测整幅影像...")
            pred_start = time.time()
            classified_path = clf_dir / f"classified_{clf_code}.tif"
            predict_by_block(clf, img, classified_path, ignore_background=IGNORE_BACKGROUND)
            pred_time = time.time() - pred_start
            logger.info(f"预测完成，耗时: {pred_time:.2f} 秒")
            
            # 验证集精度
            val_metrics = {'overall_accuracy': np.nan, 'kappa': np.nan, 'f1_macro': np.nan}
            if val_exists:
                logger.info("评估验证集...")
                with rxr.open_rasterio(classified_path) as pred_img:
                    pred_arr = pred_img.values.squeeze()
                
                yv_pred = pred_arr[valid_val]
                val_metrics = calculate_metrics(yv_true, yv_pred)
                logger.info(f"验证集精度: {val_metrics['overall_accuracy']:.4f}, "
                           f"Kappa: {val_metrics['kappa']:.4f}")
                
                # 绘制验证集混淆矩阵
                val_class_names = [class_names.get(c, f'Class_{c}') 
                                 for c in sorted(np.unique(yv_true))]
                plot_confusion_matrix(yv_true, yv_pred, val_class_names,
                                    clf_dir / "val_confusion_matrix.png")
            
            # 可视化分类结果
            logger.info("生成可视化结果...")
            classified_img = rxr.open_rasterio(classified_path)
            plot_classification_result(img, classified_img, class_names, class_colors,
                                      clf_dir / f"result_{clf_code}.png",
                                      f" ({clf_name})")
            
            # 记录结果
            result = {
                '分类器代码': clf_code,
                '分类器名称': clf_name,
                '描述': clf_desc,
                '训练集精度': train_metrics['overall_accuracy'],
                '训练集Kappa': train_metrics['kappa'],
                '训练集F1': train_metrics['f1_macro'],
                '验证集精度': val_metrics['overall_accuracy'],
                '验证集Kappa': val_metrics['kappa'],
                '验证集F1': val_metrics['f1_macro'],
                '训练时间(秒)': train_time,
                '预测时间(秒)': pred_time,
                '总时间(秒)': train_time + pred_time
            }
            comparison_results.append(result)
            
            logger.info(f"✓ {clf_name} 测试完成")
            
        except Exception as e:
            logger.error(f"✗ {clf_name} 测试失败: {str(e)}")
            import traceback
            logger.error(traceback.format_exc())
            continue
    
    # 6. 生成对比分析
    logger.info("\n" + "="*80)
    logger.info("生成对比分析报告...")
    logger.info("="*80)
    
    comparison_df = pd.DataFrame(comparison_results)
    
    # 保存对比表格
    comparison_df.to_csv(OUT_DIR / "classifier_comparison.csv", 
                        index=False, encoding='utf-8-sig')
    logger.info(f"对比表格已保存")
    
    # 生成对比图表
    logger.info("生成对比图表...")
    plot_classifier_comparison(comparison_df, OUT_DIR)
    
    # 生成文字报告
    logger.info("生成文字报告...")
    generate_comparison_report(comparison_df, OUT_DIR / "comparison_report.txt")
    
    # 7. 输出摘要
    logger.info("\n" + "="*80)
    logger.info("对比实验完成！")
    logger.info("="*80)
    logger.info(f"\n总耗时: {time.time() - t0:.1f} 秒")
    logger.info(f"结果保存路径: {OUT_DIR.absolute()}\n")
    
    # 显示Top 3
    logger.info("【验证集精度 Top 3】")
    top3 = comparison_df.nlargest(3, '验证集精度')
    for idx, (_, row) in enumerate(top3.iterrows(), 1):
        logger.info(f"  {idx}. {row['分类器名称']:12s} - "
                   f"精度: {row['验证集精度']:.4f}, "
                   f"Kappa: {row['验证集Kappa']:.4f}, "
                   f"训练时间: {row['训练时间(秒)']:.2f}s")
    
    logger.info("\n所有对比图表和报告已生成！")

if __name__ == "__main__":
    main()


2025-10-16 08:58:59,118 [INFO] 开始多分类器对比实验...
2025-10-16 08:59:01,247 [INFO] 将对比 12 个分类器:
2025-10-16 08:59:01,249 [INFO]   - 随机森林 (rf): Random Forest - 集成学习方法，适合高维数据
2025-10-16 08:59:01,251 [INFO]   - 支持向量机 (svm): Support Vector Machine - 基于核函数的分类器
2025-10-16 08:59:01,252 [INFO]   - 决策树 (dt): Decision Tree - 简单直观的树形分类器
2025-10-16 08:59:01,252 [INFO]   - K近邻 (knn): K-Nearest Neighbors - 基于距离的分类器
2025-10-16 08:59:01,253 [INFO]   - 朴素贝叶斯 (nb): Naive Bayes - 基于概率的快速分类器
2025-10-16 08:59:01,254 [INFO]   - 梯度提升 (gb): Gradient Boosting - 强大的集成学习方法
2025-10-16 08:59:01,255 [INFO]   - AdaBoost (ada): AdaBoost - 自适应提升集成方法
2025-10-16 08:59:01,255 [INFO]   - 极端随机树 (et): Extra Trees - 极端随机化的森林方法
2025-10-16 08:59:01,256 [INFO]   - 逻辑回归 (lr): Logistic Regression - 经典的线性分类器
2025-10-16 08:59:01,256 [INFO]   - 神经网络 (mlp): Multi-layer Perceptron - 前馈神经网络
2025-10-16 08:59:01,257 [INFO]   - XGBoost (xgb): XGBoost - 高性能梯度提升框架
2025-10-16 08:59:01,257 [INFO]   - LightGBM (lgb): LightGBM - 轻量级梯度提升框架
2025-10-16 08

## 主要特性

### 1. **支持的分类器（12种）**
- ✅ Random Forest (随机森林)
- ✅ SVM (支持向量机)
- ✅ XGBoost (极端梯度提升)
- ✅ Decision Tree (决策树)
- ✅ K-Nearest Neighbors (K近邻)
- ✅ Naive Bayes (朴素贝叶斯)
- ✅ Gradient Boosting (梯度提升)
- ✅ AdaBoost (自适应提升)
- ✅ Extra Trees (极端随机树)
- ✅ Logistic Regression (逻辑回归)
- ✅ Multi-layer Perceptron (神经网络)
- ✅ LightGBM (轻量级梯度提升)

### 2. **对比分析功能**

#### 生成的对比图表：
1. **总体精度和Kappa系数对比图** - 横向条形图
2. **F1分数对比图** - 柱状图
3. **训练时间对比图** - 柱状图
4. **预测时间对比图** - 柱状图
5. **综合性能雷达图** - 显示Top 5分类器
6. **精度-速度散点图** - 展示性能权衡

#### 生成的报告：
- **comparison_report.txt** - 详细的文字对比报告
- **classifier_comparison.csv** - Excel格式对比表格
- 包含排名、最佳分类器推荐等

### 3. **评价指标**
- 总体精度 (Overall Accuracy)
- Kappa系数
- F1分数（宏平均和加权平均）
- 精确率和召回率
- 训练时间
- 预测时间
- 综合得分

### 4. **使用方法**

```python
# 修改配置选择要对比的分类器
CLASSIFIERS_TO_RUN = [
    "rf",    # 随机森林
    "svm",   # SVM
    "xgb",   # XGBoost
    "dt",    # 决策树
    "knn",   # K近邻
    "nb",    # 朴素贝叶斯
    # ... 可以选择任意组合
]

# 运行程序
python script.py
```

### 5. **输出结构**
```
results_comparison/
├── rf/
│   ├── classified_rf.tif
│   ├── train_confusion_matrix.png
│   ├── val_confusion_matrix.png
│   └── result_rf.png
├── svm/
│   └── ...
├── xgb/
│   └── ...
├── classifier_comparison.csv
├── comparison_report.txt
├── comparison_accuracy_kappa.png
├── comparison_f1.png
├── comparison_training_time.png
├── comparison_prediction_time.png
├── comparison_radar.png
└── comparison_accuracy_speed.png
```

### 6. **性能优化建议**

如果数据量很大，可以：
1. 减少 `N_ESTIMATORS` (如设为50)
2. 只选择部分分类器对比
3. 使用样本采样加快训练
4. 临时关闭后处理 `POSTPROCESSING = False`

### 7. **报告示例内容**

```
一、验证集精度排名
1. 随机森林     - 精度: 0.9245, Kappa: 0.9012
2. XGBoost      - 精度: 0.9198, Kappa: 0.8956
3. Extra Trees  - 精度: 0.9102, Kappa: 0.8845
...

【最佳精度】 随机森林
  验证集精度: 0.9245
  Kappa系数: 0.9012
  训练时间: 45.32 秒

【最快训练】 朴素贝叶斯
  训练时间: 2.15 秒
  验证集精度: 0.8234

【综合最佳】 XGBoost
  综合得分: 0.8956
  验证集精度: 0.9198
  训练时间: 52.18 秒
```

这个版本可以全面对比不同分类器的性能，帮助选择最适合您数据的方法！

## 交互界面


### 基础版本

In [10]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
遥感影像监督分类系统 - GUI版本 (修复版)
支持多分类器对比和可视化
"""

import os
import sys
import time
import threading
import queue
from pathlib import Path
import tkinter as tk
from tkinter import ttk, filedialog, messagebox, scrolledtext
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure
import matplotlib.colors as mcolors
import seaborn as sns
import geopandas as gpd
import rioxarray as rxr
import xarray as xr
from rasterio import features
from sklearn.preprocessing import LabelEncoder
from sklearn.ensemble import (RandomForestClassifier, GradientBoostingClassifier, 
                              AdaBoostClassifier, ExtraTreesClassifier)
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import (confusion_matrix, accuracy_score, classification_report, 
                           cohen_kappa_score, precision_score, recall_score, f1_score)
from tqdm import tqdm
from scipy import ndimage
from skimage import morphology
import warnings
warnings.filterwarnings('ignore')

plt.rcParams["font.sans-serif"] = ["SimHei", "DejaVu Sans"]
plt.rcParams["axes.unicode_minus"] = False

# ==================== 后端处理类 ====================
class ClassificationBackend:
    """分类处理后端"""
    
    def __init__(self):
        self.BACKGROUND_VALUE = 0
        self.RANDOM_STATE = 42
        
        # 预定义颜色
        self.LANDUSE_COLORS = {
            "水体": "lightblue", "河流": "blue", "湖泊": "deepskyblue",
            "植被": "forestgreen", "森林": "darkgreen", "草地": "limegreen",
            "农田": "yellowgreen", "耕地": "olivedrab",
            "建筑": "gray", "城市": "dimgray", "居民地": "slategray",
            "裸地": "tan", "沙地": "wheat", "其他": "darkred"
        }
        
        self.COLOR_PALETTE = ['forestgreen', 'lightblue', 'gray', 'tan', 'yellow', 
                             'darkred', 'purple', 'orange', 'pink', 'brown']
    
    def get_all_classifiers(self, n_estimators=100):
        """获取所有分类器"""
        classifiers = {
            "rf": (
                RandomForestClassifier(n_estimators=n_estimators, n_jobs=-1, 
                                      random_state=self.RANDOM_STATE, verbose=0),
                "随机森林", "Random Forest - 集成学习方法", False
            ),
            "svm": (
                SVC(kernel="rbf", probability=True, random_state=self.RANDOM_STATE),
                "支持向量机", "Support Vector Machine", False
            ),
            "dt": (
                DecisionTreeClassifier(random_state=self.RANDOM_STATE, max_depth=20),
                "决策树", "Decision Tree", False
            ),
            "knn": (
                KNeighborsClassifier(n_neighbors=5, n_jobs=-1),
                "K近邻", "K-Nearest Neighbors", False
            ),
            "nb": (
                GaussianNB(),
                "朴素贝叶斯", "Naive Bayes", False
            ),
            "gb": (
                GradientBoostingClassifier(n_estimators=n_estimators, 
                                          random_state=self.RANDOM_STATE, verbose=0),
                "梯度提升", "Gradient Boosting", False
            ),
            "ada": (
                AdaBoostClassifier(n_estimators=n_estimators, random_state=self.RANDOM_STATE),
                "AdaBoost", "AdaBoost", False
            ),
            "et": (
                ExtraTreesClassifier(n_estimators=n_estimators, n_jobs=-1, 
                                    random_state=self.RANDOM_STATE, verbose=0),
                "极端随机树", "Extra Trees", False
            ),
            "lr": (
                LogisticRegression(max_iter=1000, n_jobs=-1, 
                                  random_state=self.RANDOM_STATE, verbose=0),
                "逻辑回归", "Logistic Regression", False
            ),
            "mlp": (
                MLPClassifier(hidden_layer_sizes=(100, 50), max_iter=300,
                            random_state=self.RANDOM_STATE, verbose=False, early_stopping=True),
                "神经网络", "Neural Network", False
            ),
        }
        
        try:
            from xgboost import XGBClassifier
            classifiers["xgb"] = (
                XGBClassifier(n_estimators=n_estimators, learning_rate=0.1, max_depth=8,
                            n_jobs=-1, random_state=self.RANDOM_STATE, verbosity=0),
                "XGBoost", "XGBoost", True  # 需要标签编码
            )
        except ImportError:
            pass
        
        try:
            from lightgbm import LGBMClassifier
            classifiers["lgb"] = (
                LGBMClassifier(n_estimators=n_estimators, learning_rate=0.1,
                             n_jobs=-1, random_state=self.RANDOM_STATE, verbose=-1),
                "LightGBM", "LightGBM", False
            )
        except ImportError:
            pass
        
        return classifiers
    
    def get_background_mask(self, image):
        """获取背景掩膜"""
        data = image.values
        background_mask = np.all(data == 0, axis=0)
        return background_mask
    
    def get_class_info_from_shp(self, shp_path, class_attr, name_attr):
        """从shp文件获取类别信息"""
        gdf = gpd.read_file(shp_path)
        
        if name_attr not in gdf.columns:
            gdf[name_attr] = gdf[class_attr].apply(lambda x: f"Class_{x}")
        
        class_info = gdf[[class_attr, name_attr]].drop_duplicates()
        class_names = dict(zip(class_info[class_attr], class_info[name_attr]))
        
        class_colors = {}
        for i, (class_id, class_name) in enumerate(class_names.items()):
            color_found = False
            for key, color in self.LANDUSE_COLORS.items():
                if key in class_name:
                    class_colors[class_id] = color
                    color_found = True
                    break
            if not color_found:
                class_colors[class_id] = self.COLOR_PALETTE[i % len(self.COLOR_PALETTE)]
        
        return class_names, class_colors, sorted(class_names.keys())
    
    def rasterize_samples(self, shp, ref_img, attr):
        """矢量栅格化"""
        gdf = gpd.read_file(shp)
        gdf = gdf.to_crs(ref_img.rio.crs)
        shapes = ((geom, value) for geom, value in zip(gdf.geometry, gdf[attr]))
        
        arr = features.rasterize(
            shapes=shapes,
            out_shape=ref_img.shape[1:],
            transform=ref_img.rio.transform(),
            fill=0,
            all_touched=True,
            dtype="uint16"
        )
        return arr
    
    def extract_samples(self, image, mask, ignore_background=True):
        """提取样本并清理NaN值"""
        data = np.moveaxis(image.values, 0, -1)
        valid = mask > 0
        
        if ignore_background:
            background_mask = self.get_background_mask(image)
            valid = valid & (~background_mask)
        
        X = data[valid]
        y = mask[valid]
        
        # ===== 关键修复：处理NaN值 =====
        # 检查是否有NaN
        nan_mask = np.isnan(X).any(axis=1)
        n_nan = np.sum(nan_mask)
        
        if n_nan > 0:
            # 方法1: 删除包含NaN的样本（推荐）
            X = X[~nan_mask]
            y = y[~nan_mask]
            # 方法2: 也可以用0填充NaN
            # X = np.nan_to_num(X, nan=0.0)
        
        # 再次检查是否还有无穷值
        inf_mask = np.isinf(X).any(axis=1)
        n_inf = np.sum(inf_mask)
        
        if n_inf > 0:
            X = X[~inf_mask]
            y = y[~inf_mask]
        
        return X, y, n_nan, n_inf
    
    def calculate_metrics(self, y_true, y_pred):
        """计算评价指标"""
        return {
            'overall_accuracy': accuracy_score(y_true, y_pred),
            'kappa': cohen_kappa_score(y_true, y_pred),
            'precision_macro': precision_score(y_true, y_pred, average='macro', zero_division=0),
            'recall_macro': recall_score(y_true, y_pred, average='macro', zero_division=0),
            'f1_macro': f1_score(y_true, y_pred, average='macro', zero_division=0),
        }
    
    def predict_by_block(self, model, image, out_path, block_size=512, 
                        ignore_background=True, progress_callback=None,
                        label_encoder=None):
        """分块预测"""
        height, width = image.shape[1], image.shape[2]
        prediction = np.zeros((height, width), dtype='uint16')
        
        if ignore_background:
            background_mask = self.get_background_mask(image)
        
        total_blocks = int(np.ceil(height / block_size))
        
        for i, y in enumerate(range(0, height, block_size)):
            h = min(block_size, height - y)
            block_data = image.isel(y=slice(y, y+h)).values
            data = np.moveaxis(block_data, 0, -1)
            original_shape = data.shape
            data_flat = data.reshape(-1, data.shape[-1])
            
            if ignore_background:
                block_bg_mask = background_mask[y:y+h, :].flatten()
                non_bg_indices = ~block_bg_mask
                
                if np.any(non_bg_indices):
                    data_to_predict = np.nan_to_num(data_flat[non_bg_indices], nan=0.0, 
                                                   posinf=0.0, neginf=0.0)
                    preds_non_bg = model.predict(data_to_predict)
                    
                    # 如果使用了标签编码，需要反向转换
                    if label_encoder is not None:
                        preds_non_bg = label_encoder.inverse_transform(preds_non_bg)
                    
                    preds_flat = np.zeros(len(data_flat), dtype='uint16')
                    preds_flat[non_bg_indices] = preds_non_bg
                    preds = preds_flat.reshape(original_shape[0], original_shape[1])
                else:
                    preds = np.zeros((original_shape[0], original_shape[1]), dtype='uint16')
            else:
                data_flat = np.nan_to_num(data_flat, nan=0.0, posinf=0.0, neginf=0.0)
                preds = model.predict(data_flat)
                
                if label_encoder is not None:
                    preds = label_encoder.inverse_transform(preds)
                
                preds = preds.reshape(original_shape[0], original_shape[1]).astype("uint16")
            
            prediction[y:y+h, :] = preds
            
            if progress_callback:
                progress_callback((i + 1) / total_blocks * 100)
        
        # 保存结果
        prediction_da = xr.DataArray(
            prediction,
            dims=['y', 'x'],
            coords={'y': image.coords['y'], 'x': image.coords['x']}
        )
        
        prediction_da.rio.write_crs(image.rio.crs, inplace=True)
        prediction_da.rio.write_transform(image.rio.transform(), inplace=True)
        prediction_da.rio.write_nodata(self.BACKGROUND_VALUE, inplace=True)
        
        prediction_da.rio.to_raster(out_path, driver='GTiff', dtype='uint16', 
                                    compress='lzw', tiled=True)
        return out_path

# ==================== GUI主类 ====================
class ClassificationGUI:
    """遥感影像分类GUI主界面"""
    
    def __init__(self, root):
        self.root = root
        self.root.title("遥感影像监督分类系统 v2.1")
        self.root.geometry("1400x900")
        
        # 后端处理对象
        self.backend = ClassificationBackend()
        
        # 数据变量
        self.image_path = tk.StringVar()
        self.train_shp_path = tk.StringVar()
        self.val_shp_path = tk.StringVar()
        self.output_dir = tk.StringVar(value=str(Path("./results_gui")))
        
        self.class_attr = tk.StringVar(value="class")
        self.name_attr = tk.StringVar(value="name")
        self.n_estimators = tk.IntVar(value=100)
        self.block_size = tk.IntVar(value=512)
        self.ignore_background = tk.BooleanVar(value=True)
        
        # 分类器选择
        self.classifier_vars = {}
        all_classifiers = self.backend.get_all_classifiers()
        for code, (_, name, _, _) in all_classifiers.items():
            self.classifier_vars[code] = tk.BooleanVar(value=False)
        
        # 运行状态
        self.is_running = False
        self.log_queue = queue.Queue()
        
        # 构建界面
        self.build_ui()
        
        # 启动日志更新
        self.update_log()
    
    def build_ui(self):
        """构建用户界面"""
        # 创建主框架
        main_frame = ttk.Frame(self.root, padding="10")
        main_frame.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S))
        
        self.root.columnconfigure(0, weight=1)
        self.root.rowconfigure(0, weight=1)
        main_frame.columnconfigure(1, weight=1)
        main_frame.rowconfigure(3, weight=1)
        
        # ===== 1. 文件选择区 =====
        file_frame = ttk.LabelFrame(main_frame, text="1. 数据输入", padding="10")
        file_frame.grid(row=0, column=0, columnspan=2, sticky=(tk.W, tk.E), pady=5)
        
        # 影像文件
        ttk.Label(file_frame, text="影像文件:").grid(row=0, column=0, sticky=tk.W, pady=2)
        ttk.Entry(file_frame, textvariable=self.image_path, width=60).grid(
            row=0, column=1, sticky=(tk.W, tk.E), padx=5
        )
        ttk.Button(file_frame, text="浏览", command=self.browse_image).grid(
            row=0, column=2, padx=5
        )
        
        # 训练样本
        ttk.Label(file_frame, text="训练样本:").grid(row=1, column=0, sticky=tk.W, pady=2)
        ttk.Entry(file_frame, textvariable=self.train_shp_path, width=60).grid(
            row=1, column=1, sticky=(tk.W, tk.E), padx=5
        )
        ttk.Button(file_frame, text="浏览", command=self.browse_train_shp).grid(
            row=1, column=2, padx=5
        )
        
        # 验证样本
        ttk.Label(file_frame, text="验证样本:").grid(row=2, column=0, sticky=tk.W, pady=2)
        ttk.Entry(file_frame, textvariable=self.val_shp_path, width=60).grid(
            row=2, column=1, sticky=(tk.W, tk.E), padx=5
        )
        ttk.Button(file_frame, text="浏览", command=self.browse_val_shp).grid(
            row=2, column=2, padx=5
        )
        
        # 输出目录
        ttk.Label(file_frame, text="输出目录:").grid(row=3, column=0, sticky=tk.W, pady=2)
        ttk.Entry(file_frame, textvariable=self.output_dir, width=60).grid(
            row=3, column=1, sticky=(tk.W, tk.E), padx=5
        )
        ttk.Button(file_frame, text="浏览", command=self.browse_output).grid(
            row=3, column=2, padx=5
        )
        
        file_frame.columnconfigure(1, weight=1)
        
        # ===== 2. 参数设置区 =====
        param_frame = ttk.LabelFrame(main_frame, text="2. 参数配置", padding="10")
        param_frame.grid(row=1, column=0, sticky=(tk.W, tk.E, tk.N), pady=5, padx=(0, 5))
        
        # 属性字段
        ttk.Label(param_frame, text="类别编号字段:").grid(row=0, column=0, sticky=tk.W, pady=2)
        ttk.Entry(param_frame, textvariable=self.class_attr, width=15).grid(
            row=0, column=1, sticky=tk.W, padx=5
        )
        
        ttk.Label(param_frame, text="类别名称字段:").grid(row=1, column=0, sticky=tk.W, pady=2)
        ttk.Entry(param_frame, textvariable=self.name_attr, width=15).grid(
            row=1, column=1, sticky=tk.W, padx=5
        )
        
        # 其他参数
        ttk.Label(param_frame, text="树模型数量:").grid(row=2, column=0, sticky=tk.W, pady=2)
        ttk.Spinbox(param_frame, from_=10, to=500, textvariable=self.n_estimators, 
                   width=13).grid(row=2, column=1, sticky=tk.W, padx=5)
        
        ttk.Label(param_frame, text="分块大小:").grid(row=3, column=0, sticky=tk.W, pady=2)
        ttk.Spinbox(param_frame, from_=256, to=2048, increment=256, 
                   textvariable=self.block_size, width=13).grid(
            row=3, column=1, sticky=tk.W, padx=5
        )
        
        ttk.Checkbutton(param_frame, text="忽略背景值（所有波段为0）", 
                       variable=self.ignore_background).grid(
            row=4, column=0, columnspan=2, sticky=tk.W, pady=5
        )
        
        # ===== 3. 分类器选择区 =====
        clf_frame = ttk.LabelFrame(main_frame, text="3. 分类器选择", padding="10")
        clf_frame.grid(row=1, column=1, rowspan=2, sticky=(tk.W, tk.E, tk.N, tk.S), pady=5)
        
        # 快捷按钮
        btn_frame = ttk.Frame(clf_frame)
        btn_frame.grid(row=0, column=0, columnspan=3, sticky=(tk.W, tk.E), pady=(0, 5))
        ttk.Button(btn_frame, text="全选", command=self.select_all_classifiers).pack(
            side=tk.LEFT, padx=2
        )
        ttk.Button(btn_frame, text="全不选", command=self.deselect_all_classifiers).pack(
            side=tk.LEFT, padx=2
        )
        ttk.Button(btn_frame, text="推荐选择", command=self.select_recommended).pack(
            side=tk.LEFT, padx=2
        )
        
        # 分类器复选框
        all_classifiers = self.backend.get_all_classifiers()
        row = 1
        col = 0
        for code, (_, name, desc, _) in all_classifiers.items():
            cb = ttk.Checkbutton(clf_frame, text=f"{name} ({code})", 
                               variable=self.classifier_vars[code])
            cb.grid(row=row, column=col, sticky=tk.W, pady=2, padx=5)
            
            col += 1
            if col >= 3:
                col = 0
                row += 1
        
        # ===== 4. 控制按钮区 =====
        control_frame = ttk.LabelFrame(main_frame, text="4. 运行控制", padding="10")
        control_frame.grid(row=2, column=0, sticky=(tk.W, tk.E), pady=5, padx=(0, 5))
        
        self.start_btn = ttk.Button(control_frame, text="开始分类", 
                                    command=self.start_classification, width=15)
        self.start_btn.grid(row=0, column=0, padx=5, pady=5)
        
        self.stop_btn = ttk.Button(control_frame, text="停止", 
                                   command=self.stop_classification, 
                                   state=tk.DISABLED, width=15)
        self.stop_btn.grid(row=0, column=1, padx=5, pady=5)
        
        ttk.Button(control_frame, text="打开结果目录", 
                  command=self.open_result_dir, width=15).grid(
            row=0, column=2, padx=5, pady=5
        )
        
        ttk.Button(control_frame, text="查看对比报告", 
                  command=self.view_report, width=15).grid(
            row=0, column=3, padx=5, pady=5
        )
        
        # 进度条
        ttk.Label(control_frame, text="进度:").grid(row=1, column=0, sticky=tk.W, pady=2)
        self.progress_var = tk.DoubleVar()
        self.progress_bar = ttk.Progressbar(control_frame, variable=self.progress_var, 
                                           maximum=100, length=400)
        self.progress_bar.grid(row=1, column=1, columnspan=3, sticky=(tk.W, tk.E), 
                              padx=5, pady=2)
        
        control_frame.columnconfigure(3, weight=1)
        
        # ===== 5. 日志输出区 =====
        log_frame = ttk.LabelFrame(main_frame, text="5. 运行日志", padding="10")
        log_frame.grid(row=3, column=0, columnspan=2, sticky=(tk.W, tk.E, tk.N, tk.S), pady=5)
        
        self.log_text = scrolledtext.ScrolledText(log_frame, wrap=tk.WORD, 
                                                  height=20, width=100)
        self.log_text.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S))
        
        log_frame.columnconfigure(0, weight=1)
        log_frame.rowconfigure(0, weight=1)
        
        # 状态栏
        self.status_var = tk.StringVar(value="就绪")
        status_bar = ttk.Label(main_frame, textvariable=self.status_var, 
                              relief=tk.SUNKEN, anchor=tk.W)
        status_bar.grid(row=4, column=0, columnspan=2, sticky=(tk.W, tk.E), pady=(5, 0))
    
    # ===== 文件浏览函数 =====
    def browse_image(self):
        filename = filedialog.askopenfilename(
            title="选择影像文件",
            filetypes=[("GeoTIFF", "*.tif *.tiff"), ("所有文件", "*.*")]
        )
        if filename:
            self.image_path.set(filename)
    
    def browse_train_shp(self):
        filename = filedialog.askopenfilename(
            title="选择训练样本",
            filetypes=[("Shapefile", "*.shp"), ("所有文件", "*.*")]
        )
        if filename:
            self.train_shp_path.set(filename)
    
    def browse_val_shp(self):
        filename = filedialog.askopenfilename(
            title="选择验证样本",
            filetypes=[("Shapefile", "*.shp"), ("所有文件", "*.*")]
        )
        if filename:
            self.val_shp_path.set(filename)
    
    def browse_output(self):
        dirname = filedialog.askdirectory(title="选择输出目录")
        if dirname:
            self.output_dir.set(dirname)
    
    # ===== 分类器选择函数 =====
    def select_all_classifiers(self):
        for var in self.classifier_vars.values():
            var.set(True)
    
    def deselect_all_classifiers(self):
        for var in self.classifier_vars.values():
            var.set(False)
    
    def select_recommended(self):
        """选择推荐的分类器"""
        recommended = ["rf", "xgb", "svm", "et", "gb"]
        for code, var in self.classifier_vars.items():
            var.set(code in recommended)
    
    # ===== 日志相关函数 =====
    def log(self, message):
        """添加日志消息"""
        self.log_queue.put(message)
    
    def update_log(self):
        """更新日志显示"""
        try:
            while True:
                message = self.log_queue.get_nowait()
                self.log_text.insert(tk.END, message + "\n")
                self.log_text.see(tk.END)
        except queue.Empty:
            pass
        
        self.root.after(100, self.update_log)
    
    # ===== 主要功能函数 =====
    def start_classification(self):
        """开始分类"""
        # 检查输入
        if not self.image_path.get():
            messagebox.showerror("错误", "请选择影像文件！")
            return
        
        if not self.train_shp_path.get():
            messagebox.showerror("错误", "请选择训练样本！")
            return
        
        # 检查是否至少选择了一个分类器
        selected_classifiers = [code for code, var in self.classifier_vars.items() 
                               if var.get()]
        if not selected_classifiers:
            messagebox.showerror("错误", "请至少选择一个分类器！")
            return
        
        # 禁用开始按钮，启用停止按钮
        self.start_btn.config(state=tk.DISABLED)
        self.stop_btn.config(state=tk.NORMAL)
        self.is_running = True
        
        # 清空日志
        self.log_text.delete(1.0, tk.END)
        self.log("="*80)
        self.log("开始分类任务...")
        self.log("="*80)
        
        # 在新线程中运行分类
        thread = threading.Thread(target=self.run_classification, 
                                 args=(selected_classifiers,))
        thread.daemon = True
        thread.start()
    
    def stop_classification(self):
        """停止分类"""
        self.is_running = False
        self.log("\n用户请求停止...")
        self.status_var.set("已停止")
    
    def run_classification(self, selected_classifiers):
        """执行分类（在后台线程中运行）"""
        try:
            # 创建输出目录
            out_dir = Path(self.output_dir.get())
            out_dir.mkdir(exist_ok=True)
            
            # 1. 读取影像
            self.log(f"\n正在读取影像: {self.image_path.get()}")
            self.status_var.set("读取影像...")
            img = rxr.open_rasterio(self.image_path.get(), masked=True)
            self.log(f"影像尺寸: {img.shape}, 波段数: {img.rio.count}")
            
            if not self.is_running:
                return
            
            # 2. 读取类别信息
            self.log("\n正在读取类别信息...")
            class_names, class_colors, _ = self.backend.get_class_info_from_shp(
                self.train_shp_path.get(), 
                self.class_attr.get(), 
                self.name_attr.get()
            )
            self.log(f"检测到类别: {list(class_names.values())}")
            
            # 3. 提取训练样本
            self.log("\n正在处理训练样本...")
            self.status_var.set("处理训练样本...")
            train_mask = self.backend.rasterize_samples(
                self.train_shp_path.get(), img, self.class_attr.get()
            )
            X_train, y_train, n_nan, n_inf = self.backend.extract_samples(
                img, train_mask, ignore_background=self.ignore_background.get()
            )
            self.log(f"训练样本数: {len(y_train)}")
            if n_nan > 0:
                self.log(f"  已移除 {n_nan} 个包含NaN的样本")
            if n_inf > 0:
                self.log(f"  已移除 {n_inf} 个包含Inf的样本")
            
            if not self.is_running:
                return
            
            # 4. 提取验证样本（如果有）
            val_exists = os.path.exists(self.val_shp_path.get())
            if val_exists:
                self.log("\n正在处理验证样本...")
                val_mask = self.backend.rasterize_samples(
                    self.val_shp_path.get(), img, self.class_attr.get()
                )
                
                if self.ignore_background.get():
                    background_mask = self.backend.get_background_mask(img)
                    valid_val = (val_mask > 0) & (~background_mask)
                else:
                    valid_val = val_mask > 0
                
                yv_true = val_mask[valid_val]
                self.log(f"验证样本数: {len(yv_true)}")
            
            # 5. 对每个选择的分类器进行训练和评估
            all_classifiers = self.backend.get_all_classifiers(self.n_estimators.get())
            comparison_results = []
            
            for i, clf_code in enumerate(selected_classifiers):
                if not self.is_running:
                    break
                
                clf, clf_name, clf_desc, needs_encoding = all_classifiers[clf_code]
                
                self.log(f"\n{'='*80}")
                self.log(f"[{i+1}/{len(selected_classifiers)}] 正在测试: {clf_name} ({clf_code})")
                self.log(f"{'='*80}")
                self.status_var.set(f"训练 {clf_name}...")
                
                # 创建分类器目录
                clf_dir = out_dir / clf_code
                clf_dir.mkdir(exist_ok=True)
                
                try:
                    # ===== 标签编码（如果需要）=====
                    label_encoder = None
                    X_train_use = X_train.copy()
                    y_train_use = y_train.copy()
                    
                    if needs_encoding:
                        self.log("  使用标签编码...")
                        label_encoder = LabelEncoder()
                        y_train_use = label_encoder.fit_transform(y_train)
                    
                    # 训练
                    self.log("开始训练...")
                    train_start = time.time()
                    clf.fit(X_train_use, y_train_use)
                    train_time = time.time() - train_start
                    self.log(f"训练完成，耗时: {train_time:.2f} 秒")
                    
                    # 训练集精度
                    y_train_pred = clf.predict(X_train_use)
                    
                    # 如果使用了标签编码，需要反向转换
                    if label_encoder is not None:
                        y_train_pred = label_encoder.inverse_transform(y_train_pred)
                    
                    train_metrics = self.backend.calculate_metrics(y_train, y_train_pred)
                    self.log(f"训练集精度: {train_metrics['overall_accuracy']:.4f}, "
                           f"Kappa: {train_metrics['kappa']:.4f}")
                    
                    if not self.is_running:
                        break
                    
                    # 预测整幅影像
                    self.log("开始预测整幅影像...")
                    self.status_var.set(f"预测 {clf_name}...")
                    
                    pred_start = time.time()
                    classified_path = clf_dir / f"classified_{clf_code}.tif"
                    
                    def update_progress(progress):
                        self.progress_var.set(progress)
                    
                    self.backend.predict_by_block(
                        clf, img, classified_path, 
                        block_size=self.block_size.get(),
                        ignore_background=self.ignore_background.get(),
                        progress_callback=update_progress,
                        label_encoder=label_encoder
                    )
                    
                    pred_time = time.time() - pred_start
                    self.log(f"预测完成，耗时: {pred_time:.2f} 秒")
                    
                    # 验证集精度
                    val_metrics = {'overall_accuracy': np.nan, 'kappa': np.nan, 'f1_macro': np.nan}
                    if val_exists:
                        self.log("评估验证集...")
                        with rxr.open_rasterio(classified_path) as pred_img:
                            pred_arr = pred_img.values.squeeze()
                        
                        yv_pred = pred_arr[valid_val]
                        val_metrics = self.backend.calculate_metrics(yv_true, yv_pred)
                        self.log(f"验证集精度: {val_metrics['overall_accuracy']:.4f}, "
                               f"Kappa: {val_metrics['kappa']:.4f}")
                    
                    # 记录结果
                    result = {
                        '分类器代码': clf_code,
                        '分类器名称': clf_name,
                        '训练集精度': train_metrics['overall_accuracy'],
                        '训练集Kappa': train_metrics['kappa'],
                        '训练集F1': train_metrics['f1_macro'],
                        '验证集精度': val_metrics['overall_accuracy'],
                        '验证集Kappa': val_metrics['kappa'],
                        '验证集F1': val_metrics['f1_macro'],
                        '训练时间(秒)': train_time,
                        '预测时间(秒)': pred_time,
                    }
                    comparison_results.append(result)
                    
                    self.log(f"✓ {clf_name} 完成")
                    
                except Exception as e:
                    self.log(f"✗ {clf_name} 失败: {str(e)}")
                    import traceback
                    self.log(traceback.format_exc())
                    continue
                
                # 更新总进度
                self.progress_var.set((i + 1) / len(selected_classifiers) * 100)
            
            # 6. 生成对比报告
            if comparison_results and self.is_running:
                self.log(f"\n{'='*80}")
                self.log("生成对比报告...")
                self.status_var.set("生成报告...")
                
                comparison_df = pd.DataFrame(comparison_results)
                comparison_df.to_csv(out_dir / "classifier_comparison.csv", 
                                   index=False, encoding='utf-8-sig')
                
                # 生成简要报告
                with open(out_dir / "comparison_summary.txt", 'w', encoding='utf-8') as f:
                    f.write("分类器性能对比摘要\n")
                    f.write("="*60 + "\n\n")
                    
                    f.write(f"成功完成: {len(comparison_results)}/{len(selected_classifiers)} 个分类器\n\n")
                    
                    sorted_df = comparison_df.sort_values('验证集精度', ascending=False)
                    f.write("验证集精度排名:\n")
                    f.write("-"*60 + "\n")
                    for idx, (_, row) in enumerate(sorted_df.iterrows(), 1):
                        f.write(f"{idx}. {row['分类器名称']:12s} - "
                               f"精度: {row['验证集精度']:.4f}, "
                               f"Kappa: {row['验证集Kappa']:.4f}, "
                               f"F1: {row['验证集F1']:.4f}\n")
                    
                    f.write("\n" + "-"*60 + "\n")
                    f.write("训练时间排名:\n")
                    f.write("-"*60 + "\n")
                    sorted_time = comparison_df.sort_values('训练时间(秒)')
                    for idx, (_, row) in enumerate(sorted_time.iterrows(), 1):
                        f.write(f"{idx}. {row['分类器名称']:12s} - "
                               f"{row['训练时间(秒)']:.2f} 秒\n")
                
                self.log("\n✓ 所有任务完成！")
                self.log(f"结果保存至: {out_dir.absolute()}")
                self.log(f"成功: {len(comparison_results)}/{len(selected_classifiers)} 个分类器")
                self.status_var.set("完成")
                
                messagebox.showinfo("完成", 
                    f"分类任务完成！\n成功: {len(comparison_results)}/{len(selected_classifiers)}\n结果保存至: {out_dir}")
            
        except Exception as e:
            self.log(f"\n错误: {str(e)}")
            import traceback
            self.log(traceback.format_exc())
            messagebox.showerror("错误", f"发生错误:\n{str(e)}")
            self.status_var.set("错误")
        
        finally:
            # 恢复按钮状态
            self.start_btn.config(state=tk.NORMAL)
            self.stop_btn.config(state=tk.DISABLED)
            self.progress_var.set(0)
            self.is_running = False
    
    def open_result_dir(self):
        """打开结果目录"""
        out_dir = Path(self.output_dir.get())
        if out_dir.exists():
            import subprocess
            import platform
            
            if platform.system() == "Windows":
                os.startfile(out_dir)
            elif platform.system() == "Darwin":  # macOS
                subprocess.Popen(["open", out_dir])
            else:  # Linux
                subprocess.Popen(["xdg-open", out_dir])
        else:
            messagebox.showwarning("警告", "结果目录不存在！")
    
    def view_report(self):
        """查看对比报告"""
        report_file = Path(self.output_dir.get()) / "comparison_summary.txt"
        if report_file.exists():
            # 创建新窗口显示报告
            report_window = tk.Toplevel(self.root)
            report_window.title("分类器对比报告")
            report_window.geometry("800x600")
            
            text_widget = scrolledtext.ScrolledText(report_window, wrap=tk.WORD)
            text_widget.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
            
            with open(report_file, 'r', encoding='utf-8') as f:
                content = f.read()
                text_widget.insert(1.0, content)
            
            text_widget.config(state=tk.DISABLED)
        else:
            messagebox.showwarning("警告", "报告文件不存在！请先运行分类。")

# ==================== 主程序入口 ====================
def main():
    root = tk.Tk()
    app = ClassificationGUI(root)
    root.mainloop()

if __name__ == "__main__":
    main()

2025-10-16 09:45:36,101 [INFO] GDAL signalled an error: err_no=1, msg='Deleting results_gui\\et\\classified_et.tif failed:\nPermission denied'


### 优化版本

In [1]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
遥感影像监督分类系统 - GUI版本 (性能优化版)
支持多分类器对比和可视化
优化：数据采样、特征缩放、参数调优、并行处理
"""

import os
import sys
import time
import threading
import queue
from pathlib import Path
import tkinter as tk
from tkinter import ttk, filedialog, messagebox, scrolledtext
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure
import matplotlib.colors as mcolors
import seaborn as sns
import geopandas as gpd
import rioxarray as rxr
import xarray as xr
from rasterio import features
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.ensemble import (RandomForestClassifier, GradientBoostingClassifier, 
                              AdaBoostClassifier, ExtraTreesClassifier)
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import (confusion_matrix, accuracy_score, classification_report, 
                           cohen_kappa_score, precision_score, recall_score, f1_score)
from tqdm import tqdm
from scipy import ndimage
from skimage import morphology
import warnings
warnings.filterwarnings('ignore')

plt.rcParams["font.sans-serif"] = ["SimHei", "DejaVu Sans"]
plt.rcParams["axes.unicode_minus"] = False

# ==================== 后端处理类 ====================
class ClassificationBackend:
    """分类处理后端（性能优化版）"""
    
    def __init__(self):
        self.BACKGROUND_VALUE = 0
        self.RANDOM_STATE = 42
        
        # 预定义颜色
        self.LANDUSE_COLORS = {
            "水体": "lightblue", "河流": "blue", "湖泊": "deepskyblue",
            "植被": "forestgreen", "森林": "darkgreen", "草地": "limegreen",
            "农田": "yellowgreen", "耕地": "olivedrab",
            "建筑": "gray", "城市": "dimgray", "居民地": "slategray",
            "裸地": "tan", "沙地": "wheat", "其他": "darkred"
        }
        
        self.COLOR_PALETTE = ['forestgreen', 'lightblue', 'gray', 'tan', 'yellow', 
                             'darkred', 'purple', 'orange', 'pink', 'brown']
    
    def get_all_classifiers(self, n_estimators=100, fast_mode=False):
        """
        获取所有分类器（优化参数）
        fast_mode: 快速模式，使用更少的估计器和更简单的参数
        """
        # 根据模式调整参数
        if fast_mode:
            n_est = min(50, n_estimators)
            max_depth = 10
            max_iter = 200
        else:
            n_est = n_estimators
            max_depth = 20
            max_iter = 500
        
        classifiers = {
            "rf": (
                RandomForestClassifier(
                    n_estimators=n_est, 
                    n_jobs=-1, 
                    random_state=self.RANDOM_STATE, 
                    verbose=0,
                    max_depth=max_depth,
                    min_samples_split=5,
                    min_samples_leaf=2,
                    max_features='sqrt'  # 减少特征数量
                ),
                "随机森林", "Random Forest", False, False
            ),
            "svm": (
                SVC(
                    kernel="rbf", 
                    C=1.0,
                    gamma='scale',
                    cache_size=500,  # 增加缓存
                    probability=True, 
                    random_state=self.RANDOM_STATE
                ),
                "支持向量机", "SVM", False, True  # 需要特征缩放
            ),
            "dt": (
                DecisionTreeClassifier(
                    random_state=self.RANDOM_STATE,
                    max_depth=max_depth,
                    min_samples_split=5,
                    min_samples_leaf=2
                ),
                "决策树", "Decision Tree", False, False
            ),
            "knn": (
                KNeighborsClassifier(
                    n_neighbors=5,
                    n_jobs=-1,
                    algorithm='ball_tree',  # 更快的算法
                    leaf_size=30
                ),
                "K近邻", "KNN", False, True  # 需要特征缩放
            ),
            "nb": (
                GaussianNB(),
                "朴素贝叶斯", "Naive Bayes", False, False
            ),
            "gb": (
                GradientBoostingClassifier(
                    n_estimators=n_est,
                    learning_rate=0.1,
                    max_depth=5,  # 减小深度提高速度
                    random_state=self.RANDOM_STATE, 
                    verbose=0,
                    subsample=0.8  # 使用子采样
                ),
                "梯度提升", "Gradient Boosting", False, False
            ),
            "ada": (
                AdaBoostClassifier(
                    n_estimators=n_est,
                    learning_rate=1.0,
                    random_state=self.RANDOM_STATE,
                    algorithm='SAMME.R'  # 更快的算法
                ),
                "AdaBoost", "AdaBoost", False, False
            ),
            "et": (
                ExtraTreesClassifier(
                    n_estimators=n_est,
                    n_jobs=-1,
                    random_state=self.RANDOM_STATE,
                    verbose=0,
                    max_depth=max_depth,
                    min_samples_split=5,
                    max_features='sqrt'
                ),
                "极端随机树", "Extra Trees", False, False
            ),
            "lr": (
                LogisticRegression(
                    max_iter=max_iter,
                    n_jobs=-1,
                    random_state=self.RANDOM_STATE,
                    verbose=0,
                    solver='lbfgs',  # 快速求解器
                    multi_class='multinomial'
                ),
                "逻辑回归", "Logistic Regression", False, True  # 需要特征缩放
            ),
            "mlp": (
                MLPClassifier(
                    hidden_layer_sizes=(100, 50),
                    max_iter=max_iter,
                    random_state=self.RANDOM_STATE,
                    verbose=False,
                    early_stopping=True,
                    validation_fraction=0.1,
                    n_iter_no_change=10,  # 早停
                    learning_rate='adaptive'
                ),
                "神经网络", "MLP", False, True  # 需要特征缩放
            ),
        }
        
        # XGBoost
        try:
            from xgboost import XGBClassifier
            classifiers["xgb"] = (
                XGBClassifier(
                    n_estimators=n_est,
                    learning_rate=0.1,
                    max_depth=6,  # 减小深度
                    n_jobs=-1,
                    random_state=self.RANDOM_STATE,
                    verbosity=0,
                    tree_method='hist',  # 使用直方图算法，更快
                    subsample=0.8,
                    colsample_bytree=0.8
                ),
                "XGBoost", "XGBoost", True, False  # 需要标签编码
            )
        except ImportError:
            pass
        
        # LightGBM
        try:
            from lightgbm import LGBMClassifier
            classifiers["lgb"] = (
                LGBMClassifier(
                    n_estimators=n_est,
                    learning_rate=0.1,
                    max_depth=max_depth,
                    n_jobs=-1,
                    random_state=self.RANDOM_STATE,
                    verbose=-1,
                    num_leaves=31,
                    subsample=0.8,
                    colsample_bytree=0.8
                ),
                "LightGBM", "LightGBM", False, False
            )
        except ImportError:
            pass
        
        return classifiers
    
    def get_background_mask(self, image):
        """获取背景掩膜"""
        data = image.values
        background_mask = np.all(data == 0, axis=0)
        return background_mask
    
    def get_class_info_from_shp(self, shp_path, class_attr, name_attr):
        """从shp文件获取类别信息"""
        gdf = gpd.read_file(shp_path)
        
        if name_attr not in gdf.columns:
            gdf[name_attr] = gdf[class_attr].apply(lambda x: f"Class_{x}")
        
        class_info = gdf[[class_attr, name_attr]].drop_duplicates()
        class_names = dict(zip(class_info[class_attr], class_info[name_attr]))
        
        class_colors = {}
        for i, (class_id, class_name) in enumerate(class_names.items()):
            color_found = False
            for key, color in self.LANDUSE_COLORS.items():
                if key in class_name:
                    class_colors[class_id] = color
                    color_found = True
                    break
            if not color_found:
                class_colors[class_id] = self.COLOR_PALETTE[i % len(self.COLOR_PALETTE)]
        
        return class_names, class_colors, sorted(class_names.keys())
    
    def rasterize_samples(self, shp, ref_img, attr):
        """矢量栅格化"""
        gdf = gpd.read_file(shp)
        gdf = gdf.to_crs(ref_img.rio.crs)
        shapes = ((geom, value) for geom, value in zip(gdf.geometry, gdf[attr]))
        
        arr = features.rasterize(
            shapes=shapes,
            out_shape=ref_img.shape[1:],
            transform=ref_img.rio.transform(),
            fill=0,
            all_touched=True,
            dtype="uint16"
        )
        return arr
    
    def extract_samples(self, image, mask, ignore_background=True, max_samples=None):
        """
        提取样本并清理NaN值
        max_samples: 最大样本数，如果超过则进行分层采样
        """
        data = np.moveaxis(image.values, 0, -1)
        valid = mask > 0
        
        if ignore_background:
            background_mask = self.get_background_mask(image)
            valid = valid & (~background_mask)
        
        X = data[valid]
        y = mask[valid]
        
        # 清理NaN和Inf值
        nan_mask = np.isnan(X).any(axis=1)
        inf_mask = np.isinf(X).any(axis=1)
        bad_mask = nan_mask | inf_mask
        
        n_nan = np.sum(nan_mask)
        n_inf = np.sum(inf_mask)
        
        X = X[~bad_mask]
        y = y[~bad_mask]
        
        # ===== 性能优化：分层采样 =====
        n_sampled = 0
        if max_samples is not None and len(y) > max_samples:
            n_original = len(y)
            
            # 使用分层采样保持类别比例
            unique_classes, class_counts = np.unique(y, return_counts=True)
            
            if len(unique_classes) > 1:
                # 使用sklearn的分层采样
                splitter = StratifiedShuffleSplit(
                    n_splits=1, 
                    train_size=max_samples, 
                    random_state=self.RANDOM_STATE
                )
                
                sample_idx, _ = next(splitter.split(X, y))
                X = X[sample_idx]
                y = y[sample_idx]
                n_sampled = n_original - len(y)
            else:
                # 只有一个类别，随机采样
                np.random.seed(self.RANDOM_STATE)
                sample_idx = np.random.choice(len(y), max_samples, replace=False)
                X = X[sample_idx]
                y = y[sample_idx]
                n_sampled = n_original - len(y)
        
        return X, y, n_nan, n_inf, n_sampled
    
    def calculate_metrics(self, y_true, y_pred):
        """计算评价指标"""
        return {
            'overall_accuracy': accuracy_score(y_true, y_pred),
            'kappa': cohen_kappa_score(y_true, y_pred),
            'precision_macro': precision_score(y_true, y_pred, average='macro', zero_division=0),
            'recall_macro': recall_score(y_true, y_pred, average='macro', zero_division=0),
            'f1_macro': f1_score(y_true, y_pred, average='macro', zero_division=0),
        }
    
    def predict_by_block(self, model, image, out_path, block_size=512, 
                        ignore_background=True, progress_callback=None,
                        label_encoder=None, scaler=None):
        """
        分块预测（优化版）
        添加了特征缩放支持
        """
        height, width = image.shape[1], image.shape[2]
        prediction = np.zeros((height, width), dtype='uint16')
        
        if ignore_background:
            background_mask = self.get_background_mask(image)
        
        total_blocks = int(np.ceil(height / block_size))
        
        for i, y in enumerate(range(0, height, block_size)):
            h = min(block_size, height - y)
            block_data = image.isel(y=slice(y, y+h)).values
            data = np.moveaxis(block_data, 0, -1)
            original_shape = data.shape
            data_flat = data.reshape(-1, data.shape[-1])
            
            if ignore_background:
                block_bg_mask = background_mask[y:y+h, :].flatten()
                non_bg_indices = ~block_bg_mask
                
                if np.any(non_bg_indices):
                    data_to_predict = np.nan_to_num(data_flat[non_bg_indices], 
                                                   nan=0.0, posinf=0.0, neginf=0.0)
                    
                    # 如果使用了特征缩放
                    if scaler is not None:
                        data_to_predict = scaler.transform(data_to_predict)
                    
                    preds_non_bg = model.predict(data_to_predict)
                    
                    # 如果使用了标签编码，需要反向转换
                    if label_encoder is not None:
                        preds_non_bg = label_encoder.inverse_transform(preds_non_bg)
                    
                    preds_flat = np.zeros(len(data_flat), dtype='uint16')
                    preds_flat[non_bg_indices] = preds_non_bg
                    preds = preds_flat.reshape(original_shape[0], original_shape[1])
                else:
                    preds = np.zeros((original_shape[0], original_shape[1]), dtype='uint16')
            else:
                data_flat = np.nan_to_num(data_flat, nan=0.0, posinf=0.0, neginf=0.0)
                
                if scaler is not None:
                    data_flat = scaler.transform(data_flat)
                
                preds = model.predict(data_flat)
                
                if label_encoder is not None:
                    preds = label_encoder.inverse_transform(preds)
                
                preds = preds.reshape(original_shape[0], original_shape[1]).astype("uint16")
            
            prediction[y:y+h, :] = preds
            
            if progress_callback:
                progress_callback((i + 1) / total_blocks * 100)
        
        # 保存结果
        prediction_da = xr.DataArray(
            prediction,
            dims=['y', 'x'],
            coords={'y': image.coords['y'], 'x': image.coords['x']}
        )
        
        prediction_da.rio.write_crs(image.rio.crs, inplace=True)
        prediction_da.rio.write_transform(image.rio.transform(), inplace=True)
        prediction_da.rio.write_nodata(self.BACKGROUND_VALUE, inplace=True)
        
        prediction_da.rio.to_raster(out_path, driver='GTiff', dtype='uint16', 
                                    compress='lzw', tiled=True)
        return out_path

# ==================== GUI主类 ====================
class ClassificationGUI:
    """遥感影像分类GUI主界面（性能优化版）"""
    
    def __init__(self, root):
        self.root = root
        self.root.title("遥感影像监督分类系统 v2.2 (性能优化版)")
        self.root.geometry("1400x900")
        
        # 后端处理对象
        self.backend = ClassificationBackend()
        
        # 数据变量
        self.image_path = tk.StringVar()
        self.train_shp_path = tk.StringVar()
        self.val_shp_path = tk.StringVar()
        self.output_dir = tk.StringVar(value=str(Path("./results_gui")))
        
        self.class_attr = tk.StringVar(value="class")
        self.name_attr = tk.StringVar(value="name")
        self.n_estimators = tk.IntVar(value=100)
        self.block_size = tk.IntVar(value=512)
        self.ignore_background = tk.BooleanVar(value=True)
        
        # ===== 性能优化参数 =====
        self.enable_sampling = tk.BooleanVar(value=True)  # 启用采样
        self.max_samples = tk.IntVar(value=50000)  # 最大样本数
        self.fast_mode = tk.BooleanVar(value=False)  # 快速模式
        
        # 分类器选择
        self.classifier_vars = {}
        all_classifiers = self.backend.get_all_classifiers()
        for code, (_, name, _, _, _) in all_classifiers.items():
            self.classifier_vars[code] = tk.BooleanVar(value=False)
        
        # 运行状态
        self.is_running = False
        self.log_queue = queue.Queue()
        
        # 构建界面
        self.build_ui()
        
        # 启动日志更新
        self.update_log()
    
    def build_ui(self):
        """构建用户界面"""
        # 创建主框架
        main_frame = ttk.Frame(self.root, padding="10")
        main_frame.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S))
        
        self.root.columnconfigure(0, weight=1)
        self.root.rowconfigure(0, weight=1)
        main_frame.columnconfigure(1, weight=1)
        main_frame.rowconfigure(3, weight=1)
        
        # ===== 1. 文件选择区 =====
        file_frame = ttk.LabelFrame(main_frame, text="1. 数据输入", padding="10")
        file_frame.grid(row=0, column=0, columnspan=2, sticky=(tk.W, tk.E), pady=5)
        
        # 影像文件
        ttk.Label(file_frame, text="影像文件:").grid(row=0, column=0, sticky=tk.W, pady=2)
        ttk.Entry(file_frame, textvariable=self.image_path, width=60).grid(
            row=0, column=1, sticky=(tk.W, tk.E), padx=5
        )
        ttk.Button(file_frame, text="浏览", command=self.browse_image).grid(
            row=0, column=2, padx=5
        )
        
        # 训练样本
        ttk.Label(file_frame, text="训练样本:").grid(row=1, column=0, sticky=tk.W, pady=2)
        ttk.Entry(file_frame, textvariable=self.train_shp_path, width=60).grid(
            row=1, column=1, sticky=(tk.W, tk.E), padx=5
        )
        ttk.Button(file_frame, text="浏览", command=self.browse_train_shp).grid(
            row=1, column=2, padx=5
        )
        
        # 验证样本
        ttk.Label(file_frame, text="验证样本:").grid(row=2, column=0, sticky=tk.W, pady=2)
        ttk.Entry(file_frame, textvariable=self.val_shp_path, width=60).grid(
            row=2, column=1, sticky=(tk.W, tk.E), padx=5
        )
        ttk.Button(file_frame, text="浏览", command=self.browse_val_shp).grid(
            row=2, column=2, padx=5
        )
        
        # 输出目录
        ttk.Label(file_frame, text="输出目录:").grid(row=3, column=0, sticky=tk.W, pady=2)
        ttk.Entry(file_frame, textvariable=self.output_dir, width=60).grid(
            row=3, column=1, sticky=(tk.W, tk.E), padx=5
        )
        ttk.Button(file_frame, text="浏览", command=self.browse_output).grid(
            row=3, column=2, padx=5
        )
        
        file_frame.columnconfigure(1, weight=1)
        
        # ===== 2. 参数设置区 =====
        param_frame = ttk.LabelFrame(main_frame, text="2. 参数配置", padding="10")
        param_frame.grid(row=1, column=0, sticky=(tk.W, tk.E, tk.N), pady=5, padx=(0, 5))
        
        # 属性字段
        ttk.Label(param_frame, text="类别编号字段:").grid(row=0, column=0, sticky=tk.W, pady=2)
        ttk.Entry(param_frame, textvariable=self.class_attr, width=15).grid(
            row=0, column=1, sticky=tk.W, padx=5
        )
        
        ttk.Label(param_frame, text="类别名称字段:").grid(row=1, column=0, sticky=tk.W, pady=2)
        ttk.Entry(param_frame, textvariable=self.name_attr, width=15).grid(
            row=1, column=1, sticky=tk.W, padx=5
        )
        
        # 其他参数
        ttk.Label(param_frame, text="树模型数量:").grid(row=2, column=0, sticky=tk.W, pady=2)
        ttk.Spinbox(param_frame, from_=10, to=500, textvariable=self.n_estimators, 
                   width=13).grid(row=2, column=1, sticky=tk.W, padx=5)
        
        ttk.Label(param_frame, text="分块大小:").grid(row=3, column=0, sticky=tk.W, pady=2)
        ttk.Spinbox(param_frame, from_=256, to=2048, increment=256, 
                   textvariable=self.block_size, width=13).grid(
            row=3, column=1, sticky=tk.W, padx=5
        )
        
        # ===== 性能优化选项 =====
        ttk.Separator(param_frame, orient='horizontal').grid(
            row=4, column=0, columnspan=2, sticky=(tk.W, tk.E), pady=5
        )
        
        ttk.Label(param_frame, text="性能优化:", font=('', 9, 'bold')).grid(
            row=5, column=0, columnspan=2, sticky=tk.W, pady=2
        )
        
        sample_frame = ttk.Frame(param_frame)
        sample_frame.grid(row=6, column=0, columnspan=2, sticky=(tk.W, tk.E))
        
        ttk.Checkbutton(sample_frame, text="启用采样", 
                       variable=self.enable_sampling,
                       command=self.toggle_sampling).pack(side=tk.LEFT)
        
        ttk.Label(sample_frame, text="  最大样本数:").pack(side=tk.LEFT, padx=(10, 0))
        self.max_samples_spinbox = ttk.Spinbox(
            sample_frame, from_=10000, to=200000, increment=10000,
            textvariable=self.max_samples, width=10
        )
        self.max_samples_spinbox.pack(side=tk.LEFT, padx=5)
        
        ttk.Checkbutton(param_frame, text="快速模式（减少模型复杂度）", 
                       variable=self.fast_mode).grid(
            row=7, column=0, columnspan=2, sticky=tk.W, pady=2
        )
        
        ttk.Checkbutton(param_frame, text="忽略背景值（所有波段为0）", 
                       variable=self.ignore_background).grid(
            row=8, column=0, columnspan=2, sticky=tk.W, pady=2
        )
        
        # ===== 3. 分类器选择区 =====
        clf_frame = ttk.LabelFrame(main_frame, text="3. 分类器选择", padding="10")
        clf_frame.grid(row=1, column=1, rowspan=2, sticky=(tk.W, tk.E, tk.N, tk.S), pady=5)
        
        # 快捷按钮
        btn_frame = ttk.Frame(clf_frame)
        btn_frame.grid(row=0, column=0, columnspan=3, sticky=(tk.W, tk.E), pady=(0, 5))
        ttk.Button(btn_frame, text="全选", command=self.select_all_classifiers).pack(
            side=tk.LEFT, padx=2
        )
        ttk.Button(btn_frame, text="全不选", command=self.deselect_all_classifiers).pack(
            side=tk.LEFT, padx=2
        )
        ttk.Button(btn_frame, text="推荐选择", command=self.select_recommended).pack(
            side=tk.LEFT, padx=2
        )
        ttk.Button(btn_frame, text="快速分类器", command=self.select_fast).pack(
            side=tk.LEFT, padx=2
        )
        
        # 分类器复选框
        all_classifiers = self.backend.get_all_classifiers()
        row = 1
        col = 0
        for code, (_, name, _, _, _) in all_classifiers.items():
            cb = ttk.Checkbutton(clf_frame, text=f"{name} ({code})", 
                               variable=self.classifier_vars[code])
            cb.grid(row=row, column=col, sticky=tk.W, pady=2, padx=5)
            
            col += 1
            if col >= 3:
                col = 0
                row += 1
        
        # ===== 4. 控制按钮区 =====
        control_frame = ttk.LabelFrame(main_frame, text="4. 运行控制", padding="10")
        control_frame.grid(row=2, column=0, sticky=(tk.W, tk.E), pady=5, padx=(0, 5))
        
        self.start_btn = ttk.Button(control_frame, text="开始分类", 
                                    command=self.start_classification, width=15)
        self.start_btn.grid(row=0, column=0, padx=5, pady=5)
        
        self.stop_btn = ttk.Button(control_frame, text="停止", 
                                   command=self.stop_classification, 
                                   state=tk.DISABLED, width=15)
        self.stop_btn.grid(row=0, column=1, padx=5, pady=5)
        
        ttk.Button(control_frame, text="打开结果目录", 
                  command=self.open_result_dir, width=15).grid(
            row=0, column=2, padx=5, pady=5
        )
        
        ttk.Button(control_frame, text="查看对比报告", 
                  command=self.view_report, width=15).grid(
            row=0, column=3, padx=5, pady=5
        )
        
        # 进度条
        ttk.Label(control_frame, text="进度:").grid(row=1, column=0, sticky=tk.W, pady=2)
        self.progress_var = tk.DoubleVar()
        self.progress_bar = ttk.Progressbar(control_frame, variable=self.progress_var, 
                                           maximum=100, length=400)
        self.progress_bar.grid(row=1, column=1, columnspan=3, sticky=(tk.W, tk.E), 
                              padx=5, pady=2)
        
        control_frame.columnconfigure(3, weight=1)
        
        # ===== 5. 日志输出区 =====
        log_frame = ttk.LabelFrame(main_frame, text="5. 运行日志", padding="10")
        log_frame.grid(row=3, column=0, columnspan=2, sticky=(tk.W, tk.E, tk.N, tk.S), pady=5)
        
        self.log_text = scrolledtext.ScrolledText(log_frame, wrap=tk.WORD, 
                                                  height=20, width=100)
        self.log_text.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S))
        
        log_frame.columnconfigure(0, weight=1)
        log_frame.rowconfigure(0, weight=1)
        
        # 状态栏
        self.status_var = tk.StringVar(value="就绪")
        status_bar = ttk.Label(main_frame, textvariable=self.status_var, 
                              relief=tk.SUNKEN, anchor=tk.W)
        status_bar.grid(row=4, column=0, columnspan=2, sticky=(tk.W, tk.E), pady=(5, 0))
    
    def toggle_sampling(self):
        """切换采样功能"""
        if self.enable_sampling.get():
            self.max_samples_spinbox.config(state=tk.NORMAL)
        else:
            self.max_samples_spinbox.config(state=tk.DISABLED)
    
    # ===== 文件浏览函数 =====
    def browse_image(self):
        filename = filedialog.askopenfilename(
            title="选择影像文件",
            filetypes=[("GeoTIFF", "*.tif *.tiff"), ("所有文件", "*.*")]
        )
        if filename:
            self.image_path.set(filename)
    
    def browse_train_shp(self):
        filename = filedialog.askopenfilename(
            title="选择训练样本",
            filetypes=[("Shapefile", "*.shp"), ("所有文件", "*.*")]
        )
        if filename:
            self.train_shp_path.set(filename)
    
    def browse_val_shp(self):
        filename = filedialog.askopenfilename(
            title="选择验证样本",
            filetypes=[("Shapefile", "*.shp"), ("所有文件", "*.*")]
        )
        if filename:
            self.val_shp_path.set(filename)
    
    def browse_output(self):
        dirname = filedialog.askdirectory(title="选择输出目录")
        if dirname:
            self.output_dir.set(dirname)
    
    # ===== 分类器选择函数 =====
    def select_all_classifiers(self):
        for var in self.classifier_vars.values():
            var.set(True)
    
    def deselect_all_classifiers(self):
        for var in self.classifier_vars.values():
            var.set(False)
    
    def select_recommended(self):
        """选择推荐的分类器（精度优先）"""
        recommended = ["rf", "xgb", "svm", "et", "gb", "lgb"]
        for code, var in self.classifier_vars.items():
            var.set(code in recommended)
    
    def select_fast(self):
        """选择快速分类器（速度优先）"""
        fast = ["rf", "dt", "nb", "et", "xgb", "lgb"]
        for code, var in self.classifier_vars.items():
            var.set(code in fast)
    
    # ===== 日志相关函数 =====
    def log(self, message):
        """添加日志消息"""
        self.log_queue.put(message)
    
    def update_log(self):
        """更新日志显示"""
        try:
            while True:
                message = self.log_queue.get_nowait()
                self.log_text.insert(tk.END, message + "\n")
                self.log_text.see(tk.END)
        except queue.Empty:
            pass
        
        self.root.after(100, self.update_log)
    
    # ===== 主要功能函数 =====
    def start_classification(self):
        """开始分类"""
        # 检查输入
        if not self.image_path.get():
            messagebox.showerror("错误", "请选择影像文件！")
            return
        
        if not self.train_shp_path.get():
            messagebox.showerror("错误", "请选择训练样本！")
            return
        
        # 检查是否至少选择了一个分类器
        selected_classifiers = [code for code, var in self.classifier_vars.items() 
                               if var.get()]
        if not selected_classifiers:
            messagebox.showerror("错误", "请至少选择一个分类器！")
            return
        
        # 禁用开始按钮，启用停止按钮
        self.start_btn.config(state=tk.DISABLED)
        self.stop_btn.config(state=tk.NORMAL)
        self.is_running = True
        
        # 清空日志
        self.log_text.delete(1.0, tk.END)
        self.log("="*80)
        self.log("开始分类任务... (性能优化版)")
        self.log("="*80)
        
        if self.enable_sampling.get():
            self.log(f"✓ 启用数据采样：最大 {self.max_samples.get()} 个样本")
        if self.fast_mode.get():
            self.log(f"✓ 启用快速模式：减少模型复杂度")
        
        # 在新线程中运行分类
        thread = threading.Thread(target=self.run_classification, 
                                 args=(selected_classifiers,))
        thread.daemon = True
        thread.start()
    
    def stop_classification(self):
        """停止分类"""
        self.is_running = False
        self.log("\n用户请求停止...")
        self.status_var.set("已停止")
    
    def run_classification(self, selected_classifiers):
        """执行分类（在后台线程中运行）"""
        try:
            # 创建输出目录
            out_dir = Path(self.output_dir.get())
            out_dir.mkdir(exist_ok=True)
            
            # 1. 读取影像
            self.log(f"\n正在读取影像: {self.image_path.get()}")
            self.status_var.set("读取影像...")
            img = rxr.open_rasterio(self.image_path.get(), masked=True)
            self.log(f"影像尺寸: {img.shape}, 波段数: {img.rio.count}")
            
            if not self.is_running:
                return
            
            # 2. 读取类别信息
            self.log("\n正在读取类别信息...")
            class_names, class_colors, _ = self.backend.get_class_info_from_shp(
                self.train_shp_path.get(), 
                self.class_attr.get(), 
                self.name_attr.get()
            )
            self.log(f"检测到类别: {list(class_names.values())}")
            
            # 3. 提取训练样本
            self.log("\n正在处理训练样本...")
            self.status_var.set("处理训练样本...")
            train_mask = self.backend.rasterize_samples(
                self.train_shp_path.get(), img, self.class_attr.get()
            )
            
            # 使用采样功能
            max_samples = self.max_samples.get() if self.enable_sampling.get() else None
            
            X_train, y_train, n_nan, n_inf, n_sampled = self.backend.extract_samples(
                img, train_mask, 
                ignore_background=self.ignore_background.get(),
                max_samples=max_samples
            )
            
            self.log(f"训练样本数: {len(y_train)}")
            if n_nan > 0:
                self.log(f"  已移除 {n_nan} 个包含NaN的样本")
            if n_inf > 0:
                self.log(f"  已移除 {n_inf} 个包含Inf的样本")
            if n_sampled > 0:
                self.log(f"  已采样减少 {n_sampled} 个样本（提高速度）")
            
            if not self.is_running:
                return
            
            # 4. 提取验证样本（如果有）
            val_exists = os.path.exists(self.val_shp_path.get())
            if val_exists:
                self.log("\n正在处理验证样本...")
                val_mask = self.backend.rasterize_samples(
                    self.val_shp_path.get(), img, self.class_attr.get()
                )
                
                if self.ignore_background.get():
                    background_mask = self.backend.get_background_mask(img)
                    valid_val = (val_mask > 0) & (~background_mask)
                else:
                    valid_val = val_mask > 0
                
                yv_true = val_mask[valid_val]
                self.log(f"验证样本数: {len(yv_true)}")
            
            # 5. 对每个选择的分类器进行训练和评估
            all_classifiers = self.backend.get_all_classifiers(
                self.n_estimators.get(), 
                fast_mode=self.fast_mode.get()
            )
            comparison_results = []
            
            for i, clf_code in enumerate(selected_classifiers):
                if not self.is_running:
                    break
                
                clf, clf_name, clf_desc, needs_encoding, needs_scaling = all_classifiers[clf_code]
                
                self.log(f"\n{'='*80}")
                self.log(f"[{i+1}/{len(selected_classifiers)}] 正在测试: {clf_name} ({clf_code})")
                self.log(f"{'='*80}")
                self.status_var.set(f"训练 {clf_name}...")
                
                # 创建分类器目录
                clf_dir = out_dir / clf_code
                clf_dir.mkdir(exist_ok=True)
                
                try:
                    # ===== 数据预处理 =====
                    label_encoder = None
                    scaler = None
                    X_train_use = X_train.copy()
                    y_train_use = y_train.copy()
                    
                    # 标签编码
                    if needs_encoding:
                        self.log("  应用标签编码...")
                        label_encoder = LabelEncoder()
                        y_train_use = label_encoder.fit_transform(y_train)
                    
                    # 特征缩放
                    if needs_scaling:
                        self.log("  应用特征缩放...")
                        scaler = StandardScaler()
                        X_train_use = scaler.fit_transform(X_train_use)
                    
                    # 训练
                    self.log("开始训练...")
                    train_start = time.time()
                    clf.fit(X_train_use, y_train_use)
                    train_time = time.time() - train_start
                    self.log(f"训练完成，耗时: {train_time:.2f} 秒")
                    
                    # 训练集精度
                    y_train_pred = clf.predict(X_train_use)
                    
                    # 反向转换
                    if label_encoder is not None:
                        y_train_pred = label_encoder.inverse_transform(y_train_pred)
                    
                    train_metrics = self.backend.calculate_metrics(y_train, y_train_pred)
                    self.log(f"训练集精度: {train_metrics['overall_accuracy']:.4f}, "
                           f"Kappa: {train_metrics['kappa']:.4f}")
                    
                    if not self.is_running:
                        break
                    
                    # 预测整幅影像
                    self.log("开始预测整幅影像...")
                    self.status_var.set(f"预测 {clf_name}...")
                    
                    pred_start = time.time()
                    classified_path = clf_dir / f"classified_{clf_code}.tif"
                    
                    def update_progress(progress):
                        self.progress_var.set(progress)
                    
                    self.backend.predict_by_block(
                        clf, img, classified_path, 
                        block_size=self.block_size.get(),
                        ignore_background=self.ignore_background.get(),
                        progress_callback=update_progress,
                        label_encoder=label_encoder,
                        scaler=scaler
                    )
                    
                    pred_time = time.time() - pred_start
                    self.log(f"预测完成，耗时: {pred_time:.2f} 秒")
                    
                    # 验证集精度
                    val_metrics = {'overall_accuracy': np.nan, 'kappa': np.nan, 'f1_macro': np.nan}
                    if val_exists:
                        self.log("评估验证集...")
                        with rxr.open_rasterio(classified_path) as pred_img:
                            pred_arr = pred_img.values.squeeze()
                        
                        yv_pred = pred_arr[valid_val]
                        val_metrics = self.backend.calculate_metrics(yv_true, yv_pred)
                        self.log(f"验证集精度: {val_metrics['overall_accuracy']:.4f}, "
                               f"Kappa: {val_metrics['kappa']:.4f}")
                    
                    # 记录结果
                    result = {
                        '分类器代码': clf_code,
                        '分类器名称': clf_name,
                        '训练集精度': train_metrics['overall_accuracy'],
                        '训练集Kappa': train_metrics['kappa'],
                        '训练集F1': train_metrics['f1_macro'],
                        '验证集精度': val_metrics['overall_accuracy'],
                        '验证集Kappa': val_metrics['kappa'],
                        '验证集F1': val_metrics['f1_macro'],
                        '训练时间(秒)': train_time,
                        '预测时间(秒)': pred_time,
                        '总时间(秒)': train_time + pred_time
                    }
                    comparison_results.append(result)
                    
                    self.log(f"✓ {clf_name} 完成")
                    
                except Exception as e:
                    self.log(f"✗ {clf_name} 失败: {str(e)}")
                    import traceback
                    self.log(traceback.format_exc())
                    continue
                
                # 更新总进度
                self.progress_var.set((i + 1) / len(selected_classifiers) * 100)
            
            # 6. 生成对比报告
            if comparison_results and self.is_running:
                self.log(f"\n{'='*80}")
                self.log("生成对比报告...")
                self.status_var.set("生成报告...")
                
                comparison_df = pd.DataFrame(comparison_results)
                comparison_df.to_csv(out_dir / "classifier_comparison.csv", 
                                   index=False, encoding='utf-8-sig')
                
                # 生成详细报告
                with open(out_dir / "comparison_summary.txt", 'w', encoding='utf-8') as f:
                    f.write("分类器性能对比摘要\n")
                    f.write("="*60 + "\n\n")
                    
                    f.write(f"成功完成: {len(comparison_results)}/{len(selected_classifiers)} 个分类器\n")
                    f.write(f"训练样本数: {len(y_train)}\n")
                    if val_exists:
                        f.write(f"验证样本数: {len(yv_true)}\n")
                    f.write(f"性能优化: 采样={self.enable_sampling.get()}, 快速模式={self.fast_mode.get()}\n\n")
                    
                    # 精度排名
                    sorted_df = comparison_df.sort_values('验证集精度', ascending=False)
                    f.write("验证集精度排名:\n")
                    f.write("-"*60 + "\n")
                    for idx, (_, row) in enumerate(sorted_df.iterrows(), 1):
                        f.write(f"{idx}. {row['分类器名称']:12s} - "
                               f"精度: {row['验证集精度']:.4f}, "
                               f"Kappa: {row['验证集Kappa']:.4f}, "
                               f"F1: {row['验证集F1']:.4f}\n")
                    
                    # 速度排名
                    f.write("\n" + "-"*60 + "\n")
                    f.write("总时间排名（训练+预测）:\n")
                    f.write("-"*60 + "\n")
                    sorted_time = comparison_df.sort_values('总时间(秒)')
                    for idx, (_, row) in enumerate(sorted_time.iterrows(), 1):
                        f.write(f"{idx}. {row['分类器名称']:12s} - "
                               f"{row['总时间(秒)']:.2f} 秒 "
                               f"(训练: {row['训练时间(秒)']:.2f}s, 预测: {row['预测时间(秒)']:.2f}s)\n")
                    
                    # 推荐
                    f.write("\n" + "="*60 + "\n")
                    f.write("推荐:\n")
                    f.write("-"*60 + "\n")
                    
                    best_acc = sorted_df.iloc[0]
                    f.write(f"最高精度: {best_acc['分类器名称']} ({best_acc['验证集精度']:.4f})\n")
                    
                    best_speed = sorted_time.iloc[0]
                    f.write(f"最快速度: {best_speed['分类器名称']} ({best_speed['总时间(秒)']:.2f}秒)\n")
                    
                    # 综合评分
                    comparison_df['综合得分'] = (
                        comparison_df['验证集精度'] * 0.7 + 
                        (1 - comparison_df['总时间(秒)'] / comparison_df['总时间(秒)'].max()) * 0.3
                    )
                    best_overall = comparison_df.loc[comparison_df['综合得分'].idxmax()]
                    f.write(f"综合最佳: {best_overall['分类器名称']} "
                           f"(精度: {best_overall['验证集精度']:.4f}, "
                           f"时间: {best_overall['总时间(秒)']:.2f}秒)\n")
                
                self.log("\n✓ 所有任务完成！")
                self.log(f"结果保存至: {out_dir.absolute()}")
                self.log(f"成功: {len(comparison_results)}/{len(selected_classifiers)} 个分类器")
                self.status_var.set("完成")
                
                # 显示最佳结果
                best_clf = comparison_df.loc[comparison_df['验证集精度'].idxmax()]
                self.log(f"\n【最佳精度】{best_clf['分类器名称']}: {best_clf['验证集精度']:.4f}")
                
                messagebox.showinfo("完成", 
                    f"分类任务完成！\n"
                    f"成功: {len(comparison_results)}/{len(selected_classifiers)}\n"
                    f"最佳: {best_clf['分类器名称']} ({best_clf['验证集精度']:.4f})\n"
                    f"结果: {out_dir}")
            
        except Exception as e:
            self.log(f"\n错误: {str(e)}")
            import traceback
            self.log(traceback.format_exc())
            messagebox.showerror("错误", f"发生错误:\n{str(e)}")
            self.status_var.set("错误")
        
        finally:
            # 恢复按钮状态
            self.start_btn.config(state=tk.NORMAL)
            self.stop_btn.config(state=tk.DISABLED)
            self.progress_var.set(0)
            self.is_running = False
    
    def open_result_dir(self):
        """打开结果目录"""
        out_dir = Path(self.output_dir.get())
        if out_dir.exists():
            import subprocess
            import platform
            
            if platform.system() == "Windows":
                os.startfile(out_dir)
            elif platform.system() == "Darwin":  # macOS
                subprocess.Popen(["open", out_dir])
            else:  # Linux
                subprocess.Popen(["xdg-open", out_dir])
        else:
            messagebox.showwarning("警告", "结果目录不存在！")
    
    def view_report(self):
        """查看对比报告"""
        report_file = Path(self.output_dir.get()) / "comparison_summary.txt"
        if report_file.exists():
            # 创建新窗口显示报告
            report_window = tk.Toplevel(self.root)
            report_window.title("分类器对比报告")
            report_window.geometry("800x600")
            
            text_widget = scrolledtext.ScrolledText(report_window, wrap=tk.WORD)
            text_widget.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
            
            with open(report_file, 'r', encoding='utf-8') as f:
                content = f.read()
                text_widget.insert(1.0, content)
            
            text_widget.config(state=tk.DISABLED)
        else:
            messagebox.showwarning("警告", "报告文件不存在！请先运行分类。")

# ==================== 主程序入口 ====================
def main():
    root = tk.Tk()
    app = ClassificationGUI(root)
    root.mainloop()

if __name__ == "__main__":
    main()

### SVM比较慢，优化

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
遥感影像监督分类系统 - GUI版本 (性能优化版)
支持多分类器对比和可视化
优化：数据采样、特征缩放、参数调优、并行处理
"""

import os
import sys
import time
import threading
import queue
from pathlib import Path
import tkinter as tk
from tkinter import ttk, filedialog, messagebox, scrolledtext
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure
import matplotlib.colors as mcolors
import seaborn as sns
import geopandas as gpd
import rioxarray as rxr
import xarray as xr
from rasterio import features
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.ensemble import (RandomForestClassifier, GradientBoostingClassifier, 
                              AdaBoostClassifier, ExtraTreesClassifier)
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import (confusion_matrix, accuracy_score, classification_report, 
                           cohen_kappa_score, precision_score, recall_score, f1_score)
from tqdm import tqdm
from scipy import ndimage
from skimage import morphology
import warnings
warnings.filterwarnings('ignore')

plt.rcParams["font.sans-serif"] = ["SimHei", "DejaVu Sans"]
plt.rcParams["axes.unicode_minus"] = False

# ==================== 后端处理类 ====================
class ClassificationBackend:
    """分类处理后端（性能优化版）"""
    
    def __init__(self):
        self.BACKGROUND_VALUE = 0
        self.RANDOM_STATE = 42
        
        # 预定义颜色
        self.LANDUSE_COLORS = {
            "水体": "lightblue", "河流": "blue", "湖泊": "deepskyblue",
            "植被": "forestgreen", "森林": "darkgreen", "草地": "limegreen",
            "农田": "yellowgreen", "耕地": "olivedrab",
            "建筑": "gray", "城市": "dimgray", "居民地": "slategray",
            "裸地": "tan", "沙地": "wheat", "其他": "darkred"
        }
        
        self.COLOR_PALETTE = ['forestgreen', 'lightblue', 'gray', 'tan', 'yellow', 
                             'darkred', 'purple', 'orange', 'pink', 'brown']
    
    def get_all_classifiers(self, n_estimators=100, fast_mode=False):
        """
        获取所有分类器（优化参数）
        fast_mode: 快速模式，使用更少的估计器和更简单的参数
        """
        # 根据模式调整参数
        if fast_mode:
            n_est = min(50, n_estimators)
            max_depth = 10
            max_iter = 200
        else:
            n_est = n_estimators
            max_depth = 20
            max_iter = 500
        
        classifiers = {
            "rf": (
                RandomForestClassifier(
                    n_estimators=n_est, 
                    n_jobs=-1, 
                    random_state=self.RANDOM_STATE, 
                    verbose=0,
                    max_depth=max_depth,
                    min_samples_split=5,
                    min_samples_leaf=2,
                    max_features='sqrt'  # 减少特征数量
                ),
                "随机森林", "Random Forest", False, False
            ),
            "svm": (
                SVC(
                    kernel="rbf", 
                    C=1.0,
                    gamma='scale',
                    cache_size=500,  # 增加缓存
                    probability=True, 
                    random_state=self.RANDOM_STATE
                ),
                "支持向量机", "SVM", False, True  # 需要特征缩放
            ),
            "dt": (
                DecisionTreeClassifier(
                    random_state=self.RANDOM_STATE,
                    max_depth=max_depth,
                    min_samples_split=5,
                    min_samples_leaf=2
                ),
                "决策树", "Decision Tree", False, False
            ),
            "knn": (
                KNeighborsClassifier(
                    n_neighbors=5,
                    n_jobs=-1,
                    algorithm='ball_tree',  # 更快的算法
                    leaf_size=30
                ),
                "K近邻", "KNN", False, True  # 需要特征缩放
            ),
            "nb": (
                GaussianNB(),
                "朴素贝叶斯", "Naive Bayes", False, False
            ),
            "gb": (
                GradientBoostingClassifier(
                    n_estimators=n_est,
                    learning_rate=0.1,
                    max_depth=5,  # 减小深度提高速度
                    random_state=self.RANDOM_STATE, 
                    verbose=0,
                    subsample=0.8  # 使用子采样
                ),
                "梯度提升", "Gradient Boosting", False, False
            ),
            "ada": (
                AdaBoostClassifier(
                    n_estimators=n_est,
                    learning_rate=1.0,
                    random_state=self.RANDOM_STATE,
                    algorithm='SAMME.R'  # 更快的算法
                ),
                "AdaBoost", "AdaBoost", False, False
            ),
            "et": (
                ExtraTreesClassifier(
                    n_estimators=n_est,
                    n_jobs=-1,
                    random_state=self.RANDOM_STATE,
                    verbose=0,
                    max_depth=max_depth,
                    min_samples_split=5,
                    max_features='sqrt'
                ),
                "极端随机树", "Extra Trees", False, False
            ),
            "lr": (
                LogisticRegression(
                    max_iter=max_iter,
                    n_jobs=-1,
                    random_state=self.RANDOM_STATE,
                    verbose=0,
                    solver='lbfgs',  # 快速求解器
                    multi_class='multinomial'
                ),
                "逻辑回归", "Logistic Regression", False, True  # 需要特征缩放
            ),
            "mlp": (
                MLPClassifier(
                    hidden_layer_sizes=(100, 50),
                    max_iter=max_iter,
                    random_state=self.RANDOM_STATE,
                    verbose=False,
                    early_stopping=True,
                    validation_fraction=0.1,
                    n_iter_no_change=10,  # 早停
                    learning_rate='adaptive'
                ),
                "神经网络", "MLP", False, True  # 需要特征缩放
            ),
        }
        
        # XGBoost
        try:
            from xgboost import XGBClassifier
            classifiers["xgb"] = (
                XGBClassifier(
                    n_estimators=n_est,
                    learning_rate=0.1,
                    max_depth=6,  # 减小深度
                    n_jobs=-1,
                    random_state=self.RANDOM_STATE,
                    verbosity=0,
                    tree_method='hist',  # 使用直方图算法，更快
                    subsample=0.8,
                    colsample_bytree=0.8
                ),
                "XGBoost", "XGBoost", True, False  # 需要标签编码
            )
        except ImportError:
            pass
        
        # LightGBM
        try:
            from lightgbm import LGBMClassifier
            classifiers["lgb"] = (
                LGBMClassifier(
                    n_estimators=n_est,
                    learning_rate=0.1,
                    max_depth=max_depth,
                    n_jobs=-1,
                    random_state=self.RANDOM_STATE,
                    verbose=-1,
                    num_leaves=31,
                    subsample=0.8,
                    colsample_bytree=0.8
                ),
                "LightGBM", "LightGBM", False, False
            )
        except ImportError:
            pass
        
        return classifiers
    
    def get_background_mask(self, image):
        """获取背景掩膜"""
        data = image.values
        background_mask = np.all(data == 0, axis=0)
        return background_mask
    
    def get_class_info_from_shp(self, shp_path, class_attr, name_attr):
        """从shp文件获取类别信息"""
        gdf = gpd.read_file(shp_path)
        
        if name_attr not in gdf.columns:
            gdf[name_attr] = gdf[class_attr].apply(lambda x: f"Class_{x}")
        
        class_info = gdf[[class_attr, name_attr]].drop_duplicates()
        class_names = dict(zip(class_info[class_attr], class_info[name_attr]))
        
        class_colors = {}
        for i, (class_id, class_name) in enumerate(class_names.items()):
            color_found = False
            for key, color in self.LANDUSE_COLORS.items():
                if key in class_name:
                    class_colors[class_id] = color
                    color_found = True
                    break
            if not color_found:
                class_colors[class_id] = self.COLOR_PALETTE[i % len(self.COLOR_PALETTE)]
        
        return class_names, class_colors, sorted(class_names.keys())
    
    def rasterize_samples(self, shp, ref_img, attr):
        """矢量栅格化"""
        gdf = gpd.read_file(shp)
        gdf = gdf.to_crs(ref_img.rio.crs)
        shapes = ((geom, value) for geom, value in zip(gdf.geometry, gdf[attr]))
        
        arr = features.rasterize(
            shapes=shapes,
            out_shape=ref_img.shape[1:],
            transform=ref_img.rio.transform(),
            fill=0,
            all_touched=True,
            dtype="uint16"
        )
        return arr
    
    def extract_samples(self, image, mask, ignore_background=True, max_samples=None):
        """
        提取样本并清理NaN值
        max_samples: 最大样本数，如果超过则进行分层采样
        """
        data = np.moveaxis(image.values, 0, -1)
        valid = mask > 0
        
        if ignore_background:
            background_mask = self.get_background_mask(image)
            valid = valid & (~background_mask)
        
        X = data[valid]
        y = mask[valid]
        
        # 清理NaN和Inf值
        nan_mask = np.isnan(X).any(axis=1)
        inf_mask = np.isinf(X).any(axis=1)
        bad_mask = nan_mask | inf_mask
        
        n_nan = np.sum(nan_mask)
        n_inf = np.sum(inf_mask)
        
        X = X[~bad_mask]
        y = y[~bad_mask]
        
        # ===== 性能优化：分层采样 =====
        n_sampled = 0
        if max_samples is not None and len(y) > max_samples:
            n_original = len(y)
            
            # 使用分层采样保持类别比例
            unique_classes, class_counts = np.unique(y, return_counts=True)
            
            if len(unique_classes) > 1:
                # 使用sklearn的分层采样
                splitter = StratifiedShuffleSplit(
                    n_splits=1, 
                    train_size=max_samples, 
                    random_state=self.RANDOM_STATE
                )
                
                sample_idx, _ = next(splitter.split(X, y))
                X = X[sample_idx]
                y = y[sample_idx]
                n_sampled = n_original - len(y)
            else:
                # 只有一个类别，随机采样
                np.random.seed(self.RANDOM_STATE)
                sample_idx = np.random.choice(len(y), max_samples, replace=False)
                X = X[sample_idx]
                y = y[sample_idx]
                n_sampled = n_original - len(y)
        
        return X, y, n_nan, n_inf, n_sampled
    
    def calculate_metrics(self, y_true, y_pred):
        """计算评价指标"""
        return {
            'overall_accuracy': accuracy_score(y_true, y_pred),
            'kappa': cohen_kappa_score(y_true, y_pred),
            'precision_macro': precision_score(y_true, y_pred, average='macro', zero_division=0),
            'recall_macro': recall_score(y_true, y_pred, average='macro', zero_division=0),
            'f1_macro': f1_score(y_true, y_pred, average='macro', zero_division=0),
        }
    
    def predict_by_block(self, model, image, out_path, block_size=512, 
                        ignore_background=True, progress_callback=None,
                        label_encoder=None, scaler=None):
        """
        分块预测（优化版）
        添加了特征缩放支持
        """
        height, width = image.shape[1], image.shape[2]
        prediction = np.zeros((height, width), dtype='uint16')
        
        if ignore_background:
            background_mask = self.get_background_mask(image)
        
        total_blocks = int(np.ceil(height / block_size))
        
        for i, y in enumerate(range(0, height, block_size)):
            h = min(block_size, height - y)
            block_data = image.isel(y=slice(y, y+h)).values
            data = np.moveaxis(block_data, 0, -1)
            original_shape = data.shape
            data_flat = data.reshape(-1, data.shape[-1])
            
            if ignore_background:
                block_bg_mask = background_mask[y:y+h, :].flatten()
                non_bg_indices = ~block_bg_mask
                
                if np.any(non_bg_indices):
                    data_to_predict = np.nan_to_num(data_flat[non_bg_indices], 
                                                   nan=0.0, posinf=0.0, neginf=0.0)
                    
                    # 如果使用了特征缩放
                    if scaler is not None:
                        data_to_predict = scaler.transform(data_to_predict)
                    
                    preds_non_bg = model.predict(data_to_predict)
                    
                    # 如果使用了标签编码，需要反向转换
                    if label_encoder is not None:
                        preds_non_bg = label_encoder.inverse_transform(preds_non_bg)
                    
                    preds_flat = np.zeros(len(data_flat), dtype='uint16')
                    preds_flat[non_bg_indices] = preds_non_bg
                    preds = preds_flat.reshape(original_shape[0], original_shape[1])
                else:
                    preds = np.zeros((original_shape[0], original_shape[1]), dtype='uint16')
            else:
                data_flat = np.nan_to_num(data_flat, nan=0.0, posinf=0.0, neginf=0.0)
                
                if scaler is not None:
                    data_flat = scaler.transform(data_flat)
                
                preds = model.predict(data_flat)
                
                if label_encoder is not None:
                    preds = label_encoder.inverse_transform(preds)
                
                preds = preds.reshape(original_shape[0], original_shape[1]).astype("uint16")
            
            prediction[y:y+h, :] = preds
            
            if progress_callback:
                progress_callback((i + 1) / total_blocks * 100)
        
        # 保存结果
        prediction_da = xr.DataArray(
            prediction,
            dims=['y', 'x'],
            coords={'y': image.coords['y'], 'x': image.coords['x']}
        )
        
        prediction_da.rio.write_crs(image.rio.crs, inplace=True)
        prediction_da.rio.write_transform(image.rio.transform(), inplace=True)
        prediction_da.rio.write_nodata(self.BACKGROUND_VALUE, inplace=True)
        
        prediction_da.rio.to_raster(out_path, driver='GTiff', dtype='uint16', 
                                    compress='lzw', tiled=True)
        return out_path

# ==================== GUI主类 ====================
class ClassificationGUI:
    """遥感影像分类GUI主界面（性能优化版）"""
    
    def __init__(self, root):
        self.root = root
        self.root.title("遥感影像监督分类系统 v2.2 (性能优化版)")
        self.root.geometry("1400x900")
        
        # 后端处理对象
        self.backend = ClassificationBackend()
        
        # 数据变量
        self.image_path = tk.StringVar()
        self.train_shp_path = tk.StringVar()
        self.val_shp_path = tk.StringVar()
        self.output_dir = tk.StringVar(value=str(Path("./results_gui")))
        
        self.class_attr = tk.StringVar(value="class")
        self.name_attr = tk.StringVar(value="name")
        self.n_estimators = tk.IntVar(value=100)
        self.block_size = tk.IntVar(value=512)
        self.ignore_background = tk.BooleanVar(value=True)
        
        # ===== 性能优化参数 =====
        self.enable_sampling = tk.BooleanVar(value=True)  # 启用采样
        self.max_samples = tk.IntVar(value=50000)  # 最大样本数
        self.fast_mode = tk.BooleanVar(value=False)  # 快速模式
        
        # 分类器选择
        self.classifier_vars = {}
        all_classifiers = self.backend.get_all_classifiers()
        for code, (_, name, _, _, _) in all_classifiers.items():
            self.classifier_vars[code] = tk.BooleanVar(value=False)
        
        # 运行状态
        self.is_running = False
        self.log_queue = queue.Queue()
        
        # 构建界面
        self.build_ui()
        
        # 启动日志更新
        self.update_log()
    
    def build_ui(self):
        """构建用户界面"""
        # 创建主框架
        main_frame = ttk.Frame(self.root, padding="10")
        main_frame.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S))
        
        self.root.columnconfigure(0, weight=1)
        self.root.rowconfigure(0, weight=1)
        main_frame.columnconfigure(1, weight=1)
        main_frame.rowconfigure(3, weight=1)
        
        # ===== 1. 文件选择区 =====
        file_frame = ttk.LabelFrame(main_frame, text="1. 数据输入", padding="10")
        file_frame.grid(row=0, column=0, columnspan=2, sticky=(tk.W, tk.E), pady=5)
        
        # 影像文件
        ttk.Label(file_frame, text="影像文件:").grid(row=0, column=0, sticky=tk.W, pady=2)
        ttk.Entry(file_frame, textvariable=self.image_path, width=60).grid(
            row=0, column=1, sticky=(tk.W, tk.E), padx=5
        )
        ttk.Button(file_frame, text="浏览", command=self.browse_image).grid(
            row=0, column=2, padx=5
        )
        
        # 训练样本
        ttk.Label(file_frame, text="训练样本:").grid(row=1, column=0, sticky=tk.W, pady=2)
        ttk.Entry(file_frame, textvariable=self.train_shp_path, width=60).grid(
            row=1, column=1, sticky=(tk.W, tk.E), padx=5
        )
        ttk.Button(file_frame, text="浏览", command=self.browse_train_shp).grid(
            row=1, column=2, padx=5
        )
        
        # 验证样本
        ttk.Label(file_frame, text="验证样本:").grid(row=2, column=0, sticky=tk.W, pady=2)
        ttk.Entry(file_frame, textvariable=self.val_shp_path, width=60).grid(
            row=2, column=1, sticky=(tk.W, tk.E), padx=5
        )
        ttk.Button(file_frame, text="浏览", command=self.browse_val_shp).grid(
            row=2, column=2, padx=5
        )
        
        # 输出目录
        ttk.Label(file_frame, text="输出目录:").grid(row=3, column=0, sticky=tk.W, pady=2)
        ttk.Entry(file_frame, textvariable=self.output_dir, width=60).grid(
            row=3, column=1, sticky=(tk.W, tk.E), padx=5
        )
        ttk.Button(file_frame, text="浏览", command=self.browse_output).grid(
            row=3, column=2, padx=5
        )
        
        file_frame.columnconfigure(1, weight=1)
        
        # ===== 2. 参数设置区 =====
        param_frame = ttk.LabelFrame(main_frame, text="2. 参数配置", padding="10")
        param_frame.grid(row=1, column=0, sticky=(tk.W, tk.E, tk.N), pady=5, padx=(0, 5))
        
        # 属性字段
        ttk.Label(param_frame, text="类别编号字段:").grid(row=0, column=0, sticky=tk.W, pady=2)
        ttk.Entry(param_frame, textvariable=self.class_attr, width=15).grid(
            row=0, column=1, sticky=tk.W, padx=5
        )
        
        ttk.Label(param_frame, text="类别名称字段:").grid(row=1, column=0, sticky=tk.W, pady=2)
        ttk.Entry(param_frame, textvariable=self.name_attr, width=15).grid(
            row=1, column=1, sticky=tk.W, padx=5
        )
        
        # 其他参数
        ttk.Label(param_frame, text="树模型数量:").grid(row=2, column=0, sticky=tk.W, pady=2)
        ttk.Spinbox(param_frame, from_=10, to=500, textvariable=self.n_estimators, 
                   width=13).grid(row=2, column=1, sticky=tk.W, padx=5)
        
        ttk.Label(param_frame, text="分块大小:").grid(row=3, column=0, sticky=tk.W, pady=2)
        ttk.Spinbox(param_frame, from_=256, to=2048, increment=256, 
                   textvariable=self.block_size, width=13).grid(
            row=3, column=1, sticky=tk.W, padx=5
        )
        
        # ===== 性能优化选项 =====
        ttk.Separator(param_frame, orient='horizontal').grid(
            row=4, column=0, columnspan=2, sticky=(tk.W, tk.E), pady=5
        )
        
        ttk.Label(param_frame, text="性能优化:", font=('', 9, 'bold')).grid(
            row=5, column=0, columnspan=2, sticky=tk.W, pady=2
        )
        
        sample_frame = ttk.Frame(param_frame)
        sample_frame.grid(row=6, column=0, columnspan=2, sticky=(tk.W, tk.E))
        
        ttk.Checkbutton(sample_frame, text="启用采样", 
                       variable=self.enable_sampling,
                       command=self.toggle_sampling).pack(side=tk.LEFT)
        
        ttk.Label(sample_frame, text="  最大样本数:").pack(side=tk.LEFT, padx=(10, 0))
        self.max_samples_spinbox = ttk.Spinbox(
            sample_frame, from_=10000, to=200000, increment=10000,
            textvariable=self.max_samples, width=10
        )
        self.max_samples_spinbox.pack(side=tk.LEFT, padx=5)
        
        ttk.Checkbutton(param_frame, text="快速模式（减少模型复杂度）", 
                       variable=self.fast_mode).grid(
            row=7, column=0, columnspan=2, sticky=tk.W, pady=2
        )
        
        ttk.Checkbutton(param_frame, text="忽略背景值（所有波段为0）", 
                       variable=self.ignore_background).grid(
            row=8, column=0, columnspan=2, sticky=tk.W, pady=2
        )
        
        # ===== 3. 分类器选择区 =====
        clf_frame = ttk.LabelFrame(main_frame, text="3. 分类器选择", padding="10")
        clf_frame.grid(row=1, column=1, rowspan=2, sticky=(tk.W, tk.E, tk.N, tk.S), pady=5)
        
        # 快捷按钮
        btn_frame = ttk.Frame(clf_frame)
        btn_frame.grid(row=0, column=0, columnspan=3, sticky=(tk.W, tk.E), pady=(0, 5))
        ttk.Button(btn_frame, text="全选", command=self.select_all_classifiers).pack(
            side=tk.LEFT, padx=2
        )
        ttk.Button(btn_frame, text="全不选", command=self.deselect_all_classifiers).pack(
            side=tk.LEFT, padx=2
        )
        ttk.Button(btn_frame, text="推荐选择", command=self.select_recommended).pack(
            side=tk.LEFT, padx=2
        )
        ttk.Button(btn_frame, text="快速分类器", command=self.select_fast).pack(
            side=tk.LEFT, padx=2
        )
        
        # 分类器复选框
        all_classifiers = self.backend.get_all_classifiers()
        row = 1
        col = 0
        for code, (_, name, _, _, _) in all_classifiers.items():
            cb = ttk.Checkbutton(clf_frame, text=f"{name} ({code})", 
                               variable=self.classifier_vars[code])
            cb.grid(row=row, column=col, sticky=tk.W, pady=2, padx=5)
            
            col += 1
            if col >= 3:
                col = 0
                row += 1
        
        # ===== 4. 控制按钮区 =====
        control_frame = ttk.LabelFrame(main_frame, text="4. 运行控制", padding="10")
        control_frame.grid(row=2, column=0, sticky=(tk.W, tk.E), pady=5, padx=(0, 5))
        
        self.start_btn = ttk.Button(control_frame, text="开始分类", 
                                    command=self.start_classification, width=15)
        self.start_btn.grid(row=0, column=0, padx=5, pady=5)
        
        self.stop_btn = ttk.Button(control_frame, text="停止", 
                                   command=self.stop_classification, 
                                   state=tk.DISABLED, width=15)
        self.stop_btn.grid(row=0, column=1, padx=5, pady=5)
        
        ttk.Button(control_frame, text="打开结果目录", 
                  command=self.open_result_dir, width=15).grid(
            row=0, column=2, padx=5, pady=5
        )
        
        ttk.Button(control_frame, text="查看对比报告", 
                  command=self.view_report, width=15).grid(
            row=0, column=3, padx=5, pady=5
        )
        
        # 进度条
        ttk.Label(control_frame, text="进度:").grid(row=1, column=0, sticky=tk.W, pady=2)
        self.progress_var = tk.DoubleVar()
        self.progress_bar = ttk.Progressbar(control_frame, variable=self.progress_var, 
                                           maximum=100, length=400)
        self.progress_bar.grid(row=1, column=1, columnspan=3, sticky=(tk.W, tk.E), 
                              padx=5, pady=2)
        
        control_frame.columnconfigure(3, weight=1)
        
        # ===== 5. 日志输出区 =====
        log_frame = ttk.LabelFrame(main_frame, text="5. 运行日志", padding="10")
        log_frame.grid(row=3, column=0, columnspan=2, sticky=(tk.W, tk.E, tk.N, tk.S), pady=5)
        
        self.log_text = scrolledtext.ScrolledText(log_frame, wrap=tk.WORD, 
                                                  height=20, width=100)
        self.log_text.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S))
        
        log_frame.columnconfigure(0, weight=1)
        log_frame.rowconfigure(0, weight=1)
        
        # 状态栏
        self.status_var = tk.StringVar(value="就绪")
        status_bar = ttk.Label(main_frame, textvariable=self.status_var, 
                              relief=tk.SUNKEN, anchor=tk.W)
        status_bar.grid(row=4, column=0, columnspan=2, sticky=(tk.W, tk.E), pady=(5, 0))
    
    def toggle_sampling(self):
        """切换采样功能"""
        if self.enable_sampling.get():
            self.max_samples_spinbox.config(state=tk.NORMAL)
        else:
            self.max_samples_spinbox.config(state=tk.DISABLED)
    
    # ===== 文件浏览函数 =====
    def browse_image(self):
        filename = filedialog.askopenfilename(
            title="选择影像文件",
            filetypes=[("GeoTIFF", "*.tif *.tiff"), ("所有文件", "*.*")]
        )
        if filename:
            self.image_path.set(filename)
    
    def browse_train_shp(self):
        filename = filedialog.askopenfilename(
            title="选择训练样本",
            filetypes=[("Shapefile", "*.shp"), ("所有文件", "*.*")]
        )
        if filename:
            self.train_shp_path.set(filename)
    
    def browse_val_shp(self):
        filename = filedialog.askopenfilename(
            title="选择验证样本",
            filetypes=[("Shapefile", "*.shp"), ("所有文件", "*.*")]
        )
        if filename:
            self.val_shp_path.set(filename)
    
    def browse_output(self):
        dirname = filedialog.askdirectory(title="选择输出目录")
        if dirname:
            self.output_dir.set(dirname)
    
    # ===== 分类器选择函数 =====
    def select_all_classifiers(self):
        for var in self.classifier_vars.values():
            var.set(True)
    
    def deselect_all_classifiers(self):
        for var in self.classifier_vars.values():
            var.set(False)
    
    def select_recommended(self):
        """选择推荐的分类器（精度优先）"""
        recommended = ["rf", "xgb", "svm", "et", "gb", "lgb"]
        for code, var in self.classifier_vars.items():
            var.set(code in recommended)
    
    def select_fast(self):
        """选择快速分类器（速度优先）"""
        fast = ["rf", "dt", "nb", "et", "xgb", "lgb"]
        for code, var in self.classifier_vars.items():
            var.set(code in fast)
    
    # ===== 日志相关函数 =====
    def log(self, message):
        """添加日志消息"""
        self.log_queue.put(message)
    
    def update_log(self):
        """更新日志显示"""
        try:
            while True:
                message = self.log_queue.get_nowait()
                self.log_text.insert(tk.END, message + "\n")
                self.log_text.see(tk.END)
        except queue.Empty:
            pass
        
        self.root.after(100, self.update_log)
    
    # ===== 主要功能函数 =====
    def start_classification(self):
        """开始分类"""
        # 检查输入
        if not self.image_path.get():
            messagebox.showerror("错误", "请选择影像文件！")
            return
        
        if not self.train_shp_path.get():
            messagebox.showerror("错误", "请选择训练样本！")
            return
        
        # 检查是否至少选择了一个分类器
        selected_classifiers = [code for code, var in self.classifier_vars.items() 
                               if var.get()]
        if not selected_classifiers:
            messagebox.showerror("错误", "请至少选择一个分类器！")
            return
        
        # 禁用开始按钮，启用停止按钮
        self.start_btn.config(state=tk.DISABLED)
        self.stop_btn.config(state=tk.NORMAL)
        self.is_running = True
        
        # 清空日志
        self.log_text.delete(1.0, tk.END)
        self.log("="*80)
        self.log("开始分类任务... (性能优化版)")
        self.log("="*80)
        
        if self.enable_sampling.get():
            self.log(f"✓ 启用数据采样：最大 {self.max_samples.get()} 个样本")
        if self.fast_mode.get():
            self.log(f"✓ 启用快速模式：减少模型复杂度")
        
        # 在新线程中运行分类
        thread = threading.Thread(target=self.run_classification, 
                                 args=(selected_classifiers,))
        thread.daemon = True
        thread.start()
    
    def stop_classification(self):
        """停止分类"""
        self.is_running = False
        self.log("\n用户请求停止...")
        self.status_var.set("已停止")
    
    def run_classification(self, selected_classifiers):
        """执行分类（在后台线程中运行）"""
        try:
            # 创建输出目录
            out_dir = Path(self.output_dir.get())
            out_dir.mkdir(exist_ok=True)
            
            # 1. 读取影像
            self.log(f"\n正在读取影像: {self.image_path.get()}")
            self.status_var.set("读取影像...")
            img = rxr.open_rasterio(self.image_path.get(), masked=True)
            self.log(f"影像尺寸: {img.shape}, 波段数: {img.rio.count}")
            
            if not self.is_running:
                return
            
            # 2. 读取类别信息
            self.log("\n正在读取类别信息...")
            class_names, class_colors, _ = self.backend.get_class_info_from_shp(
                self.train_shp_path.get(), 
                self.class_attr.get(), 
                self.name_attr.get()
            )
            self.log(f"检测到类别: {list(class_names.values())}")
            
            # 3. 提取训练样本
            self.log("\n正在处理训练样本...")
            self.status_var.set("处理训练样本...")
            train_mask = self.backend.rasterize_samples(
                self.train_shp_path.get(), img, self.class_attr.get()
            )
            
            # 使用采样功能
            max_samples = self.max_samples.get() if self.enable_sampling.get() else None
            
            X_train, y_train, n_nan, n_inf, n_sampled = self.backend.extract_samples(
                img, train_mask, 
                ignore_background=self.ignore_background.get(),
                max_samples=max_samples
            )
            
            self.log(f"训练样本数: {len(y_train)}")
            if n_nan > 0:
                self.log(f"  已移除 {n_nan} 个包含NaN的样本")
            if n_inf > 0:
                self.log(f"  已移除 {n_inf} 个包含Inf的样本")
            if n_sampled > 0:
                self.log(f"  已采样减少 {n_sampled} 个样本（提高速度）")
            
            if not self.is_running:
                return
            
            # 4. 提取验证样本（如果有）
            val_exists = os.path.exists(self.val_shp_path.get())
            if val_exists:
                self.log("\n正在处理验证样本...")
                val_mask = self.backend.rasterize_samples(
                    self.val_shp_path.get(), img, self.class_attr.get()
                )
                
                if self.ignore_background.get():
                    background_mask = self.backend.get_background_mask(img)
                    valid_val = (val_mask > 0) & (~background_mask)
                else:
                    valid_val = val_mask > 0
                
                yv_true = val_mask[valid_val]
                self.log(f"验证样本数: {len(yv_true)}")
            
            # 5. 对每个选择的分类器进行训练和评估
            all_classifiers = self.backend.get_all_classifiers(
                self.n_estimators.get(), 
                fast_mode=self.fast_mode.get()
            )
            comparison_results = []
            
            for i, clf_code in enumerate(selected_classifiers):
                if not self.is_running:
                    break
                
                clf, clf_name, clf_desc, needs_encoding, needs_scaling = all_classifiers[clf_code]
                
                self.log(f"\n{'='*80}")
                self.log(f"[{i+1}/{len(selected_classifiers)}] 正在测试: {clf_name} ({clf_code})")
                self.log(f"{'='*80}")
                self.status_var.set(f"训练 {clf_name}...")
                
                # 创建分类器目录
                clf_dir = out_dir / clf_code
                clf_dir.mkdir(exist_ok=True)
                
                try:
                    # ===== 数据预处理 =====
                    label_encoder = None
                    scaler = None
                    X_train_use = X_train.copy()
                    y_train_use = y_train.copy()
                    
                    # 标签编码
                    if needs_encoding:
                        self.log("  应用标签编码...")
                        label_encoder = LabelEncoder()
                        y_train_use = label_encoder.fit_transform(y_train)
                    
                    # 特征缩放
                    if needs_scaling:
                        self.log("  应用特征缩放...")
                        scaler = StandardScaler()
                        X_train_use = scaler.fit_transform(X_train_use)
                    
                    # 训练
                    self.log("开始训练...")
                    train_start = time.time()
                    clf.fit(X_train_use, y_train_use)
                    train_time = time.time() - train_start
                    self.log(f"训练完成，耗时: {train_time:.2f} 秒")
                    
                    # 训练集精度
                    y_train_pred = clf.predict(X_train_use)
                    
                    # 反向转换
                    if label_encoder is not None:
                        y_train_pred = label_encoder.inverse_transform(y_train_pred)
                    
                    train_metrics = self.backend.calculate_metrics(y_train, y_train_pred)
                    self.log(f"训练集精度: {train_metrics['overall_accuracy']:.4f}, "
                           f"Kappa: {train_metrics['kappa']:.4f}")
                    
                    if not self.is_running:
                        break
                    
                    # 预测整幅影像
                    self.log("开始预测整幅影像...")
                    self.status_var.set(f"预测 {clf_name}...")
                    
                    pred_start = time.time()
                    classified_path = clf_dir / f"classified_{clf_code}.tif"
                    
                    def update_progress(progress):
                        self.progress_var.set(progress)
                    
                    self.backend.predict_by_block(
                        clf, img, classified_path, 
                        block_size=self.block_size.get(),
                        ignore_background=self.ignore_background.get(),
                        progress_callback=update_progress,
                        label_encoder=label_encoder,
                        scaler=scaler
                    )
                    
                    pred_time = time.time() - pred_start
                    self.log(f"预测完成，耗时: {pred_time:.2f} 秒")
                    
                    # 验证集精度
                    val_metrics = {'overall_accuracy': np.nan, 'kappa': np.nan, 'f1_macro': np.nan}
                    if val_exists:
                        self.log("评估验证集...")
                        with rxr.open_rasterio(classified_path) as pred_img:
                            pred_arr = pred_img.values.squeeze()
                        
                        yv_pred = pred_arr[valid_val]
                        val_metrics = self.backend.calculate_metrics(yv_true, yv_pred)
                        self.log(f"验证集精度: {val_metrics['overall_accuracy']:.4f}, "
                               f"Kappa: {val_metrics['kappa']:.4f}")
                    
                    # 记录结果
                    result = {
                        '分类器代码': clf_code,
                        '分类器名称': clf_name,
                        '训练集精度': train_metrics['overall_accuracy'],
                        '训练集Kappa': train_metrics['kappa'],
                        '训练集F1': train_metrics['f1_macro'],
                        '验证集精度': val_metrics['overall_accuracy'],
                        '验证集Kappa': val_metrics['kappa'],
                        '验证集F1': val_metrics['f1_macro'],
                        '训练时间(秒)': train_time,
                        '预测时间(秒)': pred_time,
                        '总时间(秒)': train_time + pred_time
                    }
                    comparison_results.append(result)
                    
                    self.log(f"✓ {clf_name} 完成")
                    
                except Exception as e:
                    self.log(f"✗ {clf_name} 失败: {str(e)}")
                    import traceback
                    self.log(traceback.format_exc())
                    continue
                
                # 更新总进度
                self.progress_var.set((i + 1) / len(selected_classifiers) * 100)
            
            # 6. 生成对比报告
            if comparison_results and self.is_running:
                self.log(f"\n{'='*80}")
                self.log("生成对比报告...")
                self.status_var.set("生成报告...")
                
                comparison_df = pd.DataFrame(comparison_results)
                comparison_df.to_csv(out_dir / "classifier_comparison.csv", 
                                   index=False, encoding='utf-8-sig')
                
                # 生成详细报告
                with open(out_dir / "comparison_summary.txt", 'w', encoding='utf-8') as f:
                    f.write("分类器性能对比摘要\n")
                    f.write("="*60 + "\n\n")
                    
                    f.write(f"成功完成: {len(comparison_results)}/{len(selected_classifiers)} 个分类器\n")
                    f.write(f"训练样本数: {len(y_train)}\n")
                    if val_exists:
                        f.write(f"验证样本数: {len(yv_true)}\n")
                    f.write(f"性能优化: 采样={self.enable_sampling.get()}, 快速模式={self.fast_mode.get()}\n\n")
                    
                    # 精度排名
                    sorted_df = comparison_df.sort_values('验证集精度', ascending=False)
                    f.write("验证集精度排名:\n")
                    f.write("-"*60 + "\n")
                    for idx, (_, row) in enumerate(sorted_df.iterrows(), 1):
                        f.write(f"{idx}. {row['分类器名称']:12s} - "
                               f"精度: {row['验证集精度']:.4f}, "
                               f"Kappa: {row['验证集Kappa']:.4f}, "
                               f"F1: {row['验证集F1']:.4f}\n")
                    
                    # 速度排名
                    f.write("\n" + "-"*60 + "\n")
                    f.write("总时间排名（训练+预测）:\n")
                    f.write("-"*60 + "\n")
                    sorted_time = comparison_df.sort_values('总时间(秒)')
                    for idx, (_, row) in enumerate(sorted_time.iterrows(), 1):
                        f.write(f"{idx}. {row['分类器名称']:12s} - "
                               f"{row['总时间(秒)']:.2f} 秒 "
                               f"(训练: {row['训练时间(秒)']:.2f}s, 预测: {row['预测时间(秒)']:.2f}s)\n")
                    
                    # 推荐
                    f.write("\n" + "="*60 + "\n")
                    f.write("推荐:\n")
                    f.write("-"*60 + "\n")
                    
                    best_acc = sorted_df.iloc[0]
                    f.write(f"最高精度: {best_acc['分类器名称']} ({best_acc['验证集精度']:.4f})\n")
                    
                    best_speed = sorted_time.iloc[0]
                    f.write(f"最快速度: {best_speed['分类器名称']} ({best_speed['总时间(秒)']:.2f}秒)\n")
                    
                    # 综合评分
                    comparison_df['综合得分'] = (
                        comparison_df['验证集精度'] * 0.7 + 
                        (1 - comparison_df['总时间(秒)'] / comparison_df['总时间(秒)'].max()) * 0.3
                    )
                    best_overall = comparison_df.loc[comparison_df['综合得分'].idxmax()]
                    f.write(f"综合最佳: {best_overall['分类器名称']} "
                           f"(精度: {best_overall['验证集精度']:.4f}, "
                           f"时间: {best_overall['总时间(秒)']:.2f}秒)\n")
                
                self.log("\n✓ 所有任务完成！")
                self.log(f"结果保存至: {out_dir.absolute()}")
                self.log(f"成功: {len(comparison_results)}/{len(selected_classifiers)} 个分类器")
                self.status_var.set("完成")
                
                # 显示最佳结果
                best_clf = comparison_df.loc[comparison_df['验证集精度'].idxmax()]
                self.log(f"\n【最佳精度】{best_clf['分类器名称']}: {best_clf['验证集精度']:.4f}")
                
                messagebox.showinfo("完成", 
                    f"分类任务完成！\n"
                    f"成功: {len(comparison_results)}/{len(selected_classifiers)}\n"
                    f"最佳: {best_clf['分类器名称']} ({best_clf['验证集精度']:.4f})\n"
                    f"结果: {out_dir}")
            
        except Exception as e:
            self.log(f"\n错误: {str(e)}")
            import traceback
            self.log(traceback.format_exc())
            messagebox.showerror("错误", f"发生错误:\n{str(e)}")
            self.status_var.set("错误")
        
        finally:
            # 恢复按钮状态
            self.start_btn.config(state=tk.NORMAL)
            self.stop_btn.config(state=tk.DISABLED)
            self.progress_var.set(0)
            self.is_running = False
    
    def open_result_dir(self):
        """打开结果目录"""
        out_dir = Path(self.output_dir.get())
        if out_dir.exists():
            import subprocess
            import platform
            
            if platform.system() == "Windows":
                os.startfile(out_dir)
            elif platform.system() == "Darwin":  # macOS
                subprocess.Popen(["open", out_dir])
            else:  # Linux
                subprocess.Popen(["xdg-open", out_dir])
        else:
            messagebox.showwarning("警告", "结果目录不存在！")
    
    def view_report(self):
        """查看对比报告"""
        report_file = Path(self.output_dir.get()) / "comparison_summary.txt"
        if report_file.exists():
            # 创建新窗口显示报告
            report_window = tk.Toplevel(self.root)
            report_window.title("分类器对比报告")
            report_window.geometry("800x600")
            
            text_widget = scrolledtext.ScrolledText(report_window, wrap=tk.WORD)
            text_widget.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
            
            with open(report_file, 'r', encoding='utf-8') as f:
                content = f.read()
                text_widget.insert(1.0, content)
            
            text_widget.config(state=tk.DISABLED)
        else:
            messagebox.showwarning("警告", "报告文件不存在！请先运行分类。")

# ==================== 主程序入口 ====================
def main():
    root = tk.Tk()
    app = ClassificationGUI(root)
    root.mainloop()

if __name__ == "__main__":
    main()

## deepseek


In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
通用遥感影像监督分类系统 (rioxarray版本)
-------------------------------------------------
功能：
1. 自动读取多波段遥感影像；
2. 从矢量样本中提取训练/验证数据；
3. 支持随机森林 / SVM / XGBoost 分类；
4. 采用分块预测模式；
5. 输出分类结果 GeoTIFF；
6. 自动生成分类报告与混淆矩阵；
7. 显示分类影像和精度评价结果；
8. 分类面积统计（平方千米）；
9. 后处理功能（去除小图斑、形态学操作）。
"""

import os
import time
import logging
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import seaborn as sns
import geopandas as gpd
import rioxarray as rxr
import xarray as xr
from shapely.geometry import mapping
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.metrics import (confusion_matrix, accuracy_score, classification_report, 
                           cohen_kappa_score, precision_score, recall_score, f1_score)
from sklearn.inspection import permutation_importance
from tqdm import tqdm
from scipy import ndimage
from skimage import morphology

plt.rcParams["font.sans-serif"] = ["SimHei", "DejaVu Sans"]  # 支持中文
plt.rcParams["axes.unicode_minus"] = False  # 支持负号显示

# ------------------ 参数配置 ------------------
IMAGE_PATH = r"D:\code313\Geo_programe\rasterio\RF\data\2017_09_05_stack.tif"
TRAIN_SHP = r"D:\code313\Geo_programe\rasterio\RF\data\cal.shp"
VAL_SHP = r"D:\code313\Geo_programe\rasterio\RF\data\val.shp"
CLASS_ATTRIBUTE = "class"  # 类别编号字段
NAME_ATTRIBUTE = "name"    # 类别名称字段
OUT_DIR = Path("./results_rioxarray")

CLASSIFIER = "rf"  # 可选: "rf", "svm", "xgb"
N_ESTIMATORS = 300
BLOCK_SIZE = 512
USE_GPU = False

# 后处理参数
POSTPROCESSING = True  # 是否进行后处理
MIN_PATCH_SIZE = 10    # 最小图斑大小（像元数），小于此值的图斑将被去除
MORPHOLOGY_OPERATION = "opening"  # 形态学操作: "opening"（开运算）, "closing"（闭运算）, "both"（两者都）, "none"（无）
MORPHOLOGY_SIZE = 3     # 形态学操作核大小

# 预定义颜色映射（可根据需要扩展）
LANDUSE_COLORS = {
    # 水体相关
    "水体": "lightblue",
    "河流": "blue",
    "湖泊": "deepskyblue",
    "水库": "dodgerblue",
    "海洋": "navy",
    
    # 植被相关
    "植被": "forestgreen",
    "森林": "darkgreen",
    "草地": "limegreen",
    "农田": "yellowgreen",
    "耕地": "olivedrab",
    
    # 建筑相关
    "建筑": "gray",
    "城市": "dimgray",
    "居民地": "slategray",
    "工业区": "darkgray",
    
    # 其他地物
    "裸地": "tan",
    "沙地": "wheat",
    "岩石": "sienna",
    "雪": "white",
    "云": "ghostwhite",
    
    # 默认颜色（如果上述未匹配）
    "其他": "darkred"
}

# 自动生成颜色配置（用于未匹配的类别）
COLOR_PALETTE = ['forestgreen', 'lightblue', 'gray', 'tan', 'yellow', 
                'darkred', 'purple', 'orange', 'pink', 'brown', 
                'cyan', 'magenta', 'lime', 'navy', 'teal']

OUT_DIR.mkdir(exist_ok=True)

# ------------------ 日志系统 ------------------
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    handlers=[
        logging.FileHandler(OUT_DIR / "classification_log.txt", encoding="utf-8"),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# ------------------ 辅助函数 ------------------
def get_class_info_from_shp(shp_path, class_attr, name_attr):
    """从shp文件中获取类别信息和自动生成的颜色"""
    gdf = gpd.read_file(shp_path)
    
    # 检查是否存在名称字段
    if name_attr not in gdf.columns:
        logger.warning(f"shp文件中没有找到 '{name_attr}' 字段，将使用类别编号作为名称")
        # 如果没有名称字段，使用类别编号作为名称
        gdf[name_attr] = gdf[class_attr].apply(lambda x: f"Class_{x}")
    
    # 获取唯一的类别编号和对应的名称
    class_info = gdf[[class_attr, name_attr]].drop_duplicates()
    class_names = dict(zip(class_info[class_attr], class_info[name_attr]))
    
    # 生成颜色映射
    class_colors = {}
    for i, (class_id, class_name) in enumerate(class_names.items()):
        # 尝试从预定义颜色中匹配
        color_found = False
        for key, color in LANDUSE_COLORS.items():
            if key in class_name:
                class_colors[class_id] = color
                color_found = True
                break
        
        # 如果没有匹配到预定义颜色，使用自动分配的颜色
        if not color_found:
            class_colors[class_id] = COLOR_PALETTE[i % len(COLOR_PALETTE)]
    
    unique_classes = sorted(class_names.keys())
    
    return class_names, class_colors, unique_classes

def rasterize_samples(shp, ref_img, attr):
    """将矢量样本栅格化为与影像对齐的数组"""
    import rasterio.features
    
    gdf = gpd.read_file(shp)
    gdf = gdf.to_crs(ref_img.rio.crs)
    shapes = ((geom, value) for geom, value in zip(gdf.geometry, gdf[attr]))
    
    arr = rasterio.features.rasterize(
        shapes=shapes,
        out_shape=ref_img.shape[1:],
        transform=ref_img.rio.transform(),
        fill=0,
        all_touched=True,
        dtype="uint16"
    )
    return arr

def extract_samples(image, mask):
    """根据掩膜提取样本特征与标签"""
    data = np.moveaxis(image.values, 0, -1)  # (bands, rows, cols) → (rows, cols, bands)
    valid = mask > 0
    X = data[valid]
    y = mask[valid]
    return X, y

def get_classifier(name):
    """构造分类器"""
    if name == "rf":
        return RandomForestClassifier(
            n_estimators=N_ESTIMATORS, n_jobs=-1, oob_score=True, verbose=1
        )
    elif name == "svm":
        return SVC(kernel="rbf", probability=True)
    elif name == "xgb":
        try:
            from xgboost import XGBClassifier
            return XGBClassifier(
                n_estimators=N_ESTIMATORS, learning_rate=0.1, max_depth=8, n_jobs=-1
            )
        except ImportError:
            raise ImportError("未安装 xgboost，请先运行 pip install xgboost")
    else:
        raise ValueError(f"未知分类器类型: {name}")

def plot_confusion_matrix(y_true, y_pred, class_names, save_path):
    """绘制详细的混淆矩阵"""
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(10, 8))
    
    # 计算百分比
    cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100
    
    # 创建热图
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names,
                cbar_kws={'label': '样本数量'})
    
    plt.xlabel('预测类别', fontsize=12)
    plt.ylabel('真实类别', fontsize=12)
    plt.title('混淆矩阵', fontsize=14, fontweight='bold')
    plt.xticks(rotation=45)
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    # 同时保存百分比版本
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm_percent, annot=True, fmt='.1f', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names,
                cbar_kws={'label': '百分比 (%)'})
    plt.xlabel('预测类别', fontsize=12)
    plt.ylabel('真实类别', fontsize=12)
    plt.title('混淆矩阵 (百分比)', fontsize=14, fontweight='bold')
    plt.xticks(rotation=45)
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig(str(save_path).replace('.png', '_percent.png'), dpi=300, bbox_inches='tight')
    plt.close()

def comprehensive_evaluation(y_true, y_pred, class_names, save_path):
    """全方位精度评价"""
    # 计算各项指标
    overall_accuracy = accuracy_score(y_true, y_pred)
    kappa = cohen_kappa_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average=None, zero_division=0)
    recall = recall_score(y_true, y_pred, average=None, zero_division=0)
    f1 = f1_score(y_true, y_pred, average=None, zero_division=0)
    
    # 创建详细报告
    report = classification_report(y_true, y_pred, target_names=class_names, digits=4)
    
    # 创建精度评价表格
    eval_df = pd.DataFrame({
        '类别': class_names,
        '精确率 (Precision)': precision,
        '召回率 (Recall)': recall,
        'F1分数': f1,
        '样本数量': np.bincount(y_true)[1:len(class_names)+1]  # 从1开始计数
    })
    
    # 保存详细报告
    with open(save_path, 'w', encoding='utf-8') as f:
        f.write("="*60 + "\n")
        f.write("           遥感影像分类精度评价报告\n")
        f.write("="*60 + "\n\n")
        
        f.write(f"总体精度 (Overall Accuracy): {overall_accuracy:.4f}\n")
        f.write(f"Kappa系数: {kappa:.4f}\n\n")
        
        f.write("各类别精度评价:\n")
        f.write("-"*60 + "\n")
        f.write(eval_df.to_string(index=False, float_format='%.4f'))
        f.write("\n\n")
        
        f.write("详细分类报告:\n")
        f.write("-"*60 + "\n")
        f.write(report)
    
    # 绘制精度指标条形图
    plt.figure(figsize=(12, 6))
    x = np.arange(len(class_names))
    width = 0.25
    
    plt.bar(x - width, precision, width, label='精确率', alpha=0.8)
    plt.bar(x, recall, width, label='召回率', alpha=0.8)
    plt.bar(x + width, f1, width, label='F1分数', alpha=0.8)
    
    plt.xlabel('地物类别')
    plt.ylabel('分数')
    plt.title('各类别分类精度指标')
    plt.xticks(x, class_names, rotation=45)
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(str(save_path).replace('.txt', '_metrics.png'), dpi=300, bbox_inches='tight')
    plt.close()
    
    return overall_accuracy, kappa, eval_df

def plot_classification_results(original_img, classified_img, class_names, class_colors, save_path, title_suffix=""):
    """显示原始影像和分类结果"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))
    
    # 显示原始影像 (使用前3个波段作为RGB)
    if original_img.shape[0] >= 3:
        rgb_data = np.moveaxis(original_img.values[:3], 0, -1)
        # 数据标准化显示
        p2, p98 = np.percentile(rgb_data, (2, 98))
        rgb_display = np.clip((rgb_data - p2) / (p98 - p2), 0, 1)
        ax1.imshow(rgb_display)
    else:
        # 单波段影像显示
        ax1.imshow(original_img.values[0], cmap='gray')
    
    ax1.set_title('原始遥感影像', fontsize=14, fontweight='bold')
    ax1.axis('off')
    
    # 显示分类结果
    classified_data = classified_img.values.squeeze()
    
    # 创建分类图例
    classes = np.unique(classified_data)
    classes = classes[classes > 0]  # 排除背景值
    
    # 创建颜色映射
    colors = [class_colors.get(c, 'black') for c in classes]
    labels = [class_names.get(c, f'未知类别_{c}') for c in classes]
    
    cmap = mcolors.ListedColormap(colors)
    bounds = np.append(classes, classes[-1] + 1) - 0.5
    norm = mcolors.BoundaryNorm(bounds, cmap.N)
    
    im = ax2.imshow(classified_data, cmap=cmap, norm=norm)
    title = '分类结果' + title_suffix
    ax2.set_title(title, fontsize=14, fontweight='bold')
    ax2.axis('off')
    
    # 添加图例
    from matplotlib.patches import Patch
    legend_elements = [Patch(facecolor=color, label=label) 
                      for color, label in zip(colors, labels)]
    ax2.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.35, 1))
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

def plot_feature_importance(clf, feature_names, save_path):
    """绘制特征重要性图（适用于随机森林和XGBoost）"""
    if hasattr(clf, 'feature_importances_'):
        importances = clf.feature_importances_
        indices = np.argsort(importances)[::-1]
        
        plt.figure(figsize=(10, 6))
        plt.title('特征重要性排序', fontsize=14, fontweight='bold')
        plt.bar(range(len(importances)), importances[indices])
        plt.xticks(range(len(importances)), [feature_names[i] for i in indices], rotation=45)
        plt.xlabel('特征波段')
        plt.ylabel('重要性')
        plt.tight_layout()
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()

def calculate_pixel_area(transform):
    """计算单个像元的面积（单位：平方米）"""
    # 获取像元尺寸（通常为米）
    pixel_width = abs(transform[0])  # x方向分辨率
    pixel_height = abs(transform[4])  # y方向分辨率
    
    # 计算单个像元面积（平方米）
    pixel_area = pixel_width * pixel_height
    
    return pixel_area

def predict_by_block_rioxarray(model, image, out_path, block_size=BLOCK_SIZE):
    """使用rioxarray进行分块预测整幅影像"""
    height, width = image.shape[1], image.shape[2]
    
    # 创建空的结果数组
    result_data = np.zeros((height, width), dtype=np.uint16)
    
    # 分块处理
    for y in tqdm(range(0, height, block_size), desc="Block predicting"):
        y_end = min(y + block_size, height)
        
        for x in range(0, width, block_size):
            x_end = min(x + block_size, width)
            
            # 读取当前块的数据
            block_data = image.isel(y=slice(y, y_end), x=slice(x, x_end)).values
            data = np.moveaxis(block_data, 0, -1)
            original_shape = data.shape[:2]
            data_flat = data.reshape(-1, data.shape[-1])
            data_flat = np.nan_to_num(data_flat)
            
            # 预测
            preds = model.predict(data_flat).reshape(original_shape).astype("uint16")
            
            # 将结果写入到对应的位置
            result_data[y:y_end, x:x_end] = preds
    
    # 使用xarray创建DataArray (rioxarray是基于xarray的)
    result_da = xr.DataArray(
        result_data.reshape(1, height, width),  # 添加波段维度
        coords={
            'band': [1],
            'y': image.y,
            'x': image.x
        },
        dims=['band', 'y', 'x'],
        attrs=image.attrs
    )
    
    # 设置空间参考信息
    result_da.rio.write_crs(image.rio.crs, inplace=True)
    result_da.rio.write_transform(image.rio.transform(), inplace=True)
    result_da.rio.write_nodata(0, inplace=True)
    
    # 保存为GeoTIFF
    result_da.rio.to_raster(
        out_path,
        driver='GTiff',
        dtype='uint16',
        compress='lzw',
        tiled=True,
        blockxsize=min(block_size, width),
        blockysize=min(block_size, block_size)
    )
    
    return out_path

def postprocess_classification(classified_data, min_patch_size=10, morphology_op="opening", morphology_size=3):
    """
    后处理分类结果
    
    参数:
    - classified_data: 分类结果数组
    - min_patch_size: 最小图斑大小（像元数）
    - morphology_op: 形态学操作类型 ("opening", "closing", "both", "none")
    - morphology_size: 形态学操作核大小
    
    返回:
    - 后处理后的分类结果
    """
    logger.info("开始后处理分类结果...")
    processed_data = classified_data.copy()
    
    # 获取所有类别（排除背景0）
    classes = np.unique(classified_data)
    classes = classes[classes > 0]
    
    # 对每个类别进行后处理
    for class_id in classes:
        # 创建二值掩膜
        binary_mask = (classified_data == class_id).astype(np.uint8)
        
        # 去除小图斑
        if min_patch_size > 0:
            # 使用连通组件分析标记图斑
            labeled_array, num_features = ndimage.label(binary_mask)
            
            # 计算每个图斑的大小
            component_sizes = np.bincount(labeled_array.ravel())
            
            # 创建掩膜，只保留大于最小图斑大小的区域
            size_mask = component_sizes >= min_patch_size
            size_mask[0] = 0  # 背景
            
            # 应用大小过滤
            binary_mask = size_mask[labeled_array]
        
        # 形态学操作
        if morphology_op != "none" and morphology_size > 0:
            # 创建结构元素
            structure = np.ones((morphology_size, morphology_size), dtype=np.uint8)
            
            if morphology_op == "opening":
                binary_mask = morphology.binary_opening(binary_mask, structure)
            elif morphology_op == "closing":
                binary_mask = morphology.binary_closing(binary_mask, structure)
            elif morphology_op == "both":
                binary_mask = morphology.binary_opening(binary_mask, structure)
                binary_mask = morphology.binary_closing(binary_mask, structure)
        
        # 更新分类结果
        processed_data[binary_mask > 0] = class_id
        # 将去除的小图斑区域设为背景（0）
        processed_data[(classified_data == class_id) & (binary_mask == 0)] = 0
    
    # 统计后处理变化
    original_nonzero = np.count_nonzero(classified_data)
    processed_nonzero = np.count_nonzero(processed_data)
    change_percent = (original_nonzero - processed_nonzero) / original_nonzero * 100
    
    logger.info(f"后处理完成: 原始非零像元数 {original_nonzero}, 处理后非零像元数 {processed_nonzero}")
    logger.info(f"后处理去除了 {original_nonzero - processed_nonzero} 个像元 ({change_percent:.2f}%)")
    
    return processed_data

def save_classification_result_rioxarray(data, ref_image, out_path):
    """使用rioxarray保存分类结果到GeoTIFF文件"""
    
    # 使用xarray创建DataArray
    result_da = xr.DataArray(
        data.reshape(1, data.shape[0], data.shape[1]),  # 添加波段维度
        coords={
            'band': [1],
            'y': ref_image.y,
            'x': ref_image.x
        },
        dims=['band', 'y', 'x'],
        attrs=ref_image.attrs
    )
    
    # 设置空间参考信息
    result_da.rio.write_crs(ref_image.rio.crs, inplace=True)
    result_da.rio.write_transform(ref_image.rio.transform(), inplace=True)
    result_da.rio.write_nodata(0, inplace=True)
    
    # 保存为GeoTIFF
    result_da.rio.to_raster(
        out_path,
        driver='GTiff',
        dtype='uint16',
        compress='lzw',
        tiled=True
    )
    
    return out_path

def calculate_area_statistics(classified_data, class_names, class_colors, pixel_area_km2, suffix=""):
    """
    计算分类面积统计
    
    参数:
    - classified_data: 分类结果数组
    - class_names: 类别名称字典
    - class_colors: 类别颜色字典
    - pixel_area_km2: 单个像元面积（平方千米）
    - suffix: 文件名后缀
    
    返回:
    - stats_df: 统计DataFrame
    - total_area_km2: 总面积
    """
    # 获取类别和数量
    unique, counts = np.unique(classified_data[classified_data > 0], return_counts=True)
    total_pixels = np.sum(counts)
    
    # 计算各类别面积
    areas_km2 = [count * pixel_area_km2 for count in counts]
    total_area_km2 = np.sum(areas_km2)
    
    # 创建统计表格
    stats_df = pd.DataFrame({
        '类别编号': unique,
        '类别名称': [class_names.get(c, f'未知类别_{c}') for c in unique],
        '像元数量': counts,
        '面积(km²)': [round(area, 4) for area in areas_km2],
        '面积占比 (%)': (counts / total_pixels * 100).round(2)
    })
    
    # 保存统计表格
    stats_filename = f"classification_statistics{suffix}.csv"
    stats_df.to_csv(OUT_DIR / stats_filename, index=False, encoding='utf-8-sig')
    
    # 绘制面积占比饼图
    plt.figure(figsize=(10, 8))
    plt.pie(stats_df['面积占比 (%)'], labels=stats_df['类别名称'], autopct='%1.1f%%', startangle=90)
    plt.title(f'分类结果面积占比分布{suffix}', fontsize=14, fontweight='bold')
    plt.axis('equal')
    plt.tight_layout()
    plt.savefig(OUT_DIR / f"area_distribution{suffix}.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    # 绘制面积柱状图
    plt.figure(figsize=(12, 6))
    plt.bar(stats_df['类别名称'], stats_df['面积(km²)'], 
            color=[class_colors.get(c, 'gray') for c in unique])
    plt.xlabel('地物类别')
    plt.ylabel('面积 (km²)')
    plt.title(f'各类别面积统计{suffix}', fontsize=14, fontweight='bold')
    plt.xticks(rotation=45)
    plt.grid(True, alpha=0.3)
    
    # 在柱状图上添加数值标签
    for i, v in enumerate(stats_df['面积(km²)']):
        plt.text(i, v + max(areas_km2)*0.01, f'{v:.2f}', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.savefig(OUT_DIR / f"area_bar_chart{suffix}.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    # 生成面积统计摘要
    with open(OUT_DIR / f"area_summary{suffix}.txt", "w", encoding="utf-8") as f:
        f.write(f"分类面积统计摘要{suffix}\n")
        f.write("="*50 + "\n")
        f.write(f"总分类面积: {total_area_km2:.4f} km²\n")
        f.write(f"像元大小: {pixel_area_km2 * 1e6:.2f} 平方米\n")
        f.write(f"总像元数: {total_pixels}\n\n")
        
        f.write("各类别面积统计:\n")
        f.write("-"*50 + "\n")
        for _, row in stats_df.iterrows():
            f.write(f"{row['类别名称']}: {row['面积(km²)']:.4f} km² ({row['面积占比 (%)']}%)\n")
    
    return stats_df, total_area_km2

# ------------------ 主流程 ------------------
def main():
    t0 = time.time()
    logger.info("开始监督分类任务...")

    # 0. 从训练样本shp文件中获取类别信息
    logger.info("正在读取类别信息...")
    class_names, class_colors, train_classes = get_class_info_from_shp(TRAIN_SHP, CLASS_ATTRIBUTE, NAME_ATTRIBUTE)
    logger.info(f"检测到类别: {list(class_names.values())}")

    # 1. 读取影像
    img = rxr.open_rasterio(IMAGE_PATH, masked=True)
    logger.info(f"影像尺寸: {img.shape}, 波段数: {img.rio.count}")
    
    # 获取影像的空间参考信息
    transform = img.rio.transform()
    crs = img.rio.crs
    logger.info(f"影像坐标系: {crs}")
    logger.info(f"影像变换参数: {transform}")

    # 计算像元面积
    pixel_area_m2 = calculate_pixel_area(transform)
    pixel_area_km2 = pixel_area_m2 / 1e6  # 转换为平方千米
    logger.info(f"单个像元面积: {pixel_area_m2:.2f} 平方米 ({pixel_area_km2:.6f} 平方千米)")

    # 2. 训练样本栅格化与提取
    logger.info("正在处理训练样本...")
    train_mask = rasterize_samples(TRAIN_SHP, img, CLASS_ATTRIBUTE)
    X_train, y_train = extract_samples(img, train_mask)
    logger.info(f"训练样本数: {len(y_train)}")

    # 3. 训练分类器
    clf = get_classifier(CLASSIFIER)
    logger.info(f"使用分类器: {clf.__class__.__name__}")
    clf.fit(X_train, y_train)
    logger.info("模型训练完成。")

    # 4. 精度评估（训练集）
    y_pred_train = clf.predict(X_train)
    
    # 获取实际存在的类别
    actual_train_classes = sorted(np.unique(y_train))
    train_class_names = [class_names.get(c, f'未知类别_{c}') for c in actual_train_classes if c > 0]
    
    # 全方位精度评价
    overall_acc, kappa, eval_df = comprehensive_evaluation(
        y_train, y_pred_train, train_class_names, OUT_DIR / "train_evaluation.txt"
    )
    logger.info(f"训练集总体精度: {overall_acc:.4f}, Kappa: {kappa:.4f}")
    
    # 绘制训练集混淆矩阵
    plot_confusion_matrix(y_train, y_pred_train, train_class_names, OUT_DIR / "train_cm.png")

    # 5. 特征重要性分析（如果适用）
    if hasattr(clf, 'feature_importances_'):
        feature_names = [f'波段{i+1}' for i in range(X_train.shape[1])]
        plot_feature_importance(clf, feature_names, OUT_DIR / "feature_importance.png")

    # 6. 分块预测整幅影像 (使用rioxarray)
    logger.info("开始分块预测...")
    classified_path = OUT_DIR / "classified_result.tif"
    predict_by_block_rioxarray(clf, img, classified_path)
    logger.info(f"分类结果保存至: {classified_path}")

    # 7. 显示原始分类结果
    logger.info("生成原始分类结果可视化...")
    classified_img = rxr.open_rasterio(classified_path)
    plot_classification_results(img, classified_img, class_names, class_colors, 
                               OUT_DIR / "classification_visualization.png", " (原始)")

    # 8. 原始分类结果面积统计
    logger.info("计算原始分类结果面积统计...")
    original_classified_data = classified_img.values.squeeze()
    original_stats_df, original_total_area = calculate_area_statistics(
        original_classified_data, class_names, class_colors, pixel_area_km2, "_original"
    )
    logger.info(f"原始分类总面积: {original_total_area:.4f} 平方千米")

    # 9. 后处理
    if POSTPROCESSING:
        logger.info("开始后处理...")
        logger.info(f"后处理参数: 最小图斑大小={MIN_PATCH_SIZE}, 形态学操作={MORPHOLOGY_OPERATION}, 核大小={MORPHOLOGY_SIZE}")
        
        # 进行后处理
        processed_data = postprocess_classification(
            original_classified_data, 
            min_patch_size=MIN_PATCH_SIZE,
            morphology_op=MORPHOLOGY_OPERATION,
            morphology_size=MORPHOLOGY_SIZE
        )
        
        # 保存后处理结果 (使用rioxarray)
        processed_path = OUT_DIR / "classified_result_postprocessed.tif"
        save_classification_result_rioxarray(processed_data, img, processed_path)
        logger.info(f"后处理结果保存至: {processed_path}")
        
        # 显示后处理分类结果
        logger.info("生成后处理分类结果可视化...")
        processed_img = rxr.open_rasterio(processed_path)
        plot_classification_results(img, processed_img, class_names, class_colors,
                                   OUT_DIR / "classification_visualization_postprocessed.png", " (后处理)")
        
        # 后处理分类结果面积统计
        logger.info("计算后处理分类结果面积统计...")
        processed_stats_df, processed_total_area = calculate_area_statistics(
            processed_data, class_names, class_colors, pixel_area_km2, "_postprocessed"
        )
        logger.info(f"后处理分类总面积: {processed_total_area:.4f} 平方千米")
        
        # 生成后处理变化报告
        area_change = processed_total_area - original_total_area
        area_change_percent = (area_change / original_total_area) * 100
        
        with open(OUT_DIR / "postprocessing_report.txt", "w", encoding="utf-8") as f:
            f.write("后处理变化报告\n")
            f.write("="*50 + "\n")
            f.write(f"后处理参数:\n")
            f.write(f"  最小图斑大小: {MIN_PATCH_SIZE} 像元\n")
            f.write(f"  形态学操作: {MORPHOLOGY_OPERATION}\n")
            f.write(f"  核大小: {MORPHOLOGY_SIZE}\n\n")
            
            f.write(f"面积变化:\n")
            f.write(f"  原始总面积: {original_total_area:.4f} km²\n")
            f.write(f"  后处理总面积: {processed_total_area:.4f} km²\n")
            f.write(f"  面积变化: {area_change:+.4f} km² ({area_change_percent:+.2f}%)\n\n")
            
            f.write("各类别面积变化:\n")
            f.write("-"*50 + "\n")
            for class_id in class_names.keys():
                if class_id in original_stats_df['类别编号'].values and class_id in processed_stats_df['类别编号'].values:
                    orig_area = original_stats_df[original_stats_df['类别编号'] == class_id]['面积(km²)'].values[0]
                    proc_area = processed_stats_df[processed_stats_df['类别编号'] == class_id]['面积(km²)'].values[0]
                    change = proc_area - orig_area
                    change_percent = (change / orig_area) * 100 if orig_area > 0 else 0
                    f.write(f"{class_names[class_id]}: {orig_area:.4f} → {proc_area:.4f} km² ({change:+.4f}, {change_percent:+.2f}%)\n")
                elif class_id in original_stats_df['类别编号'].values:
                    orig_area = original_stats_df[original_stats_df['类别编号'] == class_id]['面积(km²)'].values[0]
                    f.write(f"{class_names[class_id]}: {orig_area:.4f} → 0.0000 km² (完全去除)\n")
                elif class_id in processed_stats_df['类别编号'].values:
                    proc_area = processed_stats_df[processed_stats_df['类别编号'] == class_id]['面积(km²)'].values[0]
                    f.write(f"{class_names[class_id]}: 0.0000 → {proc_area:.4f} km² (新增)\n")

    # 10. 验证阶段
    if os.path.exists(VAL_SHP):
        logger.info("正在进行验证...")
        val_mask = rasterize_samples(VAL_SHP, img, CLASS_ATTRIBUTE)
        
        # 使用原始分类结果进行验证
        with rxr.open_rasterio(classified_path) as pred_img:
            pred_arr = pred_img.values.squeeze()
        
        Xv = pred_arr[val_mask > 0]
        yv = val_mask[val_mask > 0]
        
        # 验证集全方位精度评价
        val_classes = sorted(np.unique(yv))
        val_class_names = [class_names.get(c, f'未知类别_{c}') for c in val_classes if c > 0]
        
        val_overall_acc, val_kappa, val_eval_df = comprehensive_evaluation(
            yv, Xv, val_class_names, OUT_DIR / "validation_evaluation.txt"
        )
        logger.info(f"验证集总体精度: {val_overall_acc:.4f}, Kappa: {val_kappa:.4f}")
        
        # 绘制验证集混淆矩阵
        plot_confusion_matrix(yv, Xv, val_class_names, OUT_DIR / "val_cm.png")

        # 生成综合报告
        with open(OUT_DIR / "comprehensive_report.txt", "w", encoding="utf-8") as f:
            f.write("遥感影像分类综合报告\n")
            f.write("="*50 + "\n")
            f.write(f"分类器: {clf.__class__.__name__}\n")
            f.write(f"训练样本数: {len(y_train)}\n")
            f.write(f"验证样本数: {len(yv)}\n")
            f.write(f"类别编号字段: {CLASS_ATTRIBUTE}\n")
            f.write(f"类别名称字段: {NAME_ATTRIBUTE}\n")
            f.write(f"检测到的类别: {list(class_names.values())}\n")
            f.write(f"像元面积: {pixel_area_m2:.2f} 平方米\n")
            f.write(f"后处理: {'是' if POSTPROCESSING else '否'}\n\n")
            
            f.write("精度评价汇总:\n")
            f.write("-"*30 + "\n")
            f.write(f"训练集总体精度: {overall_acc:.4f}\n")
            f.write(f"训练集Kappa系数: {kappa:.4f}\n")
            f.write(f"验证集总体精度: {val_overall_acc:.4f}\n")
            f.write(f"验证集Kappa系数: {val_kappa:.4f}\n\n")
            
            f.write("各类别验证精度:\n")
            f.write("-"*30 + "\n")
            f.write(val_eval_df.to_string(index=False, float_format='%.4f'))

    # 11. 保存类别信息
    class_info_df = pd.DataFrame({
        '类别编号': list(class_names.keys()),
        '类别名称': list(class_names.values()),
        '显示颜色': [class_colors.get(c, 'black') for c in class_names.keys()]
    })
    class_info_df.to_csv(OUT_DIR / "class_information.csv", index=False, encoding='utf-8-sig')
    
    logger.info(f"全部任务完成，用时 {time.time()-t0:.1f} 秒。")
    logger.info(f"所有结果已保存至: {OUT_DIR.absolute()}")

if __name__ == "__main__":
    main()

2025-10-16 08:42:55,803 [INFO] 开始监督分类任务...
2025-10-16 08:42:55,804 [INFO] 正在读取类别信息...
2025-10-16 08:42:55,816 [INFO] 检测到类别: ['类1', '类2', '类3', '类4', '类5', '类6', '类7', '类8', '类9', '类10', '类11']
2025-10-16 08:42:55,825 [INFO] 影像尺寸: (14, 1024, 2098), 波段数: 14
2025-10-16 08:42:55,828 [INFO] 影像坐标系: EPSG:32633
2025-10-16 08:42:55,829 [INFO] 影像变换参数: | 0.20, 0.00, 351916.64|
| 0.00,-0.20, 5997247.36|
| 0.00, 0.00, 1.00|
2025-10-16 08:42:55,830 [INFO] 单个像元面积: 0.04 平方米 (0.000000 平方千米)
2025-10-16 08:42:55,831 [INFO] 正在处理训练样本...
2025-10-16 08:42:56,080 [INFO] 训练样本数: 15041
2025-10-16 08:42:56,082 [INFO] 使用分类器: RandomForestClassifier
[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 48 concurrent workers.
[Parallel(n_jobs=-1)]: Done 104 tasks      | elapsed:    0.5s
[Parallel(n_jobs=-1)]: Done 300 out of 300 | elapsed:    1.2s finished
2025-10-16 08:42:58,063 [INFO] 模型训练完成。
[Parallel(n_jobs=48)]: Using backend ThreadingBackend with 48 concurrent workers.
[Parallel(n_jobs=48)]: Done 104 tasks