# 版本1

In [6]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
遥感影像监督分类系统 - GUI版本 (完全优化版)
==============================================
版本: v3.0
作者: AI Assistant
日期: 2024

主要特性:
- 12种分类器 + 5种SVM变体
- 智能性能优化（采样、特征缩放、快速模式）
- SVM速度优化（线性核、SGD、核近似）
- 预测时间估算和性能警告
- 实时进度显示和详细日志
- 自动生成对比报告
"""

import os
import sys
import time
import threading
import queue
from pathlib import Path
import tkinter as tk
from tkinter import ttk, filedialog, messagebox, scrolledtext
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure
import matplotlib.colors as mcolors
import seaborn as sns
import geopandas as gpd
import rioxarray as rxr
import xarray as xr
from rasterio import features
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.ensemble import (RandomForestClassifier, GradientBoostingClassifier, 
                              AdaBoostClassifier, ExtraTreesClassifier)
from sklearn.svm import SVC, LinearSVC
from sklearn.linear_model import SGDClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier
from sklearn.calibration import CalibratedClassifierCV
from sklearn.kernel_approximation import Nystroem, RBFSampler
from sklearn.pipeline import Pipeline
from sklearn.metrics import (confusion_matrix, accuracy_score, classification_report, 
                           cohen_kappa_score, precision_score, recall_score, f1_score)
from tqdm import tqdm
from scipy import ndimage
from skimage import morphology
import warnings
warnings.filterwarnings('ignore')

plt.rcParams["font.sans-serif"] = ["SimHei", "DejaVu Sans"]
plt.rcParams["axes.unicode_minus"] = False

# ==================== 后端处理类 ====================
class ClassificationBackend:
    """分类处理后端（完全优化版）"""
    
    def __init__(self):
        self.BACKGROUND_VALUE = 0
        self.RANDOM_STATE = 42
        
        # 预定义颜色
        self.LANDUSE_COLORS = {
            "水体": "lightblue", "河流": "blue", "湖泊": "deepskyblue",
            "植被": "forestgreen", "森林": "darkgreen", "草地": "limegreen",
            "农田": "yellowgreen", "耕地": "olivedrab",
            "建筑": "gray", "城市": "dimgray", "居民地": "slategray",
            "裸地": "tan", "沙地": "wheat", "其他": "darkred"
        }
        
        self.COLOR_PALETTE = ['forestgreen', 'lightblue', 'gray', 'tan', 'yellow', 
                             'darkred', 'purple', 'orange', 'pink', 'brown']
    
    def get_all_classifiers(self, n_estimators=100, fast_mode=False, n_train_samples=None):
        """
        获取所有分类器（包含SVM优化版本）
        
        返回格式: {code: (classifier, name, desc, needs_encoding, needs_scaling, speed_tag)}
        speed_tag: "very_fast", "fast", "medium", "slow", "very_slow"
        """
        # 根据模式调整参数
        if fast_mode:
            n_est = min(50, n_estimators)
            max_depth = 10
            max_iter = 200
        else:
            n_est = n_estimators
            max_depth = 20
            max_iter = 500
        
        # 核近似的组件数
        if n_train_samples:
            n_components = min(1000, n_train_samples // 2)
        else:
            n_components = 1000
        
        classifiers = {
            # ===== 树模型系列（速度快） =====
            "rf": (
                RandomForestClassifier(
                    n_estimators=n_est, 
                    n_jobs=-1, 
                    random_state=self.RANDOM_STATE, 
                    verbose=0,
                    max_depth=max_depth,
                    min_samples_split=5,
                    min_samples_leaf=2,
                    max_features='sqrt'
                ),
                "随机森林", "Random Forest - 稳定可靠的集成学习", 
                False, False, "fast"
            ),
            
            "et": (
                ExtraTreesClassifier(
                    n_estimators=n_est,
                    n_jobs=-1,
                    random_state=self.RANDOM_STATE,
                    verbose=0,
                    max_depth=max_depth,
                    min_samples_split=5,
                    max_features='sqrt'
                ),
                "极端随机树", "Extra Trees - 更快的随机森林", 
                False, False, "fast"
            ),
            
            "dt": (
                DecisionTreeClassifier(
                    random_state=self.RANDOM_STATE,
                    max_depth=max_depth,
                    min_samples_split=5,
                    min_samples_leaf=2
                ),
                "决策树", "Decision Tree - 简单快速", 
                False, False, "very_fast"
            ),
            
            # ===== SVM系列（多种优化版本） =====
            "svm_linear": (
                SVC(
                    kernel="linear",
                    C=1.0,
                    cache_size=500,
                    probability=True, 
                    random_state=self.RANDOM_STATE,
                    max_iter=max_iter
                ),
                "SVM-线性核", "SVM Linear - 线性可分问题", 
                False, True, "medium"
            ),
            
            "linear_svc": (
                CalibratedClassifierCV(
                    LinearSVC(
                        C=1.0,
                        max_iter=max_iter,
                        random_state=self.RANDOM_STATE,
                        dual=False,
                        loss='squared_hinge'
                    ),
                    cv=3
                ),
                "线性SVM(快)", "Linear SVM - 快速线性分类器", 
                False, True, "fast"
            ),
            
            "sgd_svm": (
                SGDClassifier(
                    loss='hinge',
                    penalty='l2',
                    max_iter=max_iter,
                    n_jobs=-1,
                    random_state=self.RANDOM_STATE,
                    learning_rate='optimal'
                ),
                "SGD-SVM", "SGD SVM - 极快的线性SVM", 
                False, True, "very_fast"
            ),
            
            "nystroem_svm": (
                Pipeline([
                    ("feature_map", Nystroem(
                        kernel='rbf',
                        gamma=0.1,
                        n_components=n_components,
                        random_state=self.RANDOM_STATE
                    )),
                    ("sgd", SGDClassifier(
                        max_iter=max_iter,
                        random_state=self.RANDOM_STATE
                    ))
                ]),
                "核近似SVM", "Nystroem SVM - RBF核的快速近似", 
                False, True, "fast"
            ),
            
            "rbf_sampler_svm": (
                Pipeline([
                    ("feature_map", RBFSampler(
                        gamma=0.1,
                        n_components=n_components,
                        random_state=self.RANDOM_STATE
                    )),
                    ("sgd", SGDClassifier(
                        max_iter=max_iter,
                        random_state=self.RANDOM_STATE
                    ))
                ]),
                "RBF采样SVM", "RBF Sampler SVM - 另一种RBF近似", 
                False, True, "fast"
            ),
            
            "svm_rbf": (
                SVC(
                    kernel="rbf", 
                    C=1.0,
                    gamma='scale',
                    cache_size=500,
                    probability=True, 
                    random_state=self.RANDOM_STATE
                ),
                "SVM-RBF核⚠️", "SVM RBF - 高精度但预测极慢", 
                False, True, "very_slow"
            ),
            
            # ===== 其他分类器 =====
            "knn": (
                KNeighborsClassifier(
                    n_neighbors=5,
                    n_jobs=-1,
                    algorithm='ball_tree',
                    leaf_size=30
                ),
                "K近邻", "KNN - 基于距离的分类器", 
                False, True, "slow"
            ),
            
            "nb": (
                GaussianNB(),
                "朴素贝叶斯", "Naive Bayes - 最快的概率分类器", 
                False, False, "very_fast"
            ),
            
            "gb": (
                GradientBoostingClassifier(
                    n_estimators=n_est,
                    learning_rate=0.1,
                    max_depth=5,
                    random_state=self.RANDOM_STATE, 
                    verbose=0,
                    subsample=0.8
                ),
                "梯度提升", "Gradient Boosting - 强大的集成方法", 
                False, False, "medium"
            ),
            
            "ada": (
                AdaBoostClassifier(
                    n_estimators=n_est,
                    learning_rate=1.0,
                    random_state=self.RANDOM_STATE,
                    algorithm='SAMME.R'
                ),
                "AdaBoost", "AdaBoost - 自适应提升", 
                False, False, "medium"
            ),
            
            "lr": (
                LogisticRegression(
                    max_iter=max_iter,
                    n_jobs=-1,
                    random_state=self.RANDOM_STATE,
                    verbose=0,
                    solver='lbfgs',
                    multi_class='multinomial'
                ),
                "逻辑回归", "Logistic Regression - 经典线性分类器", 
                False, True, "very_fast"
            ),
            
            "mlp": (
                MLPClassifier(
                    hidden_layer_sizes=(100, 50),
                    max_iter=max_iter,
                    random_state=self.RANDOM_STATE,
                    verbose=False,
                    early_stopping=True,
                    validation_fraction=0.1,
                    n_iter_no_change=10,
                    learning_rate='adaptive'
                ),
                "神经网络", "MLP - 前馈神经网络", 
                False, True, "medium"
            ),
        }
        
        # XGBoost
        try:
            from xgboost import XGBClassifier
            classifiers["xgb"] = (
                XGBClassifier(
                    n_estimators=n_est,
                    learning_rate=0.1,
                    max_depth=6,
                    n_jobs=-1,
                    random_state=self.RANDOM_STATE,
                    verbosity=0,
                    tree_method='hist',
                    subsample=0.8,
                    colsample_bytree=0.8
                ),
                "XGBoost", "XGBoost - 高性能梯度提升", 
                True, False, "fast"
            )
        except ImportError:
            pass
        
        # 将第328-347行的 LightGBM 导入部分替换为：

        # LightGBM
        try:
            # 尝试导入，如果失败则跳过
            import lightgbm
            from lightgbm import LGBMClassifier
            classifiers["lgb"] = (
                LGBMClassifier(
                    n_estimators=n_est,
                    learning_rate=0.1,
                    max_depth=max_depth,
                    n_jobs=-1,
                    random_state=self.RANDOM_STATE,
                    verbose=-1,
                    num_leaves=31,
                    subsample=0.8,
                    colsample_bytree=0.8
                ),
                "LightGBM", "LightGBM - 极速梯度提升", 
                False, False, "very_fast"
            )
        except (ImportError, AttributeError) as e:
            # 捕获导入错误和属性错误
            pass
        
        return classifiers
    
    def get_background_mask(self, image):
        """获取背景掩膜"""
        data = image.values
        background_mask = np.all(data == 0, axis=0)
        return background_mask
    
    def get_class_info_from_shp(self, shp_path, class_attr, name_attr):
        """从shp文件获取类别信息"""
        gdf = gpd.read_file(shp_path)
        
        if name_attr not in gdf.columns:
            gdf[name_attr] = gdf[class_attr].apply(lambda x: f"Class_{x}")
        
        class_info = gdf[[class_attr, name_attr]].drop_duplicates()
        class_names = dict(zip(class_info[class_attr], class_info[name_attr]))
        
        class_colors = {}
        for i, (class_id, class_name) in enumerate(class_names.items()):
            color_found = False
            for key, color in self.LANDUSE_COLORS.items():
                if key in class_name:
                    class_colors[class_id] = color
                    color_found = True
                    break
            if not color_found:
                class_colors[class_id] = self.COLOR_PALETTE[i % len(self.COLOR_PALETTE)]
        
        return class_names, class_colors, sorted(class_names.keys())
    
    def rasterize_samples(self, shp, ref_img, attr):
        """矢量栅格化"""
        gdf = gpd.read_file(shp)
        gdf = gdf.to_crs(ref_img.rio.crs)
        shapes = ((geom, value) for geom, value in zip(gdf.geometry, gdf[attr]))
        
        arr = features.rasterize(
            shapes=shapes,
            out_shape=ref_img.shape[1:],
            transform=ref_img.rio.transform(),
            fill=0,
            all_touched=True,
            dtype="uint16"
        )
        return arr
    
    def extract_samples(self, image, mask, ignore_background=True, max_samples=None):
        """
        提取样本并清理NaN值
        max_samples: 最大样本数，如果超过则进行分层采样
        """
        data = np.moveaxis(image.values, 0, -1)
        valid = mask > 0
        
        if ignore_background:
            background_mask = self.get_background_mask(image)
            valid = valid & (~background_mask)
        
        X = data[valid]
        y = mask[valid]
        
        # 清理NaN和Inf值
        nan_mask = np.isnan(X).any(axis=1)
        inf_mask = np.isinf(X).any(axis=1)
        bad_mask = nan_mask | inf_mask
        
        n_nan = np.sum(nan_mask)
        n_inf = np.sum(inf_mask)
        
        X = X[~bad_mask]
        y = y[~bad_mask]
        
        # 分层采样
        n_sampled = 0
        if max_samples is not None and len(y) > max_samples:
            n_original = len(y)
            
            unique_classes, class_counts = np.unique(y, return_counts=True)
            
            if len(unique_classes) > 1:
                splitter = StratifiedShuffleSplit(
                    n_splits=1, 
                    train_size=max_samples, 
                    random_state=self.RANDOM_STATE
                )
                
                sample_idx, _ = next(splitter.split(X, y))
                X = X[sample_idx]
                y = y[sample_idx]
                n_sampled = n_original - len(y)
            else:
                np.random.seed(self.RANDOM_STATE)
                sample_idx = np.random.choice(len(y), max_samples, replace=False)
                X = X[sample_idx]
                y = y[sample_idx]
                n_sampled = n_original - len(y)
        
        return X, y, n_nan, n_inf, n_sampled
    
    def calculate_metrics(self, y_true, y_pred):
        """计算评价指标"""
        return {
            'overall_accuracy': accuracy_score(y_true, y_pred),
            'kappa': cohen_kappa_score(y_true, y_pred),
            'precision_macro': precision_score(y_true, y_pred, average='macro', zero_division=0),
            'recall_macro': recall_score(y_true, y_pred, average='macro', zero_division=0),
            'f1_macro': f1_score(y_true, y_pred, average='macro', zero_division=0),
        }
    
    def estimate_prediction_time(self, clf_code, n_pixels, speed_tag):
        """
        估算预测时间（秒）
        基于速度标签和像元数量
        """
        time_per_million_pixels = {
            "very_fast": 1,      # 1秒/百万像元
            "fast": 3,           # 3秒/百万像元
            "medium": 10,        # 10秒/百万像元
            "slow": 30,          # 30秒/百万像元
            "very_slow": 300     # 300秒/百万像元 (5分钟)
        }
        
        base_time = time_per_million_pixels.get(speed_tag, 10)
        return (n_pixels / 1_000_000) * base_time
    
    def predict_by_block(self, model, image, out_path, block_size=512, 
                        ignore_background=True, progress_callback=None,
                        label_encoder=None, scaler=None):
        """
        分块预测（优化版）
        """
        height, width = image.shape[1], image.shape[2]
        prediction = np.zeros((height, width), dtype='uint16')
        
        if ignore_background:
            background_mask = self.get_background_mask(image)
        
        total_blocks = int(np.ceil(height / block_size))
        
        for i, y in enumerate(range(0, height, block_size)):
            h = min(block_size, height - y)
            block_data = image.isel(y=slice(y, y+h)).values
            data = np.moveaxis(block_data, 0, -1)
            original_shape = data.shape
            data_flat = data.reshape(-1, data.shape[-1])
            
            if ignore_background:
                block_bg_mask = background_mask[y:y+h, :].flatten()
                non_bg_indices = ~block_bg_mask
                
                if np.any(non_bg_indices):
                    data_to_predict = np.nan_to_num(data_flat[non_bg_indices], 
                                                   nan=0.0, posinf=0.0, neginf=0.0)
                    
                    if scaler is not None:
                        data_to_predict = scaler.transform(data_to_predict)
                    
                    preds_non_bg = model.predict(data_to_predict)
                    
                    if label_encoder is not None:
                        preds_non_bg = label_encoder.inverse_transform(preds_non_bg)
                    
                    preds_flat = np.zeros(len(data_flat), dtype='uint16')
                    preds_flat[non_bg_indices] = preds_non_bg
                    preds = preds_flat.reshape(original_shape[0], original_shape[1])
                else:
                    preds = np.zeros((original_shape[0], original_shape[1]), dtype='uint16')
            else:
                data_flat = np.nan_to_num(data_flat, nan=0.0, posinf=0.0, neginf=0.0)
                
                if scaler is not None:
                    data_flat = scaler.transform(data_flat)
                
                preds = model.predict(data_flat)
                
                if label_encoder is not None:
                    preds = label_encoder.inverse_transform(preds)
                
                preds = preds.reshape(original_shape[0], original_shape[1]).astype("uint16")
            
            prediction[y:y+h, :] = preds
            
            if progress_callback:
                progress_callback((i + 1) / total_blocks * 100)
        
        # 保存结果
        prediction_da = xr.DataArray(
            prediction,
            dims=['y', 'x'],
            coords={'y': image.coords['y'], 'x': image.coords['x']}
        )
        
        prediction_da.rio.write_crs(image.rio.crs, inplace=True)
        prediction_da.rio.write_transform(image.rio.transform(), inplace=True)
        prediction_da.rio.write_nodata(self.BACKGROUND_VALUE, inplace=True)
        
        prediction_da.rio.to_raster(out_path, driver='GTiff', dtype='uint16', 
                                    compress='lzw', tiled=True)
        return out_path

# ==================== GUI主类 ====================
class ClassificationGUI:
    """遥感影像分类GUI主界面（完全优化版）"""
    
    def __init__(self, root):
        self.root = root
        self.root.title("遥感影像监督分类系统 v3.0 (完全优化版)")
        self.root.geometry("1450x950")
        
        # 后端处理对象
        self.backend = ClassificationBackend()
        
        # 数据变量
        self.image_path = tk.StringVar()
        self.train_shp_path = tk.StringVar()
        self.val_shp_path = tk.StringVar()
        self.output_dir = tk.StringVar(value=str(Path("./results_gui")))
        
        self.class_attr = tk.StringVar(value="class")
        self.name_attr = tk.StringVar(value="name")
        self.n_estimators = tk.IntVar(value=100)
        self.block_size = tk.IntVar(value=512)
        self.ignore_background = tk.BooleanVar(value=True)
        
        # 性能优化参数
        self.enable_sampling = tk.BooleanVar(value=True)
        self.max_samples = tk.IntVar(value=50000)
        self.fast_mode = tk.BooleanVar(value=False)
        
        # 分类器选择
        self.classifier_vars = {}
        all_classifiers = self.backend.get_all_classifiers()
        for code in all_classifiers.keys():
            self.classifier_vars[code] = tk.BooleanVar(value=False)
        
        # 运行状态
        self.is_running = False
        self.log_queue = queue.Queue()
        
        # 构建界面
        self.build_ui()
        
        # 启动日志更新
        self.update_log()
    
    def build_ui(self):
        """构建用户界面"""
        # 创建主框架
        main_frame = ttk.Frame(self.root, padding="10")
        main_frame.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S))
        
        self.root.columnconfigure(0, weight=1)
        self.root.rowconfigure(0, weight=1)
        main_frame.columnconfigure(1, weight=1)
        main_frame.rowconfigure(3, weight=1)
        
        # ===== 1. 文件选择区 =====
        file_frame = ttk.LabelFrame(main_frame, text="1. 数据输入", padding="10")
        file_frame.grid(row=0, column=0, columnspan=2, sticky=(tk.W, tk.E), pady=5)
        
        ttk.Label(file_frame, text="影像文件:").grid(row=0, column=0, sticky=tk.W, pady=2)
        ttk.Entry(file_frame, textvariable=self.image_path, width=65).grid(
            row=0, column=1, sticky=(tk.W, tk.E), padx=5
        )
        ttk.Button(file_frame, text="浏览", command=self.browse_image).grid(
            row=0, column=2, padx=5
        )
        
        ttk.Label(file_frame, text="训练样本:").grid(row=1, column=0, sticky=tk.W, pady=2)
        ttk.Entry(file_frame, textvariable=self.train_shp_path, width=65).grid(
            row=1, column=1, sticky=(tk.W, tk.E), padx=5
        )
        ttk.Button(file_frame, text="浏览", command=self.browse_train_shp).grid(
            row=1, column=2, padx=5
        )
        
        ttk.Label(file_frame, text="验证样本:").grid(row=2, column=0, sticky=tk.W, pady=2)
        ttk.Entry(file_frame, textvariable=self.val_shp_path, width=65).grid(
            row=2, column=1, sticky=(tk.W, tk.E), padx=5
        )
        ttk.Button(file_frame, text="浏览", command=self.browse_val_shp).grid(
            row=2, column=2, padx=5
        )
        
        ttk.Label(file_frame, text="输出目录:").grid(row=3, column=0, sticky=tk.W, pady=2)
        ttk.Entry(file_frame, textvariable=self.output_dir, width=65).grid(
            row=3, column=1, sticky=(tk.W, tk.E), padx=5
        )
        ttk.Button(file_frame, text="浏览", command=self.browse_output).grid(
            row=3, column=2, padx=5
        )
        
        file_frame.columnconfigure(1, weight=1)
        
        # ===== 2. 参数设置区 =====
        param_frame = ttk.LabelFrame(main_frame, text="2. 参数配置", padding="10")
        param_frame.grid(row=1, column=0, sticky=(tk.W, tk.E, tk.N), pady=5, padx=(0, 5))
        
        ttk.Label(param_frame, text="类别编号字段:").grid(row=0, column=0, sticky=tk.W, pady=2)
        ttk.Entry(param_frame, textvariable=self.class_attr, width=15).grid(
            row=0, column=1, sticky=tk.W, padx=5
        )
        
        ttk.Label(param_frame, text="类别名称字段:").grid(row=1, column=0, sticky=tk.W, pady=2)
        ttk.Entry(param_frame, textvariable=self.name_attr, width=15).grid(
            row=1, column=1, sticky=tk.W, padx=5
        )
        
        ttk.Label(param_frame, text="树模型数量:").grid(row=2, column=0, sticky=tk.W, pady=2)
        ttk.Spinbox(param_frame, from_=10, to=500, textvariable=self.n_estimators, 
                   width=13).grid(row=2, column=1, sticky=tk.W, padx=5)
        
        ttk.Label(param_frame, text="分块大小:").grid(row=3, column=0, sticky=tk.W, pady=2)
        ttk.Spinbox(param_frame, from_=256, to=2048, increment=256, 
                   textvariable=self.block_size, width=13).grid(
            row=3, column=1, sticky=tk.W, padx=5
        )
        
        # 性能优化选项
        ttk.Separator(param_frame, orient='horizontal').grid(
            row=4, column=0, columnspan=2, sticky=(tk.W, tk.E), pady=8
        )
        
        ttk.Label(param_frame, text="⚡ 性能优化:", font=('', 9, 'bold')).grid(
            row=5, column=0, columnspan=2, sticky=tk.W, pady=2
        )
        
        sample_frame = ttk.Frame(param_frame)
        sample_frame.grid(row=6, column=0, columnspan=2, sticky=(tk.W, tk.E))
        
        ttk.Checkbutton(sample_frame, text="启用采样", 
                       variable=self.enable_sampling,
                       command=self.toggle_sampling).pack(side=tk.LEFT)
        
        ttk.Label(sample_frame, text="  最大样本数:").pack(side=tk.LEFT, padx=(10, 0))
        self.max_samples_spinbox = ttk.Spinbox(
            sample_frame, from_=10000, to=200000, increment=10000,
            textvariable=self.max_samples, width=10
        )
        self.max_samples_spinbox.pack(side=tk.LEFT, padx=5)
        
        ttk.Checkbutton(param_frame, text="快速模式（减少模型复杂度）", 
                       variable=self.fast_mode).grid(
            row=7, column=0, columnspan=2, sticky=tk.W, pady=2
        )
        
        ttk.Checkbutton(param_frame, text="忽略背景值（所有波段为0）", 
                       variable=self.ignore_background).grid(
            row=8, column=0, columnspan=2, sticky=tk.W, pady=2
        )
        
        # ===== 3. 分类器选择区 =====
        clf_frame = ttk.LabelFrame(main_frame, text="3. 分类器选择 (✓推荐 ⚠️慢速)", padding="10")
        clf_frame.grid(row=1, column=1, rowspan=2, sticky=(tk.W, tk.E, tk.N, tk.S), pady=5)
        
        # 快捷按钮
        btn_frame = ttk.Frame(clf_frame)
        btn_frame.grid(row=0, column=0, columnspan=3, sticky=(tk.W, tk.E), pady=(0, 5))
        
        ttk.Button(btn_frame, text="全选", command=self.select_all_classifiers, 
                  width=12).pack(side=tk.LEFT, padx=2)
        ttk.Button(btn_frame, text="全不选", command=self.deselect_all_classifiers, 
                  width=12).pack(side=tk.LEFT, padx=2)
        ttk.Button(btn_frame, text="✓推荐组合", command=self.select_recommended, 
                  width=12).pack(side=tk.LEFT, padx=2)
        ttk.Button(btn_frame, text="⚡快速分类器", command=self.select_fast, 
                  width=12).pack(side=tk.LEFT, padx=2)
        ttk.Button(btn_frame, text="SVM全选", command=self.select_all_svm, 
                  width=12).pack(side=tk.LEFT, padx=2)
        
        # 创建滚动区域
        canvas = tk.Canvas(clf_frame, height=200)
        scrollbar = ttk.Scrollbar(clf_frame, orient="vertical", command=canvas.yview)
        scrollable_frame = ttk.Frame(canvas)
        
        scrollable_frame.bind(
            "<Configure>",
            lambda e: canvas.configure(scrollregion=canvas.bbox("all"))
        )
        
        canvas.create_window((0, 0), window=scrollable_frame, anchor="nw")
        canvas.configure(yscrollcommand=scrollbar.set)
        
        # 分类器复选框（分组显示）
        all_classifiers = self.backend.get_all_classifiers()
        
        # SVM组
        ttk.Label(scrollable_frame, text="📊 SVM系列:", font=('', 9, 'bold')).grid(
            row=0, column=0, columnspan=3, sticky=tk.W, pady=(5, 2)
        )
        row = 1
        col = 0
        svm_codes = ["svm_linear", "linear_svc", "sgd_svm", "nystroem_svm", 
                     "rbf_sampler_svm", "svm_rbf"]
        for code in svm_codes:
            if code in all_classifiers:
                _, name, _, _, _, _ = all_classifiers[code]
                cb = ttk.Checkbutton(scrollable_frame, text=name, 
                                   variable=self.classifier_vars[code])
                cb.grid(row=row, column=col, sticky=tk.W, pady=1, padx=5)
                col += 1
                if col >= 3:
                    col = 0
                    row += 1
        
        if col > 0:
            row += 1
        
        # 树模型组
        ttk.Label(scrollable_frame, text="🌲 树模型系列:", font=('', 9, 'bold')).grid(
            row=row, column=0, columnspan=3, sticky=tk.W, pady=(10, 2)
        )
        row += 1
        col = 0
        tree_codes = ["rf", "et", "dt", "xgb", "lgb", "gb", "ada"]
        for code in tree_codes:
            if code in all_classifiers:
                _, name, _, _, _, _ = all_classifiers[code]
                cb = ttk.Checkbutton(scrollable_frame, text=name,
                                   variable=self.classifier_vars[code])
                cb.grid(row=row, column=col, sticky=tk.W, pady=1, padx=5)
                col += 1
                if col >= 3:
                    col = 0
                    row += 1
        
        if col > 0:
            row += 1
        
        # 其他分类器组
        ttk.Label(scrollable_frame, text="📈 其他分类器:", font=('', 9, 'bold')).grid(
            row=row, column=0, columnspan=3, sticky=tk.W, pady=(10, 2)
        )
        row += 1
        col = 0
        other_codes = ["knn", "nb", "lr", "mlp"]
        for code in other_codes:
            if code in all_classifiers:
                _, name, _, _, _, _ = all_classifiers[code]
                cb = ttk.Checkbutton(scrollable_frame, text=name,
                                   variable=self.classifier_vars[code])
                cb.grid(row=row, column=col, sticky=tk.W, pady=1, padx=5)
                col += 1
                if col >= 3:
                    col = 0
                    row += 1
        
        canvas.grid(row=1, column=0, columnspan=3, sticky=(tk.W, tk.E, tk.N, tk.S))
        scrollbar.grid(row=1, column=3, sticky=(tk.N, tk.S))
        
        clf_frame.rowconfigure(1, weight=1)
        
        # ===== 4. 控制按钮区 =====
        control_frame = ttk.LabelFrame(main_frame, text="4. 运行控制", padding="10")
        control_frame.grid(row=2, column=0, sticky=(tk.W, tk.E), pady=5, padx=(0, 5))
        
        self.start_btn = ttk.Button(control_frame, text="▶ 开始分类", 
                                    command=self.start_classification, width=15)
        self.start_btn.grid(row=0, column=0, padx=5, pady=5)
        
        self.stop_btn = ttk.Button(control_frame, text="⏸ 停止", 
                                   command=self.stop_classification, 
                                   state=tk.DISABLED, width=15)
        self.stop_btn.grid(row=0, column=1, padx=5, pady=5)
        
        ttk.Button(control_frame, text="📁 打开结果", 
                  command=self.open_result_dir, width=15).grid(
            row=0, column=2, padx=5, pady=5
        )
        
        ttk.Button(control_frame, text="📊 查看报告", 
                  command=self.view_report, width=15).grid(
            row=0, column=3, padx=5, pady=5
        )
        
        # 进度条
        ttk.Label(control_frame, text="进度:").grid(row=1, column=0, sticky=tk.W, pady=2)
        self.progress_var = tk.DoubleVar()
        self.progress_bar = ttk.Progressbar(control_frame, variable=self.progress_var, 
                                           maximum=100, length=400)
        self.progress_bar.grid(row=1, column=1, columnspan=3, sticky=(tk.W, tk.E), 
                              padx=5, pady=2)
        
        control_frame.columnconfigure(3, weight=1)
        
        # ===== 5. 日志输出区 =====
        log_frame = ttk.LabelFrame(main_frame, text="5. 运行日志", padding="10")
        log_frame.grid(row=3, column=0, columnspan=2, sticky=(tk.W, tk.E, tk.N, tk.S), pady=5)
        
        self.log_text = scrolledtext.ScrolledText(log_frame, wrap=tk.WORD, 
                                                  height=18, width=120,
                                                  font=('Consolas', 9))
        self.log_text.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S))
        
        log_frame.columnconfigure(0, weight=1)
        log_frame.rowconfigure(0, weight=1)
        
        # 状态栏
        self.status_var = tk.StringVar(value="就绪 | 请选择数据文件开始")
        status_bar = ttk.Label(main_frame, textvariable=self.status_var, 
                              relief=tk.SUNKEN, anchor=tk.W)
        status_bar.grid(row=4, column=0, columnspan=2, sticky=(tk.W, tk.E), pady=(5, 0))
    
    def toggle_sampling(self):
        """切换采样功能"""
        if self.enable_sampling.get():
            self.max_samples_spinbox.config(state=tk.NORMAL)
        else:
            self.max_samples_spinbox.config(state=tk.DISABLED)
    
    # ===== 文件浏览函数 =====
    def browse_image(self):
        filename = filedialog.askopenfilename(
            title="选择影像文件",
            filetypes=[("GeoTIFF", "*.tif *.tiff"), ("所有文件", "*.*")]
        )
        if filename:
            self.image_path.set(filename)
            self.status_var.set(f"已选择影像: {Path(filename).name}")
    
    def browse_train_shp(self):
        filename = filedialog.askopenfilename(
            title="选择训练样本",
            filetypes=[("Shapefile", "*.shp"), ("所有文件", "*.*")]
        )
        if filename:
            self.train_shp_path.set(filename)
            self.status_var.set(f"已选择训练样本: {Path(filename).name}")
    
    def browse_val_shp(self):
        filename = filedialog.askopenfilename(
            title="选择验证样本",
            filetypes=[("Shapefile", "*.shp"), ("所有文件", "*.*")]
        )
        if filename:
            self.val_shp_path.set(filename)
            self.status_var.set(f"已选择验证样本: {Path(filename).name}")
    
    def browse_output(self):
        dirname = filedialog.askdirectory(title="选择输出目录")
        if dirname:
            self.output_dir.set(dirname)
            self.status_var.set(f"输出目录: {dirname}")
    
    # ===== 分类器选择函数 =====
    def select_all_classifiers(self):
        for var in self.classifier_vars.values():
            var.set(True)
    
    def deselect_all_classifiers(self):
        for var in self.classifier_vars.values():
            var.set(False)
    
    def select_recommended(self):
        """选择推荐的分类器（精度和速度平衡）"""
        recommended = ["rf", "xgb", "et", "lgb", "linear_svc", "nystroem_svm"]
        for code, var in self.classifier_vars.items():
            var.set(code in recommended)
        self.status_var.set("已选择推荐组合: RF, XGBoost, ET, LightGBM, Linear SVM, Nystroem SVM")
    
    def select_fast(self):
        """选择快速分类器（速度优先）"""
        fast = ["rf", "et", "dt", "xgb", "lgb", "nb", "lr", "sgd_svm", "linear_svc"]
        for code, var in self.classifier_vars.items():
            var.set(code in fast)
        self.status_var.set("已选择快速分类器: 适合大数据量快速测试")
    
    def select_all_svm(self):
        """选择所有SVM变体"""
        svm_codes = ["svm_linear", "linear_svc", "sgd_svm", "nystroem_svm", 
                     "rbf_sampler_svm", "svm_rbf"]
        for code, var in self.classifier_vars.items():
            var.set(code in svm_codes)
        self.status_var.set("已选择所有SVM变体: 用于对比不同SVM实现")
    
    # ===== 日志相关函数 =====
    def log(self, message):
        """添加日志消息"""
        self.log_queue.put(message)
    
    def update_log(self):
        """更新日志显示"""
        try:
            while True:
                message = self.log_queue.get_nowait()
                self.log_text.insert(tk.END, message + "\n")
                self.log_text.see(tk.END)
        except queue.Empty:
            pass
        
        self.root.after(100, self.update_log)
    
    # ===== 主要功能函数 =====
    def start_classification(self):
        """开始分类"""
        # 检查输入
        if not self.image_path.get():
            messagebox.showerror("错误", "请选择影像文件！")
            return
        
        if not self.train_shp_path.get():
            messagebox.showerror("错误", "请选择训练样本！")
            return
        
        # 检查是否至少选择了一个分类器
        selected_classifiers = [code for code, var in self.classifier_vars.items() 
                               if var.get()]
        if not selected_classifiers:
            messagebox.showerror("错误", "请至少选择一个分类器！")
            return
        
        # 性能警告
        all_classifiers = self.backend.get_all_classifiers()
        slow_clfs = []
        very_slow_clfs = []
        
        for code in selected_classifiers:
            if code in all_classifiers:
                speed_tag = all_classifiers[code][5]
                name = all_classifiers[code][1]
                if speed_tag == "very_slow":
                    very_slow_clfs.append(name)
                elif speed_tag == "slow":
                    slow_clfs.append(name)
        
        # 显示警告
        if very_slow_clfs:
            warning_msg = "⚠️ 性能警告\n\n"
            warning_msg += "以下分类器预测**非常慢**:\n"
            for clf in very_slow_clfs:
                warning_msg += f"  • {clf}\n"
            warning_msg += f"\n预计预测时间: >5分钟/分类器\n\n"
            warning_msg += "建议:\n"
            warning_msg += "  • 使用 'SVM-线性核' 或 'SGD-SVM' 替代\n"
            warning_msg += "  • 或使用 '核近似SVM' 获得类似效果\n"
            warning_msg += "  • 或启用数据采样减少计算量\n\n"
            warning_msg += "是否继续?"
            
            if not messagebox.askyesno("性能警告", warning_msg, icon='warning'):
                return
        
        # 禁用开始按钮，启用停止按钮
        self.start_btn.config(state=tk.DISABLED)
        self.stop_btn.config(state=tk.NORMAL)
        self.is_running = True
        
        # 清空日志
        self.log_text.delete(1.0, tk.END)
        self.log("="*80)
        self.log("  遥感影像监督分类系统 v3.0 - 开始分类任务")
        self.log("="*80)
        self.log(f"选择的分类器: {len(selected_classifiers)} 个")
        
        if self.enable_sampling.get():
            self.log(f"✓ 数据采样: 最大 {self.max_samples.get():,} 个样本")
        if self.fast_mode.get():
            self.log(f"✓ 快速模式: 启用")
        self.log("")
        
        # 在新线程中运行分类
        thread = threading.Thread(target=self.run_classification, 
                                 args=(selected_classifiers,))
        thread.daemon = True
        thread.start()
    
    def stop_classification(self):
        """停止分类"""
        self.is_running = False
        self.log("\n⏸ 用户请求停止...")
        self.status_var.set("已停止")
    
    def run_classification(self, selected_classifiers):
        """执行分类（在后台线程中运行）"""
        try:
            # 创建输出目录
            out_dir = Path(self.output_dir.get())
            out_dir.mkdir(exist_ok=True)
            
            # 1. 读取影像
            self.log(f"📁 正在读取影像...")
            self.log(f"   路径: {self.image_path.get()}")
            self.status_var.set("读取影像...")
            img = rxr.open_rasterio(self.image_path.get(), masked=True)
            n_pixels = img.shape[1] * img.shape[2]
            self.log(f"   尺寸: {img.shape[1]} × {img.shape[2]} = {n_pixels:,} 像元")
            self.log(f"   波段数: {img.rio.count}")
            
            if not self.is_running:
                return
            
            # 2. 读取类别信息
            self.log(f"\n📊 正在读取类别信息...")
            class_names, class_colors, _ = self.backend.get_class_info_from_shp(
                self.train_shp_path.get(), 
                self.class_attr.get(), 
                self.name_attr.get()
            )
            self.log(f"   检测到 {len(class_names)} 个类别: {list(class_names.values())}")
            
            # 3. 提取训练样本
            self.log(f"\n🎯 正在处理训练样本...")
            self.status_var.set("处理训练样本...")
            train_mask = self.backend.rasterize_samples(
                self.train_shp_path.get(), img, self.class_attr.get()
            )
            
            max_samples = self.max_samples.get() if self.enable_sampling.get() else None
            
            X_train, y_train, n_nan, n_inf, n_sampled = self.backend.extract_samples(
                img, train_mask, 
                ignore_background=self.ignore_background.get(),
                max_samples=max_samples
            )
            
            self.log(f"   训练样本数: {len(y_train):,}")
            if n_nan > 0:
                self.log(f"   └─ 移除NaN样本: {n_nan:,}")
            if n_inf > 0:
                self.log(f"   └─ 移除Inf样本: {n_inf:,}")
            if n_sampled > 0:
                self.log(f"   └─ 采样减少: {n_sampled:,} (提速优化)")
            
            if not self.is_running:
                return
            
            # 4. 提取验证样本
            val_exists = os.path.exists(self.val_shp_path.get())
            if val_exists:
                self.log(f"\n✅ 正在处理验证样本...")
                val_mask = self.backend.rasterize_samples(
                    self.val_shp_path.get(), img, self.class_attr.get()
                )
                
                if self.ignore_background.get():
                    background_mask = self.backend.get_background_mask(img)
                    valid_val = (val_mask > 0) & (~background_mask)
                else:
                    valid_val = val_mask > 0
                
                yv_true = val_mask[valid_val]
                self.log(f"   验证样本数: {len(yv_true):,}")
            
            # 5. 分类器训练和评估
            all_classifiers = self.backend.get_all_classifiers(
                self.n_estimators.get(), 
                fast_mode=self.fast_mode.get(),
                n_train_samples=len(y_train)
            )
            
            comparison_results = []
            total_start_time = time.time()
            
            for i, clf_code in enumerate(selected_classifiers):
                if not self.is_running:
                    break
                
                clf, clf_name, clf_desc, needs_encoding, needs_scaling, speed_tag = all_classifiers[clf_code]
                
                self.log(f"\n{'='*80}")
                self.log(f"[{i+1}/{len(selected_classifiers)}] {clf_name}")
                self.log(f"{'='*80}")
                self.log(f"描述: {clf_desc}")
                
                # 预估时间
                est_pred_time = self.backend.estimate_prediction_time(clf_code, n_pixels, speed_tag)
                if est_pred_time > 60:
                    self.log(f"⏱️  预计预测时间: ~{est_pred_time/60:.1f} 分钟")
                elif est_pred_time > 10:
                    self.log(f"⏱️  预计预测时间: ~{est_pred_time:.0f} 秒")
                
                self.status_var.set(f"[{i+1}/{len(selected_classifiers)}] 训练 {clf_name}...")
                
                clf_dir = out_dir / clf_code
                clf_dir.mkdir(exist_ok=True)
                
                try:
                    # 数据预处理
                    label_encoder = None
                    scaler = None
                    X_train_use = X_train.copy()
                    y_train_use = y_train.copy()
                    
                    if needs_encoding:
                        self.log("   🔄 应用标签编码...")
                        label_encoder = LabelEncoder()
                        y_train_use = label_encoder.fit_transform(y_train)
                    
                    if needs_scaling:
                        self.log("   📏 应用特征缩放...")
                        scaler = StandardScaler()
                        X_train_use = scaler.fit_transform(X_train_use)
                    
                    # 训练
                    self.log("   🔨 训练中...")
                    train_start = time.time()
                    clf.fit(X_train_use, y_train_use)
                    train_time = time.time() - train_start
                    self.log(f"   ✓ 训练完成: {train_time:.2f} 秒")
                    
                    # 训练集精度
                    y_train_pred = clf.predict(X_train_use)
                    
                    if label_encoder is not None:
                        y_train_pred = label_encoder.inverse_transform(y_train_pred)
                    
                    train_metrics = self.backend.calculate_metrics(y_train, y_train_pred)
                    self.log(f"   📈 训练集 - 精度: {train_metrics['overall_accuracy']:.4f}, "
                           f"Kappa: {train_metrics['kappa']:.4f}")
                    
                    if not self.is_running:
                        break
                    
                    # 预测整幅影像
                    self.log("   🗺️  预测整幅影像...")
                    self.status_var.set(f"[{i+1}/{len(selected_classifiers)}] 预测 {clf_name}...")
                    
                    pred_start = time.time()
                    classified_path = clf_dir / f"classified_{clf_code}.tif"
                    
                    def update_progress(progress):
                        self.progress_var.set(progress)
                    
                    self.backend.predict_by_block(
                        clf, img, classified_path, 
                        block_size=self.block_size.get(),
                        ignore_background=self.ignore_background.get(),
                        progress_callback=update_progress,
                        label_encoder=label_encoder,
                        scaler=scaler
                    )
                    
                    pred_time = time.time() - pred_start
                    self.log(f"   ✓ 预测完成: {pred_time:.2f} 秒")
                    
                    # 验证集精度
                    val_metrics = {'overall_accuracy': np.nan, 'kappa': np.nan, 'f1_macro': np.nan}
                    if val_exists:
                        with rxr.open_rasterio(classified_path) as pred_img:
                            pred_arr = pred_img.values.squeeze()
                        
                        yv_pred = pred_arr[valid_val]
                        val_metrics = self.backend.calculate_metrics(yv_true, yv_pred)
                        self.log(f"   📊 验证集 - 精度: {val_metrics['overall_accuracy']:.4f}, "
                               f"Kappa: {val_metrics['kappa']:.4f}")
                    
                    # 记录结果
                    result = {
                        '分类器代码': clf_code,
                        '分类器名称': clf_name,
                        '速度等级': speed_tag,
                        '训练集精度': train_metrics['overall_accuracy'],
                        '训练集Kappa': train_metrics['kappa'],
                        '训练集F1': train_metrics['f1_macro'],
                        '验证集精度': val_metrics['overall_accuracy'],
                        '验证集Kappa': val_metrics['kappa'],
                        '验证集F1': val_metrics['f1_macro'],
                        '训练时间(秒)': train_time,
                        '预测时间(秒)': pred_time,
                        '总时间(秒)': train_time + pred_time
                    }
                    comparison_results.append(result)
                    
                    self.log(f"   ✅ {clf_name} 完成!")
                    
                except Exception as e:
                    self.log(f"   ❌ {clf_name} 失败: {str(e)}")
                    import traceback
                    self.log(f"   {traceback.format_exc()}")
                    continue
                
                # 更新总进度
                self.progress_var.set((i + 1) / len(selected_classifiers) * 100)
            
            # 6. 生成对比报告
            if comparison_results and self.is_running:
                total_time = time.time() - total_start_time
                
                self.log(f"\n{'='*80}")
                self.log("📝 生成对比报告...")
                self.status_var.set("生成报告...")
                
                comparison_df = pd.DataFrame(comparison_results)
                comparison_df.to_csv(out_dir / "classifier_comparison.csv", 
                                   index=False, encoding='utf-8-sig')
                
                # 生成详细报告
                with open(out_dir / "comparison_summary.txt", 'w', encoding='utf-8') as f:
                    f.write("="*70 + "\n")
                    f.write("        遥感影像分类器性能对比报告\n")
                    f.write("="*70 + "\n\n")
                    
                    f.write(f"生成时间: {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
                    f.write(f"影像尺寸: {img.shape[1]} × {img.shape[2]} = {n_pixels:,} 像元\n")
                    f.write(f"训练样本数: {len(y_train):,}\n")
                    if val_exists:
                        f.write(f"验证样本数: {len(yv_true):,}\n")
                    f.write(f"类别数量: {len(class_names)}\n")
                    f.write(f"性能优化: 采样={self.enable_sampling.get()}, "
                           f"快速模式={self.fast_mode.get()}\n")
                    f.write(f"成功完成: {len(comparison_results)}/{len(selected_classifiers)} 个分类器\n")
                    f.write(f"总耗时: {total_time/60:.1f} 分钟\n\n")
                    
                    # 精度排名
                    sorted_df = comparison_df.sort_values('验证集精度', ascending=False)
                    f.write("-"*70 + "\n")
                    f.write("📊 验证集精度排名:\n")
                    f.write("-"*70 + "\n")
                    for idx, (_, row) in enumerate(sorted_df.iterrows(), 1):
                        f.write(f"{idx:2d}. {row['分类器名称']:18s} - "
                               f"精度: {row['验证集精度']:.4f}, "
                               f"Kappa: {row['验证集Kappa']:.4f}, "
                               f"F1: {row['验证集F1']:.4f}\n")
                    
                    # 速度排名
                    f.write("\n" + "-"*70 + "\n")
                    f.write("⚡ 总时间排名（训练+预测）:\n")
                    f.write("-"*70 + "\n")
                    sorted_time = comparison_df.sort_values('总时间(秒)')
                    for idx, (_, row) in enumerate(sorted_time.iterrows(), 1):
                        f.write(f"{idx:2d}. {row['分类器名称']:18s} - "
                               f"{row['总时间(秒)']:7.2f}秒 "
                               f"(训练: {row['训练时间(秒)']:6.2f}s, "
                               f"预测: {row['预测时间(秒)']:6.2f}s)\n")
                    
                    # 速度分类统计
                    f.write("\n" + "-"*70 + "\n")
                    f.write("📈 速度分类统计:\n")
                    f.write("-"*70 + "\n")
                    speed_stats = comparison_df.groupby('速度等级').agg({
                        '分类器名称': 'count',
                        '验证集精度': 'mean',
                        '总时间(秒)': 'mean'
                    }).round(4)
                    f.write(speed_stats.to_string())
                    
                    # 推荐
                    f.write("\n\n" + "="*70 + "\n")
                    f.write("💡 推荐:\n")
                    f.write("-"*70 + "\n")
                    
                    best_acc = sorted_df.iloc[0]
                    f.write(f"🏆 最高精度: {best_acc['分类器名称']} "
                           f"(精度: {best_acc['验证集精度']:.4f}, "
                           f"时间: {best_acc['总时间(秒)']:.2f}秒)\n")
                    
                    best_speed = sorted_time.iloc[0]
                    f.write(f"⚡ 最快速度: {best_speed['分类器名称']} "
                           f"(时间: {best_speed['总时间(秒)']:.2f}秒, "
                           f"精度: {best_speed['验证集精度']:.4f})\n")
                    
                    # 综合评分
                    comparison_df['综合得分'] = (
                        comparison_df['验证集精度'] * 0.7 + 
                        (1 - comparison_df['总时间(秒)'] / comparison_df['总时间(秒)'].max()) * 0.3
                    )
                    best_overall = comparison_df.loc[comparison_df['综合得分'].idxmax()]
                    f.write(f"⭐ 综合最佳: {best_overall['分类器名称']} "
                           f"(精度: {best_overall['验证集精度']:.4f}, "
                           f"时间: {best_overall['总时间(秒)']:.2f}秒, "
                           f"得分: {best_overall['综合得分']:.4f})\n")
                
                self.log("✅ 所有任务完成!")
                self.log(f"📁 结果保存至: {out_dir.absolute()}")
                self.log(f"📊 成功: {len(comparison_results)}/{len(selected_classifiers)} 个分类器")
                self.log(f"⏱️  总耗时: {total_time/60:.1f} 分钟")
                
                # 显示最佳结果
                best_clf = comparison_df.loc[comparison_df['验证集精度'].idxmax()]
                self.log(f"\n🏆 最佳精度: {best_clf['分类器名称']} ({best_clf['验证集精度']:.4f})")
                
                self.status_var.set(f"✅ 完成! 最佳: {best_clf['分类器名称']} ({best_clf['验证集精度']:.4f})")
                
                messagebox.showinfo("任务完成", 
                    f"🎉 分类任务完成!\n\n"
                    f"✅ 成功: {len(comparison_results)}/{len(selected_classifiers)}\n"
                    f"🏆 最佳: {best_clf['分类器名称']} ({best_clf['验证集精度']:.4f})\n"
                    f"⏱️  耗时: {total_time/60:.1f} 分钟\n\n"
                    f"📁 结果: {out_dir}")
            
        except Exception as e:
            self.log(f"\n❌ 错误: {str(e)}")
            import traceback
            self.log(traceback.format_exc())
            messagebox.showerror("错误", f"发生错误:\n{str(e)}")
            self.status_var.set("❌ 错误")
        
        finally:
            # 恢复按钮状态
            self.start_btn.config(state=tk.NORMAL)
            self.stop_btn.config(state=tk.DISABLED)
            self.progress_var.set(0)
            self.is_running = False
    
    def open_result_dir(self):
        """打开结果目录"""
        out_dir = Path(self.output_dir.get())
        if out_dir.exists():
            import subprocess
            import platform
            
            if platform.system() == "Windows":
                os.startfile(out_dir)
            elif platform.system() == "Darwin":
                subprocess.Popen(["open", out_dir])
            else:
                subprocess.Popen(["xdg-open", out_dir])
        else:
            messagebox.showwarning("警告", "结果目录不存在！")
    
    def view_report(self):
        """查看对比报告"""
        report_file = Path(self.output_dir.get()) / "comparison_summary.txt"
        if report_file.exists():
            # 创建新窗口显示报告
            report_window = tk.Toplevel(self.root)
            report_window.title("📊 分类器对比报告")
            report_window.geometry("900x700")
            
            # 添加工具栏
            toolbar = ttk.Frame(report_window)
            toolbar.pack(side=tk.TOP, fill=tk.X, padx=5, pady=5)
            
            ttk.Button(toolbar, text="📁 打开CSV", 
                      command=lambda: os.startfile(
                          Path(self.output_dir.get()) / "classifier_comparison.csv"
                      )).pack(side=tk.LEFT, padx=2)
            
            ttk.Button(toolbar, text="🔄 刷新", 
                      command=lambda: self.refresh_report(text_widget, report_file)).pack(
                side=tk.LEFT, padx=2)
            
            # 文本显示区域
            text_widget = scrolledtext.ScrolledText(report_window, wrap=tk.WORD,
                                                   font=('Consolas', 10))
            text_widget.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
            
            with open(report_file, 'r', encoding='utf-8') as f:
                content = f.read()
                text_widget.insert(1.0, content)
            
            text_widget.config(state=tk.DISABLED)
        else:
            messagebox.showwarning("警告", "报告文件不存在！\n请先运行分类任务。")
    
    def refresh_report(self, text_widget, report_file):
        """刷新报告显示"""
        text_widget.config(state=tk.NORMAL)
        text_widget.delete(1.0, tk.END)
        
        with open(report_file, 'r', encoding='utf-8') as f:
            content = f.read()
            text_widget.insert(1.0, content)
        
        text_widget.config(state=tk.DISABLED)

# ==================== 主程序入口 ====================
def main():
    """程序入口"""
    root = tk.Tk()
    
    # 设置图标（如果有）
    # root.iconbitmap('icon.ico')
    
    app = ClassificationGUI(root)
    
    # 显示欢迎信息
    app.log("="*80)
    app.log("  遥感影像监督分类系统 v3.0")
    app.log("="*80)
    app.log("支持的分类器:")
    app.log("  📊 SVM系列: 线性核、RBF核、SGD-SVM、核近似等")
    app.log("  🌲 树模型: RF、XGBoost、LightGBM、ET、GB、DT等")
    app.log("  📈 其他: KNN、朴素贝叶斯、逻辑回归、神经网络等")
    app.log("")
    app.log("优化特性:")
    app.log("  ⚡ 数据采样 - 加快训练速度")
    app.log("  📏 特征缩放 - 提升SVM/KNN性能")
    app.log("  🚀 快速模式 - 减少模型复杂度")
    app.log("  ⚠️  性能警告 - 避免选择慢速分类器")
    app.log("")
    app.log("💡 提示: 点击上方'✓推荐组合'或'⚡快速分类器'快速选择")
    app.log("="*80)
    app.log("")
    
    root.mainloop()

if __name__ == "__main__":
    main()

# 版本2

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
遥感影像监督分类系统 - 完整版
=====================================
版本: v3.0.1 (错误修复版)
作者: AI Assistant
日期: 2024

主要特性:
- 15+种分类器（包含6种SVM变体）
- 智能性能优化（采样、特征缩放、快速模式）
- 完善的错误处理
- 图形化界面
- 自动生成对比报告

使用说明:
1. 安装依赖: pip install numpy pandas matplotlib seaborn geopandas rioxarray scikit-learn xgboost
2. 运行程序: python classification_system_v3.py
3. 按界面提示操作

注意事项:
- XGBoost 和 LightGBM 为可选依赖，未安装不影响其他分类器使用
- 推荐使用虚拟环境运行
"""

import os
import sys
import time
import threading
import queue
from pathlib import Path
import tkinter as tk
from tkinter import ttk, filedialog, messagebox, scrolledtext
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import seaborn as sns
import geopandas as gpd
import rioxarray as rxr
import xarray as xr
from rasterio import features
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.ensemble import (RandomForestClassifier, GradientBoostingClassifier, 
                              AdaBoostClassifier, ExtraTreesClassifier)
from sklearn.svm import SVC, LinearSVC
from sklearn.linear_model import SGDClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier
from sklearn.calibration import CalibratedClassifierCV
from sklearn.kernel_approximation import Nystroem, RBFSampler
from sklearn.pipeline import Pipeline
from sklearn.metrics import (confusion_matrix, accuracy_score, classification_report, 
                           cohen_kappa_score, precision_score, recall_score, f1_score)
from scipy import ndimage
from skimage import morphology
import warnings
warnings.filterwarnings('ignore')

# 设置matplotlib中文显示
plt.rcParams["font.sans-serif"] = ["SimHei", "DejaVu Sans"]
plt.rcParams["axes.unicode_minus"] = False

# ==================== 后端处理类 ====================
class ClassificationBackend:
    """分类处理后端（完全优化版）"""
    
    def __init__(self):
        self.BACKGROUND_VALUE = 0
        self.RANDOM_STATE = 42
        
        # 预定义颜色
        self.LANDUSE_COLORS = {
            "水体": "lightblue", "河流": "blue", "湖泊": "deepskyblue",
            "植被": "forestgreen", "森林": "darkgreen", "草地": "limegreen",
            "农田": "yellowgreen", "耕地": "olivedrab",
            "建筑": "gray", "城市": "dimgray", "居民地": "slategray",
            "裸地": "tan", "沙地": "wheat", "其他": "darkred"
        }
        
        self.COLOR_PALETTE = ['forestgreen', 'lightblue', 'gray', 'tan', 'yellow', 
                             'darkred', 'purple', 'orange', 'pink', 'brown']
        
        # 检查可选库的可用性
        self.check_optional_libraries()
    
    def check_optional_libraries(self):
        """检查可选库是否可用"""
        self.has_xgboost = False
        self.has_lightgbm = False
        
        # 检查 XGBoost
        try:
            import xgboost
            from xgboost import XGBClassifier
            # 测试能否正常实例化
            _ = XGBClassifier(n_estimators=10, verbosity=0)
            self.has_xgboost = True
            print("✓ XGBoost 可用")
        except Exception as e:
            print(f"✗ XGBoost 不可用: {type(e).__name__}")
            self.has_xgboost = False
        
        # 检查 LightGBM（增强错误处理）
        try:
            import lightgbm
            from lightgbm import LGBMClassifier
            # 测试能否正常实例化
            _ = LGBMClassifier(n_estimators=10, verbose=-1)
            self.has_lightgbm = True
            print("✓ LightGBM 可用")
        except Exception as e:
            print(f"✗ LightGBM 不可用: {type(e).__name__}")
            self.has_lightgbm = False
    
    def get_all_classifiers(self, n_estimators=100, fast_mode=False, n_train_samples=None):
        """
        获取所有可用分类器
        
        返回格式: {code: (classifier, name, desc, needs_encoding, needs_scaling, speed_tag)}
        """
        # 根据模式调整参数
        if fast_mode:
            n_est = min(50, n_estimators)
            max_depth = 10
            max_iter = 200
        else:
            n_est = n_estimators
            max_depth = 20
            max_iter = 500
        
        # 核近似的组件数
        if n_train_samples:
            n_components = min(1000, n_train_samples // 2)
        else:
            n_components = 1000
        
        classifiers = {
            # ===== 树模型系列 =====
            "rf": (
                RandomForestClassifier(
                    n_estimators=n_est, 
                    n_jobs=-1, 
                    random_state=self.RANDOM_STATE, 
                    verbose=0,
                    max_depth=max_depth,
                    min_samples_split=5,
                    min_samples_leaf=2,
                    max_features='sqrt'
                ),
                "随机森林", "Random Forest - 稳定可靠", 
                False, False, "fast"
            ),
            
            "et": (
                ExtraTreesClassifier(
                    n_estimators=n_est,
                    n_jobs=-1,
                    random_state=self.RANDOM_STATE,
                    verbose=0,
                    max_depth=max_depth,
                    min_samples_split=5,
                    max_features='sqrt'
                ),
                "极端随机树", "Extra Trees - 更快的RF", 
                False, False, "fast"
            ),
            
            "dt": (
                DecisionTreeClassifier(
                    random_state=self.RANDOM_STATE,
                    max_depth=max_depth,
                    min_samples_split=5,
                    min_samples_leaf=2
                ),
                "决策树", "Decision Tree - 简单快速", 
                False, False, "very_fast"
            ),
            
            # ===== SVM系列 =====
            "svm_linear": (
                SVC(
                    kernel="linear",
                    C=1.0,
                    cache_size=500,
                    probability=True, 
                    random_state=self.RANDOM_STATE,
                    max_iter=max_iter
                ),
                "SVM-线性核", "SVM Linear", 
                False, True, "medium"
            ),
            
            "linear_svc": (
                CalibratedClassifierCV(
                    LinearSVC(
                        C=1.0,
                        max_iter=max_iter,
                        random_state=self.RANDOM_STATE,
                        dual=False,
                        loss='squared_hinge'
                    ),
                    cv=3
                ),
                "线性SVM(快)", "Linear SVM - 快速版", 
                False, True, "fast"
            ),
            
            "sgd_svm": (
                SGDClassifier(
                    loss='hinge',
                    penalty='l2',
                    max_iter=max_iter,
                    n_jobs=-1,
                    random_state=self.RANDOM_STATE,
                    learning_rate='optimal'
                ),
                "SGD-SVM", "SGD SVM - 极快", 
                False, True, "very_fast"
            ),
            
            "nystroem_svm": (
                Pipeline([
                    ("feature_map", Nystroem(
                        kernel='rbf',
                        gamma=0.1,
                        n_components=n_components,
                        random_state=self.RANDOM_STATE
                    )),
                    ("sgd", SGDClassifier(
                        max_iter=max_iter,
                        random_state=self.RANDOM_STATE
                    ))
                ]),
                "核近似SVM", "Nystroem - RBF近似", 
                False, True, "fast"
            ),
            
            "rbf_sampler_svm": (
                Pipeline([
                    ("feature_map", RBFSampler(
                        gamma=0.1,
                        n_components=n_components,
                        random_state=self.RANDOM_STATE
                    )),
                    ("sgd", SGDClassifier(
                        max_iter=max_iter,
                        random_state=self.RANDOM_STATE
                    ))
                ]),
                "RBF采样SVM", "RBF Sampler", 
                False, True, "fast"
            ),
            
            "svm_rbf": (
                SVC(
                    kernel="rbf", 
                    C=1.0,
                    gamma='scale',
                    cache_size=500,
                    probability=True, 
                    random_state=self.RANDOM_STATE
                ),
                "SVM-RBF核⚠️", "SVM RBF - 慢但精确", 
                False, True, "very_slow"
            ),
            
            # ===== 其他分类器 =====
            "knn": (
                KNeighborsClassifier(
                    n_neighbors=5,
                    n_jobs=-1,
                    algorithm='ball_tree',
                    leaf_size=30
                ),
                "K近邻", "KNN", 
                False, True, "slow"
            ),
            
            "nb": (
                GaussianNB(),
                "朴素贝叶斯", "Naive Bayes - 最快", 
                False, False, "very_fast"
            ),
            
            "gb": (
                GradientBoostingClassifier(
                    n_estimators=n_est,
                    learning_rate=0.1,
                    max_depth=5,
                    random_state=self.RANDOM_STATE, 
                    verbose=0,
                    subsample=0.8
                ),
                "梯度提升", "Gradient Boosting", 
                False, False, "medium"
            ),
            
            "ada": (
                AdaBoostClassifier(
                    n_estimators=n_est,
                    learning_rate=1.0,
                    random_state=self.RANDOM_STATE,
                    algorithm='SAMME.R'
                ),
                "AdaBoost", "AdaBoost", 
                False, False, "medium"
            ),
            
            "lr": (
                LogisticRegression(
                    max_iter=max_iter,
                    n_jobs=-1,
                    random_state=self.RANDOM_STATE,
                    verbose=0,
                    solver='lbfgs',
                    multi_class='multinomial'
                ),
                "逻辑回归", "Logistic Regression", 
                False, True, "very_fast"
            ),
            
            "mlp": (
                MLPClassifier(
                    hidden_layer_sizes=(100, 50),
                    max_iter=max_iter,
                    random_state=self.RANDOM_STATE,
                    verbose=False,
                    early_stopping=True,
                    validation_fraction=0.1,
                    n_iter_no_change=10,
                    learning_rate='adaptive'
                ),
                "神经网络", "MLP", 
                False, True, "medium"
            ),
        }
        
        # 添加 XGBoost（如果可用）
        if self.has_xgboost:
            try:
                from xgboost import XGBClassifier
                classifiers["xgb"] = (
                    XGBClassifier(
                        n_estimators=n_est,
                        learning_rate=0.1,
                        max_depth=6,
                        n_jobs=-1,
                        random_state=self.RANDOM_STATE,
                        verbosity=0,
                        tree_method='hist',
                        subsample=0.8,
                        colsample_bytree=0.8
                    ),
                    "XGBoost", "XGBoost - 高性能", 
                    True, False, "fast"
                )
            except Exception as e:
                print(f"⚠️  XGBoost 实例化失败: {e}")
        
        # 添加 LightGBM（如果可用）
        if self.has_lightgbm:
            try:
                from lightgbm import LGBMClassifier
                classifiers["lgb"] = (
                    LGBMClassifier(
                        n_estimators=n_est,
                        learning_rate=0.1,
                        max_depth=max_depth,
                        n_jobs=-1,
                        random_state=self.RANDOM_STATE,
                        verbose=-1,
                        num_leaves=31,
                        subsample=0.8,
                        colsample_bytree=0.8,
                        force_col_wise=True  # 避免警告
                    ),
                    "LightGBM", "LightGBM - 极速", 
                    False, False, "very_fast"
                )
            except Exception as e:
                print(f"⚠️  LightGBM 实例化失败: {e}")
        
        return classifiers
    
    def get_background_mask(self, image):
        """获取背景掩膜"""
        data = image.values
        background_mask = np.all(data == 0, axis=0)
        return background_mask
    
    def get_class_info_from_shp(self, shp_path, class_attr, name_attr):
        """从shp文件获取类别信息"""
        gdf = gpd.read_file(shp_path)
        
        if name_attr not in gdf.columns:
            gdf[name_attr] = gdf[class_attr].apply(lambda x: f"Class_{x}")
        
        class_info = gdf[[class_attr, name_attr]].drop_duplicates()
        class_names = dict(zip(class_info[class_attr], class_info[name_attr]))
        
        class_colors = {}
        for i, (class_id, class_name) in enumerate(class_names.items()):
            color_found = False
            for key, color in self.LANDUSE_COLORS.items():
                if key in class_name:
                    class_colors[class_id] = color
                    color_found = True
                    break
            if not color_found:
                class_colors[class_id] = self.COLOR_PALETTE[i % len(self.COLOR_PALETTE)]
        
        return class_names, class_colors, sorted(class_names.keys())
    
    def rasterize_samples(self, shp, ref_img, attr):
        """矢量栅格化"""
        gdf = gpd.read_file(shp)
        gdf = gdf.to_crs(ref_img.rio.crs)
        shapes = ((geom, value) for geom, value in zip(gdf.geometry, gdf[attr]))
        
        arr = features.rasterize(
            shapes=shapes,
            out_shape=ref_img.shape[1:],
            transform=ref_img.rio.transform(),
            fill=0,
            all_touched=True,
            dtype="uint16"
        )
        return arr
    
    def extract_samples(self, image, mask, ignore_background=True, max_samples=None):
        """提取样本并清理NaN值，可选分层采样"""
        data = np.moveaxis(image.values, 0, -1)
        valid = mask > 0
        
        if ignore_background:
            background_mask = self.get_background_mask(image)
            valid = valid & (~background_mask)
        
        X = data[valid]
        y = mask[valid]
        
        # 清理NaN和Inf值
        nan_mask = np.isnan(X).any(axis=1)
        inf_mask = np.isinf(X).any(axis=1)
        bad_mask = nan_mask | inf_mask
        
        n_nan = np.sum(nan_mask)
        n_inf = np.sum(inf_mask)
        
        X = X[~bad_mask]
        y = y[~bad_mask]
        
        # 分层采样
        n_sampled = 0
        if max_samples is not None and len(y) > max_samples:
            n_original = len(y)
            unique_classes = np.unique(y)
            
            if len(unique_classes) > 1:
                splitter = StratifiedShuffleSplit(
                    n_splits=1, 
                    train_size=max_samples, 
                    random_state=self.RANDOM_STATE
                )
                sample_idx, _ = next(splitter.split(X, y))
                X = X[sample_idx]
                y = y[sample_idx]
                n_sampled = n_original - len(y)
            else:
                np.random.seed(self.RANDOM_STATE)
                sample_idx = np.random.choice(len(y), max_samples, replace=False)
                X = X[sample_idx]
                y = y[sample_idx]
                n_sampled = n_original - len(y)
        
        return X, y, n_nan, n_inf, n_sampled
    
    def calculate_metrics(self, y_true, y_pred):
        """计算评价指标"""
        return {
            'overall_accuracy': accuracy_score(y_true, y_pred),
            'kappa': cohen_kappa_score(y_true, y_pred),
            'precision_macro': precision_score(y_true, y_pred, average='macro', zero_division=0),
            'recall_macro': recall_score(y_true, y_pred, average='macro', zero_division=0),
            'f1_macro': f1_score(y_true, y_pred, average='macro', zero_division=0),
        }
    
    def estimate_prediction_time(self, clf_code, n_pixels, speed_tag):
        """估算预测时间（秒）"""
        time_per_million_pixels = {
            "very_fast": 1,
            "fast": 3,
            "medium": 10,
            "slow": 30,
            "very_slow": 300
        }
        base_time = time_per_million_pixels.get(speed_tag, 10)
        return (n_pixels / 1_000_000) * base_time
    
    def predict_by_block(self, model, image, out_path, block_size=512, 
                        ignore_background=True, progress_callback=None,
                        label_encoder=None, scaler=None):
        """分块预测"""
        height, width = image.shape[1], image.shape[2]
        prediction = np.zeros((height, width), dtype='uint16')
        
        if ignore_background:
            background_mask = self.get_background_mask(image)
        
        total_blocks = int(np.ceil(height / block_size))
        
        for i, y in enumerate(range(0, height, block_size)):
            h = min(block_size, height - y)
            block_data = image.isel(y=slice(y, y+h)).values
            data = np.moveaxis(block_data, 0, -1)
            original_shape = data.shape
            data_flat = data.reshape(-1, data.shape[-1])
            
            if ignore_background:
                block_bg_mask = background_mask[y:y+h, :].flatten()
                non_bg_indices = ~block_bg_mask
                
                if np.any(non_bg_indices):
                    data_to_predict = np.nan_to_num(data_flat[non_bg_indices], 
                                                   nan=0.0, posinf=0.0, neginf=0.0)
                    
                    if scaler is not None:
                        data_to_predict = scaler.transform(data_to_predict)
                    
                    preds_non_bg = model.predict(data_to_predict)
                    
                    if label_encoder is not None:
                        preds_non_bg = label_encoder.inverse_transform(preds_non_bg)
                    
                    preds_flat = np.zeros(len(data_flat), dtype='uint16')
                    preds_flat[non_bg_indices] = preds_non_bg
                    preds = preds_flat.reshape(original_shape[0], original_shape[1])
                else:
                    preds = np.zeros((original_shape[0], original_shape[1]), dtype='uint16')
            else:
                data_flat = np.nan_to_num(data_flat, nan=0.0, posinf=0.0, neginf=0.0)
                
                if scaler is not None:
                    data_flat = scaler.transform(data_flat)
                
                preds = model.predict(data_flat)
                
                if label_encoder is not None:
                    preds = label_encoder.inverse_transform(preds)
                
                preds = preds.reshape(original_shape[0], original_shape[1]).astype("uint16")
            
            prediction[y:y+h, :] = preds
            
            if progress_callback:
                progress_callback((i + 1) / total_blocks * 100)
        
        # 保存结果
        prediction_da = xr.DataArray(
            prediction,
            dims=['y', 'x'],
            coords={'y': image.coords['y'], 'x': image.coords['x']}
        )
        
        prediction_da.rio.write_crs(image.rio.crs, inplace=True)
        prediction_da.rio.write_transform(image.rio.transform(), inplace=True)
        prediction_da.rio.write_nodata(self.BACKGROUND_VALUE, inplace=True)
        
        prediction_da.rio.to_raster(out_path, driver='GTiff', dtype='uint16', 
                                    compress='lzw', tiled=True)
        return out_path

# ==================== GUI主类 ====================
class ClassificationGUI:
    """遥感影像分类GUI主界面"""
    
    def __init__(self, root):
        self.root = root
        self.root.title("遥感影像监督分类系统 v3.0.1")
        self.root.geometry("1450x950")
        
        # 后端处理对象
        self.backend = ClassificationBackend()
        
        # 数据变量
        self.image_path = tk.StringVar()
        self.train_shp_path = tk.StringVar()
        self.val_shp_path = tk.StringVar()
        self.output_dir = tk.StringVar(value=str(Path("./results_gui")))
        
        self.class_attr = tk.StringVar(value="class")
        self.name_attr = tk.StringVar(value="name")
        self.n_estimators = tk.IntVar(value=100)
        self.block_size = tk.IntVar(value=512)
        self.ignore_background = tk.BooleanVar(value=True)
        
        # 性能优化参数
        self.enable_sampling = tk.BooleanVar(value=True)
        self.max_samples = tk.IntVar(value=50000)
        self.fast_mode = tk.BooleanVar(value=False)
        
        # 分类器选择
        self.classifier_vars = {}
        all_classifiers = self.backend.get_all_classifiers()
        for code in all_classifiers.keys():
            self.classifier_vars[code] = tk.BooleanVar(value=False)
        
        # 运行状态
        self.is_running = False
        self.log_queue = queue.Queue()
        
        # 构建界面
        self.build_ui()
        
        # 启动日志更新
        self.update_log()
    
    def build_ui(self):
        """构建用户界面"""
        main_frame = ttk.Frame(self.root, padding="10")
        main_frame.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S))
        
        self.root.columnconfigure(0, weight=1)
        self.root.rowconfigure(0, weight=1)
        main_frame.columnconfigure(1, weight=1)
        main_frame.rowconfigure(3, weight=1)
        
        # 1. 文件选择区
        file_frame = ttk.LabelFrame(main_frame, text="1. 数据输入", padding="10")
        file_frame.grid(row=0, column=0, columnspan=2, sticky=(tk.W, tk.E), pady=5)
        
        ttk.Label(file_frame, text="影像文件:").grid(row=0, column=0, sticky=tk.W, pady=2)
        ttk.Entry(file_frame, textvariable=self.image_path, width=65).grid(
            row=0, column=1, sticky=(tk.W, tk.E), padx=5
        )
        ttk.Button(file_frame, text="浏览", command=self.browse_image).grid(row=0, column=2, padx=5)
        
        ttk.Label(file_frame, text="训练样本:").grid(row=1, column=0, sticky=tk.W, pady=2)
        ttk.Entry(file_frame, textvariable=self.train_shp_path, width=65).grid(
            row=1, column=1, sticky=(tk.W, tk.E), padx=5
        )
        ttk.Button(file_frame, text="浏览", command=self.browse_train_shp).grid(row=1, column=2, padx=5)
        
        ttk.Label(file_frame, text="验证样本:").grid(row=2, column=0, sticky=tk.W, pady=2)
        ttk.Entry(file_frame, textvariable=self.val_shp_path, width=65).grid(
            row=2, column=1, sticky=(tk.W, tk.E), padx=5
        )
        ttk.Button(file_frame, text="浏览", command=self.browse_val_shp).grid(row=2, column=2, padx=5)
        
        ttk.Label(file_frame, text="输出目录:").grid(row=3, column=0, sticky=tk.W, pady=2)
        ttk.Entry(file_frame, textvariable=self.output_dir, width=65).grid(
            row=3, column=1, sticky=(tk.W, tk.E), padx=5
        )
        ttk.Button(file_frame, text="浏览", command=self.browse_output).grid(row=3, column=2, padx=5)
        
        file_frame.columnconfigure(1, weight=1)
        
        # 2. 参数设置区
        param_frame = ttk.LabelFrame(main_frame, text="2. 参数配置", padding="10")
        param_frame.grid(row=1, column=0, sticky=(tk.W, tk.E, tk.N), pady=5, padx=(0, 5))
        
        ttk.Label(param_frame, text="类别编号字段:").grid(row=0, column=0, sticky=tk.W, pady=2)
        ttk.Entry(param_frame, textvariable=self.class_attr, width=15).grid(
            row=0, column=1, sticky=tk.W, padx=5
        )
        
        ttk.Label(param_frame, text="类别名称字段:").grid(row=1, column=0, sticky=tk.W, pady=2)
        ttk.Entry(param_frame, textvariable=self.name_attr, width=15).grid(
            row=1, column=1, sticky=tk.W, padx=5
        )
        
        ttk.Label(param_frame, text="树模型数量:").grid(row=2, column=0, sticky=tk.W, pady=2)
        ttk.Spinbox(param_frame, from_=10, to=500, textvariable=self.n_estimators, 
                   width=13).grid(row=2, column=1, sticky=tk.W, padx=5)
        
        ttk.Label(param_frame, text="分块大小:").grid(row=3, column=0, sticky=tk.W, pady=2)
        ttk.Spinbox(param_frame, from_=256, to=2048, increment=256, 
                   textvariable=self.block_size, width=13).grid(
            row=3, column=1, sticky=tk.W, padx=5
        )
        
        # 性能优化选项
        ttk.Separator(param_frame, orient='horizontal').grid(
            row=4, column=0, columnspan=2, sticky=(tk.W, tk.E), pady=8
        )
        
        ttk.Label(param_frame, text="⚡ 性能优化:", font=('', 9, 'bold')).grid(
            row=5, column=0, columnspan=2, sticky=tk.W, pady=2
        )
        
        sample_frame = ttk.Frame(param_frame)
        sample_frame.grid(row=6, column=0, columnspan=2, sticky=(tk.W, tk.E))
        
        ttk.Checkbutton(sample_frame, text="启用采样", 
                       variable=self.enable_sampling,
                       command=self.toggle_sampling).pack(side=tk.LEFT)
        
        ttk.Label(sample_frame, text="  最大样本数:").pack(side=tk.LEFT, padx=(10, 0))
        self.max_samples_spinbox = ttk.Spinbox(
            sample_frame, from_=10000, to=200000, increment=10000,
            textvariable=self.max_samples, width=10
        )
        self.max_samples_spinbox.pack(side=tk.LEFT, padx=5)
        
        ttk.Checkbutton(param_frame, text="快速模式", variable=self.fast_mode).grid(
            row=7, column=0, columnspan=2, sticky=tk.W, pady=2
        )
        
        ttk.Checkbutton(param_frame, text="忽略背景值", variable=self.ignore_background).grid(
            row=8, column=0, columnspan=2, sticky=tk.W, pady=2
        )
        
        # 3. 分类器选择区
        clf_frame = ttk.LabelFrame(main_frame, text="3. 分类器选择", padding="10")
        clf_frame.grid(row=1, column=1, rowspan=2, sticky=(tk.W, tk.E, tk.N, tk.S), pady=5)
        
        btn_frame = ttk.Frame(clf_frame)
        btn_frame.grid(row=0, column=0, columnspan=3, sticky=(tk.W, tk.E), pady=(0, 5))
        
        ttk.Button(btn_frame, text="全选", command=self.select_all_classifiers, 
                  width=12).pack(side=tk.LEFT, padx=2)
        ttk.Button(btn_frame, text="全不选", command=self.deselect_all_classifiers, 
                  width=12).pack(side=tk.LEFT, padx=2)
        ttk.Button(btn_frame, text="✓推荐", command=self.select_recommended, 
                  width=12).pack(side=tk.LEFT, padx=2)
        ttk.Button(btn_frame, text="⚡快速", command=self.select_fast, 
                  width=12).pack(side=tk.LEFT, padx=2)
        
        # 滚动区域
        canvas = tk.Canvas(clf_frame, height=200)
        scrollbar = ttk.Scrollbar(clf_frame, orient="vertical", command=canvas.yview)
        scrollable_frame = ttk.Frame(canvas)
        
        scrollable_frame.bind(
            "<Configure>",
            lambda e: canvas.configure(scrollregion=canvas.bbox("all"))
        )
        
        canvas.create_window((0, 0), window=scrollable_frame, anchor="nw")
        canvas.configure(yscrollcommand=scrollbar.set)
        
        all_classifiers = self.backend.get_all_classifiers()
        
        # SVM组
        ttk.Label(scrollable_frame, text="📊 SVM系列:", font=('', 9, 'bold')).grid(
            row=0, column=0, columnspan=3, sticky=tk.W, pady=(5, 2)
        )
        row = 1
        col = 0
        svm_codes = ["svm_linear", "linear_svc", "sgd_svm", "nystroem_svm", 
                     "rbf_sampler_svm", "svm_rbf"]
        for code in svm_codes:
            if code in all_classifiers:
                _, name, _, _, _, _ = all_classifiers[code]
                cb = ttk.Checkbutton(scrollable_frame, text=name, 
                                   variable=self.classifier_vars[code])
                cb.grid(row=row, column=col, sticky=tk.W, pady=1, padx=5)
                col += 1
                if col >= 3:
                    col = 0
                    row += 1
        
        if col > 0:
            row += 1
        
        # 树模型组
        ttk.Label(scrollable_frame, text="🌲 树模型:", font=('', 9, 'bold')).grid(
            row=row, column=0, columnspan=3, sticky=tk.W, pady=(10, 2)
        )
        row += 1
        col = 0
        tree_codes = ["rf", "et", "dt", "xgb", "lgb", "gb", "ada"]
        for code in tree_codes:
            if code in all_classifiers:
                _, name, _, _, _, _ = all_classifiers[code]
                cb = ttk.Checkbutton(scrollable_frame, text=name,
                                   variable=self.classifier_vars[code])
                cb.grid(row=row, column=col, sticky=tk.W, pady=1, padx=5)
                col += 1
                if col >= 3:
                    col = 0
                    row += 1
        
        if col > 0:
            row += 1
        
        # 其他分类器
        ttk.Label(scrollable_frame, text="📈 其他:", font=('', 9, 'bold')).grid(
            row=row, column=0, columnspan=3, sticky=tk.W, pady=(10, 2)
        )
        row += 1
        col = 0
        other_codes = ["knn", "nb", "lr", "mlp"]
        for code in other_codes:
            if code in all_classifiers:
                _, name, _, _, _, _ = all_classifiers[code]
                cb = ttk.Checkbutton(scrollable_frame, text=name,
                                   variable=self.classifier_vars[code])
                cb.grid(row=row, column=col, sticky=tk.W, pady=1, padx=5)
                col += 1
                if col >= 3:
                    col = 0
                    row += 1
        
        canvas.grid(row=1, column=0, columnspan=3, sticky=(tk.W, tk.E, tk.N, tk.S))
        scrollbar.grid(row=1, column=3, sticky=(tk.N, tk.S))
        clf_frame.rowconfigure(1, weight=1)
        
        # 4. 控制按钮区
        control_frame = ttk.LabelFrame(main_frame, text="4. 运行控制", padding="10")
        control_frame.grid(row=2, column=0, sticky=(tk.W, tk.E), pady=5, padx=(0, 5))
        
        self.start_btn = ttk.Button(control_frame, text="▶ 开始分类", 
                                    command=self.start_classification, width=15)
        self.start_btn.grid(row=0, column=0, padx=5, pady=5)
        
        self.stop_btn = ttk.Button(control_frame, text="⏸ 停止", 
                                   command=self.stop_classification, 
                                   state=tk.DISABLED, width=15)
        self.stop_btn.grid(row=0, column=1, padx=5, pady=5)
        
        ttk.Button(control_frame, text="📁 打开结果", 
                  command=self.open_result_dir, width=15).grid(row=0, column=2, padx=5, pady=5)
        
        ttk.Button(control_frame, text="📊 查看报告", 
                  command=self.view_report, width=15).grid(row=0, column=3, padx=5, pady=5)
        
        ttk.Label(control_frame, text="进度:").grid(row=1, column=0, sticky=tk.W, pady=2)
        self.progress_var = tk.DoubleVar()
        self.progress_bar = ttk.Progressbar(control_frame, variable=self.progress_var, 
                                           maximum=100, length=400)
        self.progress_bar.grid(row=1, column=1, columnspan=3, sticky=(tk.W, tk.E), 
                              padx=5, pady=2)
        
        control_frame.columnconfigure(3, weight=1)
        
        # 5. 日志输出区
        log_frame = ttk.LabelFrame(main_frame, text="5. 运行日志", padding="10")
        log_frame.grid(row=3, column=0, columnspan=2, sticky=(tk.W, tk.E, tk.N, tk.S), pady=5)
        
        self.log_text = scrolledtext.ScrolledText(log_frame, wrap=tk.WORD, 
                                                  height=18, width=120,
                                                  font=('Consolas', 9))
        self.log_text.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S))
        
        log_frame.columnconfigure(0, weight=1)
        log_frame.rowconfigure(0, weight=1)
        
        # 状态栏
        self.status_var = tk.StringVar(value="就绪")
        status_bar = ttk.Label(main_frame, textvariable=self.status_var, 
                              relief=tk.SUNKEN, anchor=tk.W)
        status_bar.grid(row=4, column=0, columnspan=2, sticky=(tk.W, tk.E), pady=(5, 0))
    
    def toggle_sampling(self):
        """切换采样功能"""
        if self.enable_sampling.get():
            self.max_samples_spinbox.config(state=tk.NORMAL)
        else:
            self.max_samples_spinbox.config(state=tk.DISABLED)
    
    def browse_image(self):
        filename = filedialog.askopenfilename(
            title="选择影像文件",
            filetypes=[("GeoTIFF", "*.tif *.tiff"), ("所有文件", "*.*")]
        )
        if filename:
            self.image_path.set(filename)
            self.status_var.set(f"已选择影像: {Path(filename).name}")
    
    def browse_train_shp(self):
        filename = filedialog.askopenfilename(
            title="选择训练样本",
            filetypes=[("Shapefile", "*.shp"), ("所有文件", "*.*")]
        )
        if filename:
            self.train_shp_path.set(filename)
    
    def browse_val_shp(self):
        filename = filedialog.askopenfilename(
            title="选择验证样本",
            filetypes=[("Shapefile", "*.shp"), ("所有文件", "*.*")]
        )
        if filename:
            self.val_shp_path.set(filename)
    
    def browse_output(self):
        dirname = filedialog.askdirectory(title="选择输出目录")
        if dirname:
            self.output_dir.set(dirname)
    
    def select_all_classifiers(self):
        for var in self.classifier_vars.values():
            var.set(True)
    
    def deselect_all_classifiers(self):
        for var in self.classifier_vars.values():
            var.set(False)
    
    def select_recommended(self):
        """推荐组合"""
        recommended = ["rf", "xgb", "et", "lgb", "linear_svc", "nystroem_svm"]
        for code, var in self.classifier_vars.items():
            var.set(code in recommended)
    
    def select_fast(self):
        """快速分类器"""
        fast = ["rf", "et", "dt", "xgb", "lgb", "nb", "lr", "sgd_svm", "linear_svc"]
        for code, var in self.classifier_vars.items():
            var.set(code in fast)
    
    def log(self, message):
        """添加日志"""
        self.log_queue.put(message)
    
    def update_log(self):
        """更新日志显示"""
        try:
            while True:
                message = self.log_queue.get_nowait()
                self.log_text.insert(tk.END, message + "\n")
                self.log_text.see(tk.END)
        except queue.Empty:
            pass
        self.root.after(100, self.update_log)
    
    def start_classification(self):
        """开始分类"""
        if not self.image_path.get():
            messagebox.showerror("错误", "请选择影像文件！")
            return
        
        if not self.train_shp_path.get():
            messagebox.showerror("错误", "请选择训练样本！")
            return
        
        selected_classifiers = [code for code, var in self.classifier_vars.items() if var.get()]
        if not selected_classifiers:
            messagebox.showerror("错误", "请至少选择一个分类器！")
            return
        
        # 性能警告
        all_classifiers = self.backend.get_all_classifiers()
        very_slow_clfs = []
        
        for code in selected_classifiers:
            if code in all_classifiers:
                speed_tag = all_classifiers[code][5]
                name = all_classifiers[code][1]
                if speed_tag == "very_slow":
                    very_slow_clfs.append(name)
        
        if very_slow_clfs:
            warning_msg = "⚠️ 以下分类器预测非常慢:\n"
            for clf in very_slow_clfs:
                warning_msg += f"  • {clf}\n"
            warning_msg += "\n建议使用其他SVM变体\n\n是否继续?"
            
            if not messagebox.askyesno("性能警告", warning_msg, icon='warning'):
                return
        
        self.start_btn.config(state=tk.DISABLED)
        self.stop_btn.config(state=tk.NORMAL)
        self.is_running = True
        
        self.log_text.delete(1.0, tk.END)
        self.log("="*80)
        self.log("  遥感影像监督分类系统 v3.0.1")
        self.log("="*80)
        self.log(f"选择的分类器: {len(selected_classifiers)} 个")
        self.log("")
        
        thread = threading.Thread(target=self.run_classification, args=(selected_classifiers,))
        thread.daemon = True
        thread.start()
    
    def stop_classification(self):
        """停止分类"""
        self.is_running = False
        self.log("\n⏸ 用户请求停止...")
        self.status_var.set("已停止")
    
    def run_classification(self, selected_classifiers):
        """执行分类"""
        try:
            out_dir = Path(self.output_dir.get())
            out_dir.mkdir(exist_ok=True)
            
            # 读取影像
            self.log(f"📁 读取影像...")
            self.status_var.set("读取影像...")
            img = rxr.open_rasterio(self.image_path.get(), masked=True)
            n_pixels = img.shape[1] * img.shape[2]
            self.log(f"   尺寸: {img.shape[1]}×{img.shape[2]} = {n_pixels:,} 像元")
            
            if not self.is_running:
                return
            
            # 读取类别信息
            self.log(f"\n📊 读取类别信息...")
            class_names, class_colors, _ = self.backend.get_class_info_from_shp(
                self.train_shp_path.get(), self.class_attr.get(), self.name_attr.get()
            )
            self.log(f"   类别: {list(class_names.values())}")
            
            # 提取训练样本
            self.log(f"\n🎯 处理训练样本...")
            self.status_var.set("处理训练样本...")
            train_mask = self.backend.rasterize_samples(
                self.train_shp_path.get(), img, self.class_attr.get()
            )
            
            max_samples = self.max_samples.get() if self.enable_sampling.get() else None
            
            X_train, y_train, n_nan, n_inf, n_sampled = self.backend.extract_samples(
                img, train_mask, 
                ignore_background=self.ignore_background.get(),
                max_samples=max_samples
            )
            
            self.log(f"   训练样本数: {len(y_train):,}")
            if n_nan > 0:
                self.log(f"   └─ 移除NaN: {n_nan:,}")
            if n_inf > 0:
                self.log(f"   └─ 移除Inf: {n_inf:,}")
            if n_sampled > 0:
                self.log(f"   └─ 采样减少: {n_sampled:,}")
            
            if not self.is_running:
                return
            
            # 提取验证样本
            val_exists = os.path.exists(self.val_shp_path.get())
            if val_exists:
                self.log(f"\n✅ 处理验证样本...")
                val_mask = self.backend.rasterize_samples(
                    self.val_shp_path.get(), img, self.class_attr.get()
                )
                
                if self.ignore_background.get():
                    background_mask = self.backend.get_background_mask(img)
                    valid_val = (val_mask > 0) & (~background_mask)
                else:
                    valid_val = val_mask > 0
                
                yv_true = val_mask[valid_val]
                self.log(f"   验证样本数: {len(yv_true):,}")
            
            # 分类器训练和评估
            all_classifiers = self.backend.get_all_classifiers(
                self.n_estimators.get(), 
                fast_mode=self.fast_mode.get(),
                n_train_samples=len(y_train)
            )
            
            comparison_results = []
            total_start_time = time.time()
            
            for i, clf_code in enumerate(selected_classifiers):
                if not self.is_running:
                    break
                
                clf, clf_name, clf_desc, needs_encoding, needs_scaling, speed_tag = all_classifiers[clf_code]
                
                self.log(f"\n{'='*80}")
                self.log(f"[{i+1}/{len(selected_classifiers)}] {clf_name}")
                self.log(f"{'='*80}")
                
                # 预估时间
                est_pred_time = self.backend.estimate_prediction_time(clf_code, n_pixels, speed_tag)
                if est_pred_time > 60:
                    self.log(f"⏱️  预计预测: ~{est_pred_time/60:.1f} 分钟")
                
                self.status_var.set(f"[{i+1}/{len(selected_classifiers)}] 训练 {clf_name}...")
                
                clf_dir = out_dir / clf_code
                clf_dir.mkdir(exist_ok=True)
                
                try:
                    # 数据预处理
                    label_encoder = None
                    scaler = None
                    X_train_use = X_train.copy()
                    y_train_use = y_train.copy()
                    
                    if needs_encoding:
                        self.log("   🔄 标签编码...")
                        label_encoder = LabelEncoder()
                        y_train_use = label_encoder.fit_transform(y_train)
                    
                    if needs_scaling:
                        self.log("   📏 特征缩放...")
                        scaler = StandardScaler()
                        X_train_use = scaler.fit_transform(X_train_use)
                    
                    # 训练
                    self.log("   🔨 训练中...")
                    train_start = time.time()
                    clf.fit(X_train_use, y_train_use)
                    train_time = time.time() - train_start
                    self.log(f"   ✓ 训练完成: {train_time:.2f}秒")
                    
                    # 训练集精度
                    y_train_pred = clf.predict(X_train_use)
                    
                    if label_encoder is not None:
                        y_train_pred = label_encoder.inverse_transform(y_train_pred)
                    
                    train_metrics = self.backend.calculate_metrics(y_train, y_train_pred)
                    self.log(f"   📈 训练集 - 精度: {train_metrics['overall_accuracy']:.4f}")
                    
                    if not self.is_running:
                        break
                    
                    # 预测整幅影像
                    self.log("   🗺️  预测影像...")
                    self.status_var.set(f"[{i+1}/{len(selected_classifiers)}] 预测 {clf_name}...")
                    
                    pred_start = time.time()
                    classified_path = clf_dir / f"classified_{clf_code}.tif"
                    
                    def update_progress(progress):
                        self.progress_var.set(progress)
                    
                    self.backend.predict_by_block(
                        clf, img, classified_path, 
                        block_size=self.block_size.get(),
                        ignore_background=self.ignore_background.get(),
                        progress_callback=update_progress,
                        label_encoder=label_encoder,
                        scaler=scaler
                    )
                    
                    pred_time = time.time() - pred_start
                    self.log(f"   ✓ 预测完成: {pred_time:.2f}秒")
                    
                    # 验证集精度
                    val_metrics = {'overall_accuracy': np.nan, 'kappa': np.nan, 'f1_macro': np.nan}
                    if val_exists:
                        with rxr.open_rasterio(classified_path) as pred_img:
                            pred_arr = pred_img.values.squeeze()
                        
                        yv_pred = pred_arr[valid_val]
                        val_metrics = self.backend.calculate_metrics(yv_true, yv_pred)
                        self.log(f"   📊 验证集 - 精度: {val_metrics['overall_accuracy']:.4f}")
                    
                    # 记录结果
                    result = {
                        '分类器代码': clf_code,
                        '分类器名称': clf_name,
                        '速度等级': speed_tag,
                        '训练集精度': train_metrics['overall_accuracy'],
                        '训练集Kappa': train_metrics['kappa'],
                        '验证集精度': val_metrics['overall_accuracy'],
                        '验证集Kappa': val_metrics['kappa'],
                        '训练时间(秒)': train_time,
                        '预测时间(秒)': pred_time,
                        '总时间(秒)': train_time + pred_time
                    }
                    comparison_results.append(result)
                    
                    self.log(f"   ✅ {clf_name} 完成!")
                    
                except Exception as e:
                    self.log(f"   ❌ {clf_name} 失败: {str(e)}")
                    continue
                
                self.progress_var.set((i + 1) / len(selected_classifiers) * 100)
            
            # 生成对比报告
            if comparison_results and self.is_running:
                total_time = time.time() - total_start_time
                
                self.log(f"\n{'='*80}")
                self.log("📝 生成报告...")
                self.status_var.set("生成报告...")
                
                comparison_df = pd.DataFrame(comparison_results)
                comparison_df.to_csv(out_dir / "classifier_comparison.csv", 
                                   index=False, encoding='utf-8-sig')
                
                # 生成详细报告
                with open(out_dir / "comparison_summary.txt", 'w', encoding='utf-8') as f:
                    f.write("="*70 + "\n")
                    f.write("        遥感影像分类器性能对比报告\n")
                    f.write("="*70 + "\n\n")
                    
                    f.write(f"时间: {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
                    f.write(f"影像: {img.shape[1]}×{img.shape[2]} = {n_pixels:,} 像元\n")
                    f.write(f"训练样本: {len(y_train):,}\n")
                    if val_exists:
                        f.write(f"验证样本: {len(yv_true):,}\n")
                    f.write(f"类别数: {len(class_names)}\n")
                    f.write(f"成功: {len(comparison_results)}/{len(selected_classifiers)}\n")
                    f.write(f"总耗时: {total_time/60:.1f} 分钟\n\n")
                    
                    # 精度排名
                    sorted_df = comparison_df.sort_values('验证集精度', ascending=False)
                    f.write("-"*70 + "\n")
                    f.write("📊 验证集精度排名:\n")
                    f.write("-"*70 + "\n")
                    for idx, (_, row) in enumerate(sorted_df.iterrows(), 1):
                        f.write(f"{idx:2d}. {row['分类器名称']:18s} - "
                               f"精度: {row['验证集精度']:.4f}, "
                               f"Kappa: {row['验证集Kappa']:.4f}\n")
                    
                    # 速度排名
                    f.write("\n" + "-"*70 + "\n")
                    f.write("⚡ 总时间排名:\n")
                    f.write("-"*70 + "\n")
                    sorted_time = comparison_df.sort_values('总时间(秒)')
                    for idx, (_, row) in enumerate(sorted_time.iterrows(), 1):
                        f.write(f"{idx:2d}. {row['分类器名称']:18s} - "
                               f"{row['总时间(秒)']:7.2f}秒\n")
                    
                    # 推荐
                    f.write("\n" + "="*70 + "\n")
                    f.write("💡 推荐:\n")
                    f.write("-"*70 + "\n")
                    
                    best_acc = sorted_df.iloc[0]
                    f.write(f"🏆 最高精度: {best_acc['分类器名称']} ({best_acc['验证集精度']:.4f})\n")
                    
                    best_speed = sorted_time.iloc[0]
                    f.write(f"⚡ 最快速度: {best_speed['分类器名称']} ({best_speed['总时间(秒)']:.2f}秒)\n")
                
                self.log("✅ 所有任务完成!")
                self.log(f"📁 结果: {out_dir.absolute()}")
                self.log(f"⏱️  总耗时: {total_time/60:.1f} 分钟")
                
                best_clf = comparison_df.loc[comparison_df['验证集精度'].idxmax()]
                self.log(f"\n🏆 最佳: {best_clf['分类器名称']} ({best_clf['验证集精度']:.4f})")
                
                self.status_var.set(f"✅ 完成! 最佳: {best_clf['分类器名称']}")
                
                messagebox.showinfo("任务完成", 
                    f"🎉 分类任务完成!\n\n"
                    f"✅ 成功: {len(comparison_results)}/{len(selected_classifiers)}\n"
                    f"🏆 最佳: {best_clf['分类器名称']} ({best_clf['验证集精度']:.4f})\n"
                    f"⏱️  耗时: {total_time/60:.1f} 分钟")
            
        except Exception as e:
            self.log(f"\n❌ 错误: {str(e)}")
            import traceback
            self.log(traceback.format_exc())
            messagebox.showerror("错误", f"发生错误:\n{str(e)}")
            self.status_var.set("❌ 错误")
        
        finally:
            self.start_btn.config(state=tk.NORMAL)
            self.stop_btn.config(state=tk.DISABLED)
            self.progress_var.set(0)
            self.is_running = False
    
    def open_result_dir(self):
        """打开结果目录"""
        out_dir = Path(self.output_dir.get())
        if out_dir.exists():
            import subprocess
            import platform
            
            if platform.system() == "Windows":
                os.startfile(out_dir)
            elif platform.system() == "Darwin":
                subprocess.Popen(["open", out_dir])
            else:
                subprocess.Popen(["xdg-open", out_dir])
        else:
            messagebox.showwarning("警告", "结果目录不存在！")
    
    def view_report(self):
        """查看对比报告"""
        report_file = Path(self.output_dir.get()) / "comparison_summary.txt"
        if report_file.exists():
            report_window = tk.Toplevel(self.root)
            report_window.title("📊 分类器对比报告")
            report_window.geometry("900x700")
            
            toolbar = ttk.Frame(report_window)
            toolbar.pack(side=tk.TOP, fill=tk.X, padx=5, pady=5)
            
            ttk.Button(toolbar, text="📁 打开CSV", 
                      command=lambda: os.startfile(
                          Path(self.output_dir.get()) / "classifier_comparison.csv"
                      )).pack(side=tk.LEFT, padx=2)
            
            text_widget = scrolledtext.ScrolledText(report_window, wrap=tk.WORD,
                                                   font=('Consolas', 10))
            text_widget.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
            
            with open(report_file, 'r', encoding='utf-8') as f:
                content = f.read()
                text_widget.insert(1.0, content)
            
            text_widget.config(state=tk.DISABLED)
        else:
            messagebox.showwarning("警告", "报告文件不存在！请先运行分类。")

# ==================== 主程序入口 ====================
def main():
    """程序入口"""
    print("="*80)
    print("  遥感影像监督分类系统 v3.0.1")
    print("="*80)
    print("\n正在检查依赖库...")
    
    root = tk.Tk()
    app = ClassificationGUI(root)
    
    # 显示欢迎信息
    app.log("="*80)
    app.log("  遥感影像监督分类系统 v3.0.1")
    app.log("="*80)
    app.log("支持的分类器:")
    app.log("  📊 SVM系列: 线性核、RBF核、SGD-SVM、核近似等")
    app.log("  🌲 树模型: RF、XGBoost、LightGBM、ET、GB、DT等")
    app.log("  📈 其他: KNN、朴素贝叶斯、逻辑回归、神经网络等")
    app.log("")
    app.log("优化特性:")
    app.log("  ⚡ 数据采样 - 加快训练")
    app.log("  📏 特征缩放 - 提升性能")
    app.log("  🚀 快速模式 - 减少复杂度")
    app.log("")
    app.log("💡 提示: 点击'✓推荐'或'⚡快速'快速选择分类器")
    app.log("="*80)
    app.log("")
    
    print("\n✓ 系统启动成功!")
    print("请在GUI界面中操作...")
    
    root.mainloop()

if __name__ == "__main__":
    main()

  遥感影像监督分类系统 v3.0.1

正在检查依赖库...
✓ XGBoost 可用
✓ LightGBM 可用

✓ 系统启动成功!
请在GUI界面中操作...


Exception in thread Thread-5 (run_classification):
Traceback (most recent call last):
  File "C:\Users\xyt556\AppData\Local\Temp\7\ipykernel_109364\1781054574.py", line 1080, in run_classification
  File "D:\Python310\lib\tkinter\__init__.py", line 402, in set
    return self._tk.globalsetvar(self._name, value)
RuntimeError: main thread is not in main loop

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "C:\Users\xyt556\AppData\Local\Temp\7\ipykernel_109364\1781054574.py", line 1253, in run_classification
  File "D:\Python310\lib\tkinter\__init__.py", line 402, in set
    return self._tk.globalsetvar(self._name, value)
RuntimeError: main thread is not in main loop

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "D:\Python310\lib\threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "d:\geog_2025\envi\lib\site-packages\ipykernel\ipkernel.py", line

# 版本3

In [5]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
遥感影像监督分类系统 - 专业版 v4.1
=====================================
新增:
- 完善结果预览显示
- Excel格式报告输出
- 混淆矩阵可视化
- 图表实时刷新优化
"""

import os
import sys
import time
import threading
import queue
from pathlib import Path
import tkinter as tk
from tkinter import ttk, filedialog, messagebox, scrolledtext
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg, NavigationToolbar2Tk
from matplotlib.figure import Figure
import matplotlib.colors as mcolors
import seaborn as sns
import geopandas as gpd
import rioxarray as rxr
import xarray as xr
from rasterio import features
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.ensemble import (RandomForestClassifier, GradientBoostingClassifier, 
                              AdaBoostClassifier, ExtraTreesClassifier)
from sklearn.svm import SVC, LinearSVC
from sklearn.linear_model import SGDClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier
from sklearn.calibration import CalibratedClassifierCV
from sklearn.kernel_approximation import Nystroem, RBFSampler
from sklearn.pipeline import Pipeline
from sklearn.metrics import (confusion_matrix, accuracy_score, classification_report, 
                           cohen_kappa_score, precision_score, recall_score, f1_score)
from scipy import ndimage
from skimage import morphology
import warnings
warnings.filterwarnings('ignore')

sns.set_style("whitegrid")

# 设置matplotlib中文显示
plt.rcParams["font.sans-serif"] = ["SimHei", "DejaVu Sans", "Arial Unicode MS"]
plt.rcParams["axes.unicode_minus"] = False

# 检查openpyxl
try:
    import openpyxl
    HAS_OPENPYXL = True
except ImportError:
    HAS_OPENPYXL = False
    print("⚠️  未安装openpyxl，将无法导出Excel文件")
    print("   安装: pip install openpyxl")

# ==================== 后端处理类 ====================
class ClassificationBackend:
    """分类处理后端"""
    
    def __init__(self):
        self.RANDOM_STATE = 42
        
        # 预定义颜色
        self.LANDUSE_COLORS = {
            "水体": "lightblue", "河流": "blue", "湖泊": "deepskyblue",
            "植被": "forestgreen", "森林": "darkgreen", "草地": "limegreen",
            "农田": "yellowgreen", "耕地": "olivedrab",
            "建筑": "gray", "城市": "dimgray", "居民地": "slategray",
            "裸地": "tan", "沙地": "wheat", "其他": "darkred"
        }
        
        self.COLOR_PALETTE = ['forestgreen', 'lightblue', 'gray', 'tan', 'yellow', 
                             'darkred', 'purple', 'orange', 'pink', 'brown']
        
        # 检查可选库
        self.check_optional_libraries()
    
    def check_optional_libraries(self):
        """检查可选库是否可用"""
        self.has_xgboost = False
        self.has_lightgbm = False
        
        try:
            import xgboost
            from xgboost import XGBClassifier
            _ = XGBClassifier(n_estimators=10, verbosity=0)
            self.has_xgboost = True
            print("✓ XGBoost 可用")
        except Exception:
            print("✗ XGBoost 不可用")
        
        try:
            import lightgbm
            from lightgbm import LGBMClassifier
            _ = LGBMClassifier(n_estimators=10, verbose=-1)
            self.has_lightgbm = True
            print("✓ LightGBM 可用")
        except Exception:
            print("✗ LightGBM 不可用")
    
    def get_all_classifiers(self, n_estimators=100, fast_mode=False, n_train_samples=None):
        """获取所有可用分类器"""
        if fast_mode:
            n_est = min(50, n_estimators)
            max_depth = 10
            max_iter = 200
        else:
            n_est = n_estimators
            max_depth = 20
            max_iter = 500
        
        if n_train_samples:
            n_components = min(1000, n_train_samples // 2)
        else:
            n_components = 1000
        
        classifiers = {
            "rf": (RandomForestClassifier(n_estimators=n_est, n_jobs=-1, random_state=self.RANDOM_STATE, 
                                         verbose=0, max_depth=max_depth, min_samples_split=5, 
                                         max_features='sqrt'),
                  "随机森林", "Random Forest", False, False, "fast"),
            
            "et": (ExtraTreesClassifier(n_estimators=n_est, n_jobs=-1, random_state=self.RANDOM_STATE,
                                       verbose=0, max_depth=max_depth, min_samples_split=5, max_features='sqrt'),
                  "极端随机树", "Extra Trees", False, False, "fast"),
            
            "dt": (DecisionTreeClassifier(random_state=self.RANDOM_STATE, max_depth=max_depth,
                                         min_samples_split=5, min_samples_leaf=2),
                  "决策树", "Decision Tree", False, False, "very_fast"),
            
            "svm_linear": (SVC(kernel="linear", C=1.0, cache_size=500, probability=True, 
                             random_state=self.RANDOM_STATE, max_iter=max_iter),
                          "SVM-线性核", "SVM Linear", False, True, "medium"),
            
            "linear_svc": (CalibratedClassifierCV(LinearSVC(C=1.0, max_iter=max_iter, random_state=self.RANDOM_STATE,
                                                           dual=False, loss='squared_hinge'), cv=3),
                          "线性SVM(快)", "Linear SVM", False, True, "fast"),
            
            "sgd_svm": (SGDClassifier(loss='hinge', penalty='l2', max_iter=max_iter, n_jobs=-1,
                                     random_state=self.RANDOM_STATE, learning_rate='optimal'),
                       "SGD-SVM", "SGD SVM", False, True, "very_fast"),
            
            "nystroem_svm": (Pipeline([
                ("feature_map", Nystroem(kernel='rbf', gamma=0.1, n_components=n_components, 
                                        random_state=self.RANDOM_STATE)),
                ("sgd", SGDClassifier(max_iter=max_iter, random_state=self.RANDOM_STATE))
            ]), "核近似SVM", "Nystroem SVM", False, True, "fast"),
            
            "rbf_sampler_svm": (Pipeline([
                ("feature_map", RBFSampler(gamma=0.1, n_components=n_components, random_state=self.RANDOM_STATE)),
                ("sgd", SGDClassifier(max_iter=max_iter, random_state=self.RANDOM_STATE))
            ]), "RBF采样SVM", "RBF Sampler", False, True, "fast"),
            
            "svm_rbf": (SVC(kernel="rbf", C=1.0, gamma='scale', cache_size=500, probability=True, 
                          random_state=self.RANDOM_STATE),
                       "SVM-RBF核⚠️", "SVM RBF", False, True, "very_slow"),
            
            "knn": (KNeighborsClassifier(n_neighbors=5, n_jobs=-1, algorithm='ball_tree', leaf_size=30),
                   "K近邻", "KNN", False, True, "slow"),
            
            "nb": (GaussianNB(), "朴素贝叶斯", "Naive Bayes", False, False, "very_fast"),
            
            "gb": (GradientBoostingClassifier(n_estimators=n_est, learning_rate=0.1, max_depth=5,
                                             random_state=self.RANDOM_STATE, verbose=0, subsample=0.8),
                  "梯度提升", "Gradient Boosting", False, False, "medium"),
            
            "ada": (AdaBoostClassifier(n_estimators=n_est, learning_rate=1.0, 
                                      random_state=self.RANDOM_STATE, algorithm='SAMME.R'),
                   "AdaBoost", "AdaBoost", False, False, "medium"),
            
            "lr": (LogisticRegression(max_iter=max_iter, n_jobs=-1, random_state=self.RANDOM_STATE,
                                     verbose=0, solver='lbfgs', multi_class='multinomial'),
                  "逻辑回归", "Logistic Regression", False, True, "very_fast"),
            
            "mlp": (MLPClassifier(hidden_layer_sizes=(100, 50), max_iter=max_iter, random_state=self.RANDOM_STATE,
                                 verbose=False, early_stopping=True, validation_fraction=0.1, 
                                 n_iter_no_change=10, learning_rate='adaptive'),
                   "神经网络", "MLP", False, True, "medium"),
        }
        
        if self.has_xgboost:
            try:
                from xgboost import XGBClassifier
                classifiers["xgb"] = (
                    XGBClassifier(n_estimators=n_est, learning_rate=0.1, max_depth=6, n_jobs=-1,
                                 random_state=self.RANDOM_STATE, verbosity=0, tree_method='hist',
                                 subsample=0.8, colsample_bytree=0.8),
                    "XGBoost", "XGBoost", True, False, "fast"
                )
            except Exception:
                pass
        
        if self.has_lightgbm:
            try:
                from lightgbm import LGBMClassifier
                classifiers["lgb"] = (
                    LGBMClassifier(n_estimators=n_est, learning_rate=0.1, max_depth=max_depth, n_jobs=-1,
                                  random_state=self.RANDOM_STATE, verbose=-1, num_leaves=31,
                                  subsample=0.8, colsample_bytree=0.8, force_col_wise=True),
                    "LightGBM", "LightGBM", False, False, "very_fast"
                )
            except Exception:
                pass
        
        return classifiers
    
    def get_background_mask(self, image, background_value=0):
        """获取背景掩膜"""
        data = image.values
        if background_value == 0:
            background_mask = np.all(data == 0, axis=0)
        else:
            background_mask = np.all(data == background_value, axis=0)
        return background_mask
    
    def get_shapefile_fields(self, shp_path):
        """获取shapefile的所有字段名"""
        try:
            gdf = gpd.read_file(shp_path)
            return list(gdf.columns)
        except Exception as e:
            print(f"读取shapefile字段失败: {e}")
            return []
    
    def get_class_info_from_shp(self, shp_path, class_attr, name_attr):
        """从shp文件获取类别信息"""
        gdf = gpd.read_file(shp_path)
        
        if name_attr not in gdf.columns or name_attr == class_attr:
            gdf[name_attr] = gdf[class_attr].apply(lambda x: f"类别_{x}")
        
        class_info = gdf[[class_attr, name_attr]].drop_duplicates()
        class_names = dict(zip(class_info[class_attr], class_info[name_attr]))
        
        class_colors = {}
        for i, (class_id, class_name) in enumerate(class_names.items()):
            color_found = False
            for key, color in self.LANDUSE_COLORS.items():
                if key in str(class_name):
                    class_colors[class_id] = color
                    color_found = True
                    break
            if not color_found:
                class_colors[class_id] = self.COLOR_PALETTE[i % len(self.COLOR_PALETTE)]
        
        return class_names, class_colors, sorted(class_names.keys())
    
    def rasterize_samples(self, shp, ref_img, attr):
        """矢量栅格化"""
        gdf = gpd.read_file(shp)
        gdf = gdf.to_crs(ref_img.rio.crs)
        shapes = ((geom, value) for geom, value in zip(gdf.geometry, gdf[attr]))
        
        arr = features.rasterize(shapes=shapes, out_shape=ref_img.shape[1:],
                                transform=ref_img.rio.transform(), fill=0,
                                all_touched=True, dtype="uint16")
        return arr
    
    def extract_samples(self, image, mask, ignore_background=True, background_value=0, max_samples=None):
        """提取样本"""
        data = np.moveaxis(image.values, 0, -1)
        valid = mask > 0
        
        if ignore_background:
            background_mask = self.get_background_mask(image, background_value)
            valid = valid & (~background_mask)
        
        X = data[valid]
        y = mask[valid]
        
        # 清理NaN和Inf
        nan_mask = np.isnan(X).any(axis=1)
        inf_mask = np.isinf(X).any(axis=1)
        bad_mask = nan_mask | inf_mask
        
        n_nan = np.sum(nan_mask)
        n_inf = np.sum(inf_mask)
        
        X = X[~bad_mask]
        y = y[~bad_mask]
        
        # 分层采样
        n_sampled = 0
        if max_samples is not None and len(y) > max_samples:
            n_original = len(y)
            unique_classes = np.unique(y)
            
            if len(unique_classes) > 1:
                splitter = StratifiedShuffleSplit(n_splits=1, train_size=max_samples, 
                                                 random_state=self.RANDOM_STATE)
                sample_idx, _ = next(splitter.split(X, y))
                X = X[sample_idx]
                y = y[sample_idx]
                n_sampled = n_original - len(y)
            else:
                np.random.seed(self.RANDOM_STATE)
                sample_idx = np.random.choice(len(y), max_samples, replace=False)
                X = X[sample_idx]
                y = y[sample_idx]
                n_sampled = n_original - len(y)
        
        return X, y, n_nan, n_inf, n_sampled
    
    def calculate_metrics(self, y_true, y_pred):
        """计算评价指标"""
        return {
            'overall_accuracy': accuracy_score(y_true, y_pred),
            'kappa': cohen_kappa_score(y_true, y_pred),
            'precision_macro': precision_score(y_true, y_pred, average='macro', zero_division=0),
            'recall_macro': recall_score(y_true, y_pred, average='macro', zero_division=0),
            'f1_macro': f1_score(y_true, y_pred, average='macro', zero_division=0),
        }
    
    def estimate_prediction_time(self, clf_code, n_pixels, speed_tag):
        """估算预测时间"""
        time_per_million = {"very_fast": 1, "fast": 3, "medium": 10, "slow": 30, "very_slow": 300}
        base_time = time_per_million.get(speed_tag, 10)
        return (n_pixels / 1_000_000) * base_time
    
    def predict_by_block(self, model, image, out_path, block_size=512, 
                        ignore_background=True, background_value=0, progress_callback=None,
                        label_encoder=None, scaler=None):
        """分块预测"""
        height, width = image.shape[1], image.shape[2]
        prediction = np.zeros((height, width), dtype='uint16')
        
        if ignore_background:
            background_mask = self.get_background_mask(image, background_value)
        
        total_blocks = int(np.ceil(height / block_size))
        
        for i, y in enumerate(range(0, height, block_size)):
            h = min(block_size, height - y)
            block_data = image.isel(y=slice(y, y+h)).values
            data = np.moveaxis(block_data, 0, -1)
            original_shape = data.shape
            data_flat = data.reshape(-1, data.shape[-1])
            
            if ignore_background:
                block_bg_mask = background_mask[y:y+h, :].flatten()
                non_bg_indices = ~block_bg_mask
                
                if np.any(non_bg_indices):
                    data_to_predict = np.nan_to_num(data_flat[non_bg_indices], 
                                                   nan=0.0, posinf=0.0, neginf=0.0)
                    
                    if scaler is not None:
                        data_to_predict = scaler.transform(data_to_predict)
                    
                    preds_non_bg = model.predict(data_to_predict)
                    
                    if label_encoder is not None:
                        preds_non_bg = label_encoder.inverse_transform(preds_non_bg)
                    
                    preds_flat = np.zeros(len(data_flat), dtype='uint16')
                    preds_flat[non_bg_indices] = preds_non_bg
                    preds = preds_flat.reshape(original_shape[0], original_shape[1])
                else:
                    preds = np.zeros((original_shape[0], original_shape[1]), dtype='uint16')
            else:
                data_flat = np.nan_to_num(data_flat, nan=0.0, posinf=0.0, neginf=0.0)
                
                if scaler is not None:
                    data_flat = scaler.transform(data_flat)
                
                preds = model.predict(data_flat)
                
                if label_encoder is not None:
                    preds = label_encoder.inverse_transform(preds)
                
                preds = preds.reshape(original_shape[0], original_shape[1]).astype("uint16")
            
            prediction[y:y+h, :] = preds
            
            if progress_callback:
                progress_callback((i + 1) / total_blocks * 100)
        
        # 保存结果
        prediction_da = xr.DataArray(prediction, dims=['y', 'x'],
                                     coords={'y': image.coords['y'], 'x': image.coords['x']})
        
        prediction_da.rio.write_crs(image.rio.crs, inplace=True)
        prediction_da.rio.write_transform(image.rio.transform(), inplace=True)
        prediction_da.rio.write_nodata(background_value, inplace=True)
        
        prediction_da.rio.to_raster(out_path, driver='GTiff', dtype='uint16', 
                                    compress='lzw', tiled=True)
        return out_path

# ==================== GUI主类 ====================
class ClassificationGUI:
    """遥感影像分类GUI主界面（专业版）"""
    
    def __init__(self, root):
        self.root = root
        self.root.title("遥感影像监督分类系统 v4.1 - 专业版")
        self.root.geometry("1600x900")
        
        # 后端处理对象
        self.backend = ClassificationBackend()
        
        # 数据变量
        self.image_path = tk.StringVar()
        self.train_shp_path = tk.StringVar()
        self.val_shp_path = tk.StringVar()
        self.output_dir = tk.StringVar(value=str(Path("./results_gui")))
        
        # 字段选择
        self.train_fields = []
        self.class_attr = tk.StringVar()
        self.name_attr = tk.StringVar()
        
        # 背景值
        self.background_value = tk.IntVar(value=0)
        self.ignore_background = tk.BooleanVar(value=True)
        
        # 其他参数
        self.n_estimators = tk.IntVar(value=100)
        self.block_size = tk.IntVar(value=512)
        
        # 性能优化参数
        self.enable_sampling = tk.BooleanVar(value=True)
        self.max_samples = tk.IntVar(value=50000)
        self.fast_mode = tk.BooleanVar(value=False)
        
        # 分类器选择
        self.classifier_vars = {}
        all_classifiers = self.backend.get_all_classifiers()
        for code in all_classifiers.keys():
            self.classifier_vars[code] = tk.BooleanVar(value=False)
        
        # 运行状态
        self.is_running = False
        self.log_queue = queue.Queue()
        
        # 存储结果数据
        self.comparison_results = []
        self.current_confusion_matrix = None
        self.current_y_true = None
        self.current_y_pred = None
        self.class_names_dict = {}
        self.class_colors_dict = {}
        self.best_result_path = None
        
        # 构建界面
        self.build_ui()
        
        # 启动日志更新
        self.update_log()
    
    def build_ui(self):
        """构建用户界面（左右分栏）"""
        # 创建主PanedWindow
        main_paned = ttk.PanedWindow(self.root, orient=tk.HORIZONTAL)
        main_paned.pack(fill=tk.BOTH, expand=True, padx=5, pady=5)
        
        # ===== 左侧面板：参数设置 =====
        left_frame = ttk.Frame(main_paned, width=600)
        main_paned.add(left_frame, weight=1)
        
        # 创建滚动区域
        canvas = tk.Canvas(left_frame)
        scrollbar = ttk.Scrollbar(left_frame, orient="vertical", command=canvas.yview)
        scrollable_left = ttk.Frame(canvas)
        
        scrollable_left.bind("<Configure>", 
                            lambda e: canvas.configure(scrollregion=canvas.bbox("all")))
        
        canvas.create_window((0, 0), window=scrollable_left, anchor="nw")
        canvas.configure(yscrollcommand=scrollbar.set)
        
        canvas.pack(side="left", fill="both", expand=True)
        scrollbar.pack(side="right", fill="y")
        
        # 1. 文件选择
        file_frame = ttk.LabelFrame(scrollable_left, text="📁 数据文件", padding="10")
        file_frame.pack(fill=tk.X, padx=5, pady=5)
        
        ttk.Label(file_frame, text="影像文件:").grid(row=0, column=0, sticky=tk.W, pady=3)
        ttk.Entry(file_frame, textvariable=self.image_path, width=40).grid(
            row=0, column=1, sticky=(tk.W, tk.E), padx=5
        )
        ttk.Button(file_frame, text="浏览", command=self.browse_image).grid(row=0, column=2)
        
        ttk.Label(file_frame, text="训练样本:").grid(row=1, column=0, sticky=tk.W, pady=3)
        ttk.Entry(file_frame, textvariable=self.train_shp_path, width=40).grid(
            row=1, column=1, sticky=(tk.W, tk.E), padx=5
        )
        ttk.Button(file_frame, text="浏览", command=self.browse_train_shp).grid(row=1, column=2)
        
        ttk.Label(file_frame, text="验证样本:").grid(row=2, column=0, sticky=tk.W, pady=3)
        ttk.Entry(file_frame, textvariable=self.val_shp_path, width=40).grid(
            row=2, column=1, sticky=(tk.W, tk.E), padx=5
        )
        ttk.Button(file_frame, text="浏览", command=self.browse_val_shp).grid(row=2, column=2)
        
        ttk.Label(file_frame, text="输出目录:").grid(row=3, column=0, sticky=tk.W, pady=3)
        ttk.Entry(file_frame, textvariable=self.output_dir, width=40).grid(
            row=3, column=1, sticky=(tk.W, tk.E), padx=5
        )
        ttk.Button(file_frame, text="浏览", command=self.browse_output).grid(row=3, column=2)
        
        file_frame.columnconfigure(1, weight=1)
        
        # 2. 字段选择
        field_frame = ttk.LabelFrame(scrollable_left, text="🏷️ 字段配置", padding="10")
        field_frame.pack(fill=tk.X, padx=5, pady=5)
        
        ttk.Label(field_frame, text="类别编号字段:").grid(row=0, column=0, sticky=tk.W, pady=3)
        self.class_attr_combo = ttk.Combobox(field_frame, textvariable=self.class_attr, 
                                            width=20, state="readonly")
        self.class_attr_combo.grid(row=0, column=1, sticky=(tk.W, tk.E), padx=5)
        
        ttk.Label(field_frame, text="类别名称字段:").grid(row=1, column=0, sticky=tk.W, pady=3)
        self.name_attr_combo = ttk.Combobox(field_frame, textvariable=self.name_attr, 
                                           width=20, state="readonly")
        self.name_attr_combo.grid(row=1, column=1, sticky=(tk.W, tk.E), padx=5)
        
        ttk.Button(field_frame, text="🔄 刷新字段列表", 
                  command=self.refresh_fields).grid(row=0, column=2, rowspan=2, padx=5)
        
        field_frame.columnconfigure(1, weight=1)
        
        # 3. 背景值设置
        bg_frame = ttk.LabelFrame(scrollable_left, text="🎨 背景值设置", padding="10")
        bg_frame.pack(fill=tk.X, padx=5, pady=5)
        
        ttk.Checkbutton(bg_frame, text="忽略背景值", 
                       variable=self.ignore_background).grid(row=0, column=0, sticky=tk.W, pady=3)
        
        ttk.Label(bg_frame, text="背景值:").grid(row=1, column=0, sticky=tk.W, pady=3)
        ttk.Spinbox(bg_frame, from_=-9999, to=9999, textvariable=self.background_value, 
                   width=15).grid(row=1, column=1, sticky=tk.W, padx=5)
        ttk.Label(bg_frame, text="(默认0, 常见: -9999, 255)", 
                 font=('', 8), foreground='gray').grid(row=1, column=2, sticky=tk.W)
        
        # 4. 分类参数
        param_frame = ttk.LabelFrame(scrollable_left, text="⚙️ 分类参数", padding="10")
        param_frame.pack(fill=tk.X, padx=5, pady=5)
        
        ttk.Label(param_frame, text="树模型数量:").grid(row=0, column=0, sticky=tk.W, pady=3)
        ttk.Spinbox(param_frame, from_=10, to=500, textvariable=self.n_estimators, 
                   width=15).grid(row=0, column=1, sticky=tk.W, padx=5)
        
        ttk.Label(param_frame, text="分块大小:").grid(row=1, column=0, sticky=tk.W, pady=3)
        ttk.Spinbox(param_frame, from_=256, to=2048, increment=256, 
                   textvariable=self.block_size, width=15).grid(row=1, column=1, sticky=tk.W, padx=5)
        
        # 5. 性能优化
        opt_frame = ttk.LabelFrame(scrollable_left, text="⚡ 性能优化", padding="10")
        opt_frame.pack(fill=tk.X, padx=5, pady=5)
        
        sample_frame = ttk.Frame(opt_frame)
        sample_frame.pack(fill=tk.X, pady=2)
        
        ttk.Checkbutton(sample_frame, text="启用采样", 
                       variable=self.enable_sampling,
                       command=self.toggle_sampling).pack(side=tk.LEFT)
        
        ttk.Label(sample_frame, text="最大样本数:").pack(side=tk.LEFT, padx=(10, 0))
        self.max_samples_spinbox = ttk.Spinbox(sample_frame, from_=10000, to=200000, 
                                              increment=10000, textvariable=self.max_samples, 
                                              width=10)
        self.max_samples_spinbox.pack(side=tk.LEFT, padx=5)
        
        ttk.Checkbutton(opt_frame, text="快速模式", 
                       variable=self.fast_mode).pack(anchor=tk.W, pady=2)
        
        # 6. 分类器选择
        clf_frame = ttk.LabelFrame(scrollable_left, text="🤖 分类器选择", padding="10")
        clf_frame.pack(fill=tk.X, padx=5, pady=5)
        
        # 快捷按钮
        btn_frame = ttk.Frame(clf_frame)
        btn_frame.pack(fill=tk.X, pady=(0, 5))
        
        ttk.Button(btn_frame, text="全选", command=self.select_all_classifiers, 
                  width=10).pack(side=tk.LEFT, padx=2)
        ttk.Button(btn_frame, text="全不选", command=self.deselect_all_classifiers, 
                  width=10).pack(side=tk.LEFT, padx=2)
        ttk.Button(btn_frame, text="✓推荐", command=self.select_recommended, 
                  width=10).pack(side=tk.LEFT, padx=2)
        ttk.Button(btn_frame, text="⚡快速", command=self.select_fast, 
                  width=10).pack(side=tk.LEFT, padx=2)
        
        # 分类器复选框
        all_classifiers = self.backend.get_all_classifiers()
        
        clf_canvas = tk.Canvas(clf_frame, height=150)
        clf_scrollbar = ttk.Scrollbar(clf_frame, orient="vertical", command=clf_canvas.yview)
        clf_scrollable = ttk.Frame(clf_canvas)
        
        clf_scrollable.bind("<Configure>", 
                           lambda e: clf_canvas.configure(scrollregion=clf_canvas.bbox("all")))
        
        clf_canvas.create_window((0, 0), window=clf_scrollable, anchor="nw")
        clf_canvas.configure(yscrollcommand=clf_scrollbar.set)
        
        # SVM组
        ttk.Label(clf_scrollable, text="SVM:", font=('', 9, 'bold')).grid(
            row=0, column=0, sticky=tk.W, pady=2
        )
        row = 1
        svm_codes = ["svm_linear", "linear_svc", "sgd_svm", "nystroem_svm", 
                     "rbf_sampler_svm", "svm_rbf"]
        for code in svm_codes:
            if code in all_classifiers:
                _, name, _, _, _, _ = all_classifiers[code]
                ttk.Checkbutton(clf_scrollable, text=name, 
                              variable=self.classifier_vars[code]).grid(
                    row=row, column=0, sticky=tk.W, padx=20
                )
                row += 1
        
        # 树模型
        ttk.Label(clf_scrollable, text="树模型:", font=('', 9, 'bold')).grid(
            row=row, column=0, sticky=tk.W, pady=(5, 2)
        )
        row += 1
        tree_codes = ["rf", "et", "dt", "xgb", "lgb", "gb", "ada"]
        for code in tree_codes:
            if code in all_classifiers:
                _, name, _, _, _, _ = all_classifiers[code]
                ttk.Checkbutton(clf_scrollable, text=name,
                              variable=self.classifier_vars[code]).grid(
                    row=row, column=0, sticky=tk.W, padx=20
                )
                row += 1
        
        # 其他
        ttk.Label(clf_scrollable, text="其他:", font=('', 9, 'bold')).grid(
            row=row, column=0, sticky=tk.W, pady=(5, 2)
        )
        row += 1
        other_codes = ["knn", "nb", "lr", "mlp"]
        for code in other_codes:
            if code in all_classifiers:
                _, name, _, _, _, _ = all_classifiers[code]
                ttk.Checkbutton(clf_scrollable, text=name,
                              variable=self.classifier_vars[code]).grid(
                    row=row, column=0, sticky=tk.W, padx=20
                )
                row += 1
        
        clf_canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
        clf_scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
        
        # 7. 控制按钮
        control_frame = ttk.LabelFrame(scrollable_left, text="🎮 运行控制", padding="10")
        control_frame.pack(fill=tk.X, padx=5, pady=5)
        
        btn_control_frame = ttk.Frame(control_frame)
        btn_control_frame.pack(fill=tk.X)
        
        self.start_btn = ttk.Button(btn_control_frame, text="▶ 开始分类", 
                                    command=self.start_classification, width=15)
        self.start_btn.pack(side=tk.LEFT, padx=5)
        
        self.stop_btn = ttk.Button(btn_control_frame, text="⏸ 停止", 
                                   command=self.stop_classification, 
                                   state=tk.DISABLED, width=15)
        self.stop_btn.pack(side=tk.LEFT, padx=5)
        
        ttk.Button(btn_control_frame, text="📁 打开结果", 
                  command=self.open_result_dir, width=15).pack(side=tk.LEFT, padx=5)
        
        # 进度条
        ttk.Label(control_frame, text="进度:").pack(anchor=tk.W, pady=(10, 0))
        self.progress_var = tk.DoubleVar()
        self.progress_bar = ttk.Progressbar(control_frame, variable=self.progress_var, 
                                           maximum=100)
        self.progress_bar.pack(fill=tk.X, pady=5)
        
        # 状态
        self.status_var = tk.StringVar(value="就绪")
        ttk.Label(control_frame, textvariable=self.status_var, 
                 relief=tk.SUNKEN, anchor=tk.W).pack(fill=tk.X)
        
        # ===== 右侧面板：图件显示 =====
        right_frame = ttk.Frame(main_paned, width=900)
        main_paned.add(right_frame, weight=2)
        
        # 创建Notebook
        self.notebook = ttk.Notebook(right_frame)
        self.notebook.pack(fill=tk.BOTH, expand=True)
        
        # 标签页1：运行日志
        log_tab = ttk.Frame(self.notebook)
        self.notebook.add(log_tab, text="📝 运行日志")
        
        self.log_text = scrolledtext.ScrolledText(log_tab, wrap=tk.WORD, 
                                                  font=('Consolas', 9))
        self.log_text.pack(fill=tk.BOTH, expand=True, padx=5, pady=5)
        
        # 标签页2：精度对比
        accuracy_tab = ttk.Frame(self.notebook)
        self.notebook.add(accuracy_tab, text="📊 精度对比")
        
        self.accuracy_fig = Figure(figsize=(10, 6), dpi=100)
        self.accuracy_canvas = FigureCanvasTkAgg(self.accuracy_fig, accuracy_tab)
        self.accuracy_canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)
        
        toolbar_acc = ttk.Frame(accuracy_tab)
        toolbar_acc.pack(fill=tk.X)
        NavigationToolbar2Tk(self.accuracy_canvas, toolbar_acc)
        
        # 标签页3：混淆矩阵
        cm_tab = ttk.Frame(self.notebook)
        self.notebook.add(cm_tab, text="🔥 混淆矩阵")
        
        self.cm_fig = Figure(figsize=(8, 6), dpi=100)
        self.cm_canvas = FigureCanvasTkAgg(self.cm_fig, cm_tab)
        self.cm_canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)
        
        toolbar_cm = ttk.Frame(cm_tab)
        toolbar_cm.pack(fill=tk.X)
        NavigationToolbar2Tk(self.cm_canvas, toolbar_cm)
        
        # 标签页4：时间对比
        time_tab = ttk.Frame(self.notebook)
        self.notebook.add(time_tab, text="⏱️ 时间对比")
        
        self.time_fig = Figure(figsize=(10, 6), dpi=100)
        self.time_canvas = FigureCanvasTkAgg(self.time_fig, time_tab)
        self.time_canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)
        
        toolbar_time = ttk.Frame(time_tab)
        toolbar_time.pack(fill=tk.X)
        NavigationToolbar2Tk(self.time_canvas, toolbar_time)
        
        # 标签页5：分类结果预览
        result_tab = ttk.Frame(self.notebook)
        self.notebook.add(result_tab, text="🗺️ 结果预览")
        
        self.result_fig = Figure(figsize=(10, 6), dpi=100)
        self.result_canvas = FigureCanvasTkAgg(self.result_fig, result_tab)
        self.result_canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)
        
        toolbar_result = ttk.Frame(result_tab)
        toolbar_result.pack(fill=tk.X)
        NavigationToolbar2Tk(self.result_canvas, toolbar_result)
    
    # ===== 辅助函数 =====
    def toggle_sampling(self):
        if self.enable_sampling.get():
            self.max_samples_spinbox.config(state=tk.NORMAL)
        else:
            self.max_samples_spinbox.config(state=tk.DISABLED)
    
    def refresh_fields(self):
        train_shp = self.train_shp_path.get()
        if not train_shp or not os.path.exists(train_shp):
            messagebox.showwarning("警告", "请先选择训练样本文件！")
            return
        
        fields = self.backend.get_shapefile_fields(train_shp)
        if fields:
            fields = [f for f in fields if f.lower() != 'geometry']
            
            self.class_attr_combo['values'] = fields
            self.name_attr_combo['values'] = fields
            
            if 'class' in fields:
                self.class_attr.set('class')
            elif 'Class' in fields:
                self.class_attr.set('Class')
            elif fields:
                self.class_attr.set(fields[0])
            
            if 'name' in fields:
                self.name_attr.set('name')
            elif 'Name' in fields:
                self.name_attr.set('Name')
            elif len(fields) > 1:
                self.name_attr.set(fields[1])
            elif fields:
                self.name_attr.set(fields[0])
            
            messagebox.showinfo("成功", f"已加载 {len(fields)} 个字段")
        else:
            messagebox.showerror("错误", "无法读取字段列表！")
    
    def browse_image(self):
        filename = filedialog.askopenfilename(
            title="选择影像文件",
            filetypes=[("GeoTIFF", "*.tif *.tiff"), ("所有文件", "*.*")]
        )
        if filename:
            self.image_path.set(filename)
            self.status_var.set(f"已选择影像: {Path(filename).name}")
    
    def browse_train_shp(self):
        filename = filedialog.askopenfilename(
            title="选择训练样本",
            filetypes=[("Shapefile", "*.shp"), ("所有文件", "*.*")]
        )
        if filename:
            self.train_shp_path.set(filename)
            self.refresh_fields()
    
    def browse_val_shp(self):
        filename = filedialog.askopenfilename(
            title="选择验证样本",
            filetypes=[("Shapefile", "*.shp"), ("所有文件", "*.*")]
        )
        if filename:
            self.val_shp_path.set(filename)
    
    def browse_output(self):
        dirname = filedialog.askdirectory(title="选择输出目录")
        if dirname:
            self.output_dir.set(dirname)
    
    def select_all_classifiers(self):
        for var in self.classifier_vars.values():
            var.set(True)
    
    def deselect_all_classifiers(self):
        for var in self.classifier_vars.values():
            var.set(False)
    
    def select_recommended(self):
        recommended = ["rf", "xgb", "et", "lgb", "linear_svc", "nystroem_svm"]
        for code, var in self.classifier_vars.items():
            var.set(code in recommended)
    
    def select_fast(self):
        fast = ["rf", "et", "dt", "xgb", "lgb", "nb", "lr", "sgd_svm", "linear_svc"]
        for code, var in self.classifier_vars.items():
            var.set(code in fast)
    
    def log(self, message):
        self.log_queue.put(message)
    
    def update_log(self):
        try:
            while True:
                message = self.log_queue.get_nowait()
                self.log_text.insert(tk.END, message + "\n")
                self.log_text.see(tk.END)
        except queue.Empty:
            pass
        self.root.after(100, self.update_log)
    
    def update_accuracy_plot(self):
        """更新精度对比图"""
        if not self.comparison_results:
            return
        
        df = pd.DataFrame(self.comparison_results)
        
        self.accuracy_fig.clear()
        
        # 创建子图
        ax1 = self.accuracy_fig.add_subplot(121)
        ax2 = self.accuracy_fig.add_subplot(122)
        
        # 精度对比
        x = np.arange(len(df))
        width = 0.35
        
        ax1.bar(x - width/2, df['训练集精度'], width, label='训练集', alpha=0.8, color='steelblue')
        ax1.bar(x + width/2, df['验证集精度'], width, label='验证集', alpha=0.8, color='coral')
        
        ax1.set_xlabel('分类器', fontsize=11)
        ax1.set_ylabel('精度', fontsize=11)
        ax1.set_title('总体精度对比', fontsize=12, fontweight='bold')
        ax1.set_xticks(x)
        ax1.set_xticklabels(df['分类器名称'], rotation=45, ha='right', fontsize=9)
        ax1.legend()
        ax1.grid(True, alpha=0.3, axis='y')
        ax1.set_ylim([0, 1.05])
        
        # 添加数值标签
        for i, (train_acc, val_acc) in enumerate(zip(df['训练集精度'], df['验证集精度'])):
            ax1.text(i - width/2, train_acc + 0.01, f'{train_acc:.3f}', 
                    ha='center', va='bottom', fontsize=8)
            ax1.text(i + width/2, val_acc + 0.01, f'{val_acc:.3f}', 
                    ha='center', va='bottom', fontsize=8)
        
        # Kappa对比
        ax2.bar(x - width/2, df['训练集Kappa'], width, label='训练集', alpha=0.8, color='steelblue')
        ax2.bar(x + width/2, df['验证集Kappa'], width, label='验证集', alpha=0.8, color='coral')
        
        ax2.set_xlabel('分类器', fontsize=11)
        ax2.set_ylabel('Kappa系数', fontsize=11)
        ax2.set_title('Kappa系数对比', fontsize=12, fontweight='bold')
        ax2.set_xticks(x)
        ax2.set_xticklabels(df['分类器名称'], rotation=45, ha='right', fontsize=9)
        ax2.legend()
        ax2.grid(True, alpha=0.3, axis='y')
        ax2.set_ylim([0, 1.05])
        
        self.accuracy_fig.tight_layout()
        self.accuracy_canvas.draw()
    
    def update_confusion_matrix(self, y_true, y_pred, class_names):
        """更新混淆矩阵显示"""
        self.cm_fig.clear()
        ax = self.cm_fig.add_subplot(111)
        
        cm = confusion_matrix(y_true, y_pred)
        
        # 绘制热图
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                    xticklabels=class_names, yticklabels=class_names,
                    cbar_kws={'label': '样本数量'}, ax=ax)
        
        ax.set_xlabel('预测类别', fontsize=11)
        ax.set_ylabel('真实类别', fontsize=11)
        ax.set_title('最佳分类器混淆矩阵', fontsize=12, fontweight='bold')
        
        plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
        
        self.cm_fig.tight_layout()
        self.cm_canvas.draw()
    
    def update_time_plot(self):
        """更新时间对比图"""
        if not self.comparison_results:
            return
        
        df = pd.DataFrame(self.comparison_results)
        
        self.time_fig.clear()
        ax = self.time_fig.add_subplot(111)
        
        x = np.arange(len(df))
        width = 0.35
        
        bars1 = ax.bar(x - width/2, df['训练时间(秒)'], width, label='训练时间', 
                      alpha=0.8, color='lightgreen')
        bars2 = ax.bar(x + width/2, df['预测时间(秒)'], width, label='预测时间', 
                      alpha=0.8, color='lightcoral')
        
        ax.set_xlabel('分类器', fontsize=11)
        ax.set_ylabel('时间 (秒)', fontsize=11)
        ax.set_title('训练和预测时间对比', fontsize=12, fontweight='bold')
        ax.set_xticks(x)
        ax.set_xticklabels(df['分类器名称'], rotation=45, ha='right', fontsize=9)
        ax.legend()
        ax.grid(True, alpha=0.3, axis='y')
        
        # 添加数值标签
        for bars in [bars1, bars2]:
            for bar in bars:
                height = bar.get_height()
                ax.text(bar.get_x() + bar.get_width()/2., height,
                       f'{height:.1f}s', ha='center', va='bottom', fontsize=8)
        
        self.time_fig.tight_layout()
        self.time_canvas.draw()
    
    def update_result_preview(self, image_path, classified_path, class_names, class_colors):
        """更新分类结果预览"""
        try:
            self.result_fig.clear()
            
            # 读取影像和分类结果
            img = rxr.open_rasterio(image_path, masked=True)
            classified = rxr.open_rasterio(classified_path)
            
            # 创建子图
            ax1 = self.result_fig.add_subplot(121)
            ax2 = self.result_fig.add_subplot(122)
            
            # 显示原始影像
            if img.shape[0] >= 3:
                rgb_data = np.moveaxis(img.values[:3], 0, -1)
                p2, p98 = np.percentile(rgb_data[rgb_data > 0], (2, 98))
                rgb_display = np.clip((rgb_data - p2) / (p98 - p2), 0, 1)
                ax1.imshow(rgb_display)
            else:
                ax1.imshow(img.values[0], cmap='gray')
            
            ax1.set_title('原始遥感影像', fontsize=12, fontweight='bold')
            ax1.axis('off')
            
            # 显示分类结果
            classified_data = classified.values.squeeze()
            
            # 获取类别
            classes = np.unique(classified_data)
            classes = classes[classes > 0]
            
            # 创建颜色映射
            colors = [class_colors.get(c, 'black') for c in classes]
            labels = [class_names.get(c, f'类别_{c}') for c in classes]
            
            cmap = mcolors.ListedColormap(colors)
            bounds = np.append(classes, classes[-1] + 1) - 0.5
            norm = mcolors.BoundaryNorm(bounds, cmap.N)
            
            # 背景设为透明
            display_data = classified_data.astype(float)
            display_data[classified_data == 0] = np.nan
            
            im = ax2.imshow(display_data, cmap=cmap, norm=norm)
            ax2.set_title('分类结果（最佳分类器）', fontsize=12, fontweight='bold')
            ax2.axis('off')
            
            # 添加图例
            from matplotlib.patches import Patch
            legend_elements = [Patch(facecolor=color, label=label) 
                              for color, label in zip(colors, labels)]
            ax2.legend(handles=legend_elements, loc='upper left', 
                      bbox_to_anchor=(1.05, 1), fontsize=9)
            
            self.result_fig.tight_layout()
            self.result_canvas.draw()
            
        except Exception as e:
            self.log(f"预览显示错误: {str(e)}")
    
    def export_to_excel(self, out_dir):
        """导出结果到Excel"""
        if not HAS_OPENPYXL:
            self.log("⚠️  未安装openpyxl，无法导出Excel")
            return
        
        if not self.comparison_results:
            return
        
        try:
            df = pd.DataFrame(self.comparison_results)
            
            excel_path = out_dir / "classification_comparison.xlsx"
            
            with pd.ExcelWriter(excel_path, engine='openpyxl') as writer:
                # 主结果表
                df.to_excel(writer, sheet_name='分类器对比', index=False)
                
                # 获取工作簿和工作表
                workbook = writer.book
                worksheet = writer.sheets['分类器对比']
                
                # 设置列宽
                for column in worksheet.columns:
                    max_length = 0
                    column_letter = column[0].column_letter
                    for cell in column:
                        try:
                            if len(str(cell.value)) > max_length:
                                max_length = len(str(cell.value))
                        except:
                            pass
                    adjusted_width = min(max_length + 2, 50)
                    worksheet.column_dimensions[column_letter].width = adjusted_width
                
                # 添加统计摘要表
                summary_data = {
                    '指标': ['最高精度', '最高Kappa', '最快训练', '最快预测'],
                    '分类器': [
                        df.loc[df['验证集精度'].idxmax(), '分类器名称'],
                        df.loc[df['验证集Kappa'].idxmax(), '分类器名称'],
                        df.loc[df['训练时间(秒)'].idxmin(), '分类器名称'],
                        df.loc[df['预测时间(秒)'].idxmin(), '分类器名称']
                    ],
                    '数值': [
                        f"{df['验证集精度'].max():.4f}",
                        f"{df['验证集Kappa'].max():.4f}",
                        f"{df['训练时间(秒)'].min():.2f}秒",
                        f"{df['预测时间(秒)'].min():.2f}秒"
                    ]
                }
                
                summary_df = pd.DataFrame(summary_data)
                summary_df.to_excel(writer, sheet_name='性能摘要', index=False)
            
            self.log(f"✓ Excel报告已保存: {excel_path}")
            
        except Exception as e:
            self.log(f"Excel导出失败: {str(e)}")
    
    def start_classification(self):
        """开始分类"""
        # 检查输入
        if not self.image_path.get():
            messagebox.showerror("错误", "请选择影像文件！")
            return
        
        if not self.train_shp_path.get():
            messagebox.showerror("错误", "请选择训练样本！")
            return
        
        if not self.class_attr.get():
            messagebox.showerror("错误", "请选择类别编号字段！")
            return
        
        selected_classifiers = [code for code, var in self.classifier_vars.items() if var.get()]
        if not selected_classifiers:
            messagebox.showerror("错误", "请至少选择一个分类器！")
            return
        
        # 性能警告
        all_classifiers = self.backend.get_all_classifiers()
        very_slow_clfs = []
        
        for code in selected_classifiers:
            if code in all_classifiers:
                speed_tag = all_classifiers[code][5]
                name = all_classifiers[code][1]
                if speed_tag == "very_slow":
                    very_slow_clfs.append(name)
        
        if very_slow_clfs:
            warning_msg = "⚠️ 以下分类器预测非常慢:\n"
            for clf in very_slow_clfs:
                warning_msg += f"  • {clf}\n"
            warning_msg += "\n是否继续?"
            
            if not messagebox.askyesno("性能警告", warning_msg, icon='warning'):
                return
        
        self.start_btn.config(state=tk.DISABLED)
        self.stop_btn.config(state=tk.NORMAL)
        self.is_running = True
        
        # 清空
        self.log_text.delete(1.0, tk.END)
        self.comparison_results = []
        
        self.log("="*80)
        self.log("  遥感影像监督分类系统 v4.1")
        self.log("="*80)
        self.log(f"选择的分类器: {len(selected_classifiers)} 个")
        self.log(f"背景值: {self.background_value.get()}")
        self.log("")
        
        # 切换到日志标签页
        self.notebook.select(0)
        
        thread = threading.Thread(target=self.run_classification, args=(selected_classifiers,))
        thread.daemon = True
        thread.start()
    
    def stop_classification(self):
        self.is_running = False
        self.log("\n⏸ 用户请求停止...")
        self.status_var.set("已停止")
    
    def run_classification(self, selected_classifiers):
        """执行分类（主流程）"""
        try:
            out_dir = Path(self.output_dir.get())
            out_dir.mkdir(exist_ok=True)
            
            # 读取影像
            self.log(f"📁 读取影像...")
            self.status_var.set("读取影像...")
            img = rxr.open_rasterio(self.image_path.get(), masked=True)
            n_pixels = img.shape[1] * img.shape[2]
            self.log(f"   尺寸: {img.shape[1]}×{img.shape[2]} = {n_pixels:,} 像元")
            
            if not self.is_running:
                return
            
            # 读取类别信息
            self.log(f"\n📊 读取类别信息...")
            class_names, class_colors, _ = self.backend.get_class_info_from_shp(
                self.train_shp_path.get(), 
                self.class_attr.get(), 
                self.name_attr.get()
            )
            self.class_names_dict = class_names
            self.class_colors_dict = class_colors
            self.log(f"   类别: {list(class_names.values())}")
            
            # 提取训练样本
            self.log(f"\n🎯 处理训练样本...")
            self.status_var.set("处理训练样本...")
            train_mask = self.backend.rasterize_samples(
                self.train_shp_path.get(), img, self.class_attr.get()
            )
            
            max_samples = self.max_samples.get() if self.enable_sampling.get() else None
            
            X_train, y_train, n_nan, n_inf, n_sampled = self.backend.extract_samples(
                img, train_mask, 
                ignore_background=self.ignore_background.get(),
                background_value=self.background_value.get(),
                max_samples=max_samples
            )
            
            self.log(f"   训练样本数: {len(y_train):,}")
            if n_nan > 0:
                self.log(f"   └─ 移除NaN: {n_nan:,}")
            if n_sampled > 0:
                self.log(f"   └─ 采样减少: {n_sampled:,}")
            
            if not self.is_running:
                return
            
            # 提取验证样本
            val_exists = os.path.exists(self.val_shp_path.get())
            if val_exists:
                self.log(f"\n✅ 处理验证样本...")
                val_mask = self.backend.rasterize_samples(
                    self.val_shp_path.get(), img, self.class_attr.get()
                )
                
                if self.ignore_background.get():
                    background_mask = self.backend.get_background_mask(
                        img, self.background_value.get()
                    )
                    valid_val = (val_mask > 0) & (~background_mask)
                else:
                    valid_val = val_mask > 0
                
                yv_true = val_mask[valid_val]
                self.log(f"   验证样本数: {len(yv_true):,}")
            
            # 分类器训练和评估
            all_classifiers = self.backend.get_all_classifiers(
                self.n_estimators.get(), 
                fast_mode=self.fast_mode.get(),
                n_train_samples=len(y_train)
            )
            
            comparison_results = []
            total_start_time = time.time()
            best_accuracy = 0
            best_clf_code = None
            
            for i, clf_code in enumerate(selected_classifiers):
                if not self.is_running:
                    break
                
                clf, clf_name, clf_desc, needs_encoding, needs_scaling, speed_tag = all_classifiers[clf_code]
                
                self.log(f"\n{'='*80}")
                self.log(f"[{i+1}/{len(selected_classifiers)}] {clf_name}")
                self.log(f"{'='*80}")
                
                self.status_var.set(f"[{i+1}/{len(selected_classifiers)}] 训练 {clf_name}...")
                
                clf_dir = out_dir / clf_code
                clf_dir.mkdir(exist_ok=True)
                
                try:
                    # 数据预处理
                    label_encoder = None
                    scaler = None
                    X_train_use = X_train.copy()
                    y_train_use = y_train.copy()
                    
                    if needs_encoding:
                        self.log("   🔄 标签编码...")
                        label_encoder = LabelEncoder()
                        y_train_use = label_encoder.fit_transform(y_train)
                    
                    if needs_scaling:
                        self.log("   📏 特征缩放...")
                        scaler = StandardScaler()
                        X_train_use = scaler.fit_transform(X_train_use)
                    
                    # 训练
                    self.log("   🔨 训练中...")
                    train_start = time.time()
                    clf.fit(X_train_use, y_train_use)
                    train_time = time.time() - train_start
                    self.log(f"   ✓ 训练完成: {train_time:.2f}秒")
                    
                    # 训练集精度
                    y_train_pred = clf.predict(X_train_use)
                    
                    if label_encoder is not None:
                        y_train_pred = label_encoder.inverse_transform(y_train_pred)
                    
                    train_metrics = self.backend.calculate_metrics(y_train, y_train_pred)
                    self.log(f"   📈 训练集 - 精度: {train_metrics['overall_accuracy']:.4f}")
                    
                    if not self.is_running:
                        break
                    
                    # 预测整幅影像
                    self.log("   🗺️  预测影像...")
                    self.status_var.set(f"[{i+1}/{len(selected_classifiers)}] 预测 {clf_name}...")
                    
                    pred_start = time.time()
                    classified_path = clf_dir / f"classified_{clf_code}.tif"
                    
                    def update_progress(progress):
                        self.progress_var.set(progress)
                    
                    self.backend.predict_by_block(
                        clf, img, classified_path, 
                        block_size=self.block_size.get(),
                        ignore_background=self.ignore_background.get(),
                        background_value=self.background_value.get(),
                        progress_callback=update_progress,
                        label_encoder=label_encoder,
                        scaler=scaler
                    )
                    
                    pred_time = time.time() - pred_start
                    self.log(f"   ✓ 预测完成: {pred_time:.2f}秒")
                    
                    # 验证集精度
                    val_metrics = {'overall_accuracy': np.nan, 'kappa': np.nan}
                    yv_pred = None
                    
                    if val_exists:
                        with rxr.open_rasterio(classified_path) as pred_img:
                            pred_arr = pred_img.values.squeeze()
                        
                        yv_pred = pred_arr[valid_val]
                        val_metrics = self.backend.calculate_metrics(yv_true, yv_pred)
                        self.log(f"   📊 验证集 - 精度: {val_metrics['overall_accuracy']:.4f}")
                        
                        # 记录最佳分类器
                        if val_metrics['overall_accuracy'] > best_accuracy:
                            best_accuracy = val_metrics['overall_accuracy']
                            best_clf_code = clf_code
                            self.best_result_path = classified_path
                            self.current_y_true = yv_true
                            self.current_y_pred = yv_pred
                    
                    # 记录结果
                    result = {
                        '分类器代码': clf_code,
                        '分类器名称': clf_name,
                        '训练集精度': train_metrics['overall_accuracy'],
                        '训练集Kappa': train_metrics['kappa'],
                        '验证集精度': val_metrics['overall_accuracy'],
                        '验证集Kappa': val_metrics['kappa'],
                        '训练时间(秒)': train_time,
                        '预测时间(秒)': pred_time,
                    }
                    comparison_results.append(result)
                    self.comparison_results = comparison_results
                    
                    # 实时更新图表
                    self.root.after(0, self.update_accuracy_plot)
                    self.root.after(0, self.update_time_plot)
                    
                    self.log(f"   ✅ {clf_name} 完成!")
                    
                except Exception as e:
                    self.log(f"   ❌ {clf_name} 失败: {str(e)}")
                    continue
                
                self.progress_var.set((i + 1) / len(selected_classifiers) * 100)
            
            # 生成报告
            if comparison_results and self.is_running:
                total_time = time.time() - total_start_time
                
                self.log(f"\n{'='*80}")
                self.log("📝 生成报告...")
                
                comparison_df = pd.DataFrame(comparison_results)
                
                # 保存CSV
                comparison_df.to_csv(out_dir / "classifier_comparison.csv", 
                                   index=False, encoding='utf-8-sig')
                
                # 导出Excel
                self.export_to_excel(out_dir)
                
                # 文字报告
                with open(out_dir / "comparison_summary.txt", 'w', encoding='utf-8') as f:
                    f.write("遥感影像分类器性能对比报告\n")
                    f.write("="*70 + "\n\n")
                    f.write(f"时间: {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
                    f.write(f"影像: {img.shape[1]}×{img.shape[2]}\n")
                    f.write(f"训练样本: {len(y_train):,}\n")
                    f.write(f"成功: {len(comparison_results)}/{len(selected_classifiers)}\n")
                    f.write(f"总耗时: {total_time/60:.1f} 分钟\n\n")
                    
                    sorted_df = comparison_df.sort_values('验证集精度', ascending=False)
                    f.write("验证集精度排名:\n")
                    f.write("-"*70 + "\n")
                    for idx, (_, row) in enumerate(sorted_df.iterrows(), 1):
                        f.write(f"{idx}. {row['分类器名称']:15s} - "
                               f"精度: {row['验证集精度']:.4f}\n")
                
                # 更新混淆矩阵
                if self.current_y_true is not None and self.current_y_pred is not None:
                    val_classes = sorted(np.unique(self.current_y_true))
                    val_class_names = [class_names.get(c, f'类别_{c}') for c in val_classes]
                    self.root.after(0, lambda: self.update_confusion_matrix(
                        self.current_y_true, self.current_y_pred, val_class_names
                    ))
                
                # 更新结果预览
                if self.best_result_path:
                    self.root.after(0, lambda: self.update_result_preview(
                        self.image_path.get(), self.best_result_path, 
                        class_names, class_colors
                    ))
                
                self.log("✅ 所有任务完成!")
                self.log(f"⏱️  总耗时: {total_time/60:.1f} 分钟")
                
                best_clf = comparison_df.loc[comparison_df['验证集精度'].idxmax()]
                self.log(f"\n🏆 最佳: {best_clf['分类器名称']} ({best_clf['验证集精度']:.4f})")
                
                self.status_var.set(f"✅ 完成! 最佳: {best_clf['分类器名称']}")
                
                # 切换到精度对比标签页
                self.root.after(0, lambda: self.notebook.select(1))
                
                messagebox.showinfo("任务完成", 
                    f"🎉 分类任务完成!\n\n"
                    f"✅ 成功: {len(comparison_results)}/{len(selected_classifiers)}\n"
                    f"🏆 最佳: {best_clf['分类器名称']} ({best_clf['验证集精度']:.4f})\n"
                    f"📊 结果已导出为Excel和CSV")
            
        except Exception as e:
            self.log(f"\n❌ 错误: {str(e)}")
            import traceback
            self.log(traceback.format_exc())
            messagebox.showerror("错误", f"发生错误:\n{str(e)}")
            self.status_var.set("❌ 错误")
        
        finally:
            self.start_btn.config(state=tk.NORMAL)
            self.stop_btn.config(state=tk.DISABLED)
            self.progress_var.set(0)
            self.is_running = False
    
    def open_result_dir(self):
        """打开结果目录"""
        out_dir = Path(self.output_dir.get())
        if out_dir.exists():
            import subprocess
            import platform
            
            if platform.system() == "Windows":
                os.startfile(out_dir)
            elif platform.system() == "Darwin":
                subprocess.Popen(["open", out_dir])
            else:
                subprocess.Popen(["xdg-open", out_dir])
        else:
            messagebox.showwarning("警告", "结果目录不存在！")

# ==================== 主程序入口 ====================
def main():
    """程序入口"""
    print("="*80)
    print("  遥感影像监督分类系统 v4.1 - 专业版")
    print("="*80)
    print("\n正在检查依赖库...")
    
    root = tk.Tk()
    app = ClassificationGUI(root)
    
    # 欢迎信息
    app.log("="*80)
    app.log("  遥感影像监督分类系统 v4.1 - 专业版")
    app.log("="*80)
    app.log("\n主要特性:")
    app.log("  ✓ 自定义背景值输入")
    app.log("  ✓ 字段下拉框自动识别")
    app.log("  ✓ 实时精度对比图表")
    app.log("  ✓ 混淆矩阵可视化")
    app.log("  ✓ 分类结果预览")
    app.log("  ✓ Excel格式报告导出")
    app.log("\n使用流程:")
    app.log("  1. 选择影像和样本文件")
    app.log("  2. 点击'刷新字段列表'选择类别字段")
    app.log("  3. 设置背景值和其他参数")
    app.log("  4. 选择分类器")
    app.log("  5. 点击'开始分类'")
    app.log("  6. 查看右侧实时图表")
    app.log("="*80)
    app.log("")
    
    print("\n✓ 系统启动成功!")
    
    root.mainloop()

if __name__ == "__main__":
    main()

  遥感影像监督分类系统 v4.1 - 专业版

正在检查依赖库...
✓ XGBoost 可用
✓ LightGBM 可用

✓ 系统启动成功!
