In [None]:
# 模型训练

# -*- coding: utf-8 -*-
import os
import sys
import numpy as np
import pandas as pd
from osgeo import gdal
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import confusion_matrix, accuracy_score, recall_score, precision_score, f1_score, classification_report
import joblib
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import logging


BASE_DIR = r"D:\Desktop\ResNet50_Paper\Review_comments\Exp\Exp2\Train_Samples"
TRAIN_CSV = os.path.join(BASE_DIR, r"label\label_train.CSV")
VAL_CSV = os.path.join(BASE_DIR, r"label\label_val.CSV")
DEM_CSV = os.path.join(BASE_DIR, r"DEM.CSV")

OUTPUT_TRAIN_DIR = r"D:\Desktop\ResNet50_Paper\Review_comments\Exp\Exp2\Exp2_RF_KG_postpro\Model_train_result"

# 知识图谱后处理配置
LABELS_DICT = {
    0: '2000-2300', 1: '1000-1900', 2: '1500-2300', 3: '750-1300', 
    4: '2650-3000', 5: '3000-3400', 6: '3400-3777', 7: '3400-3777'
}
NUM_CLASSES = 8 
CONFIDENCE_THRESHOLD = 0.5
W_MODEL = 0.6 
W_ALTITUDE = 0.4 
EPSILON = 1e-5

if not os.path.exists(OUTPUT_TRAIN_DIR):
    os.makedirs(OUTPUT_TRAIN_DIR)

# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(os.path.join(OUTPUT_TRAIN_DIR, 'train_log.log'), encoding='utf-8'),
        logging.StreamHandler(sys.stdout)
    ]
)


def read_tif_as_features(tif_path):
    try:
        ds = gdal.Open(tif_path)
        if ds is None:
            logging.error(f"无法打开文件: {tif_path}")
            return None
        data = ds.ReadAsArray()
        
        if data.ndim == 2:
            data = data[np.newaxis, :, :]
            
        stats_features = []
        
        for band in data:
            # 展平当前波段数据
            pixels = band.flatten()
            
            valid_pixels = pixels 

            # 提取统计特征
            f_mean = np.nanmean(valid_pixels)
            f_std = np.nanstd(valid_pixels)
            f_max = np.nanmax(valid_pixels)
            f_min = np.nanmin(valid_pixels)
            f_p25 = np.nanpercentile(valid_pixels, 25)
            f_p75 = np.nanpercentile(valid_pixels, 75)
            
            stats_features.extend([f_mean, f_std, f_max, f_min, f_p25, f_p75])
            
        return np.array(stats_features)
    except Exception as e:
        logging.error(f"读取TIF出错 {tif_path}: {e}")
        return None

def get_altitude_probability(altitude, class_id):
    try:
        altitude = float(altitude)
    except ValueError:
        return 0.0
    
    if class_id not in LABELS_DICT:
        return 0.0
    
    range_str = LABELS_DICT[class_id]
    lb_str, ub_str = range_str.split('-')
    lb, ub = int(lb_str), int(ub_str)
    a, b = lb - 100, ub + 100
    
    if altitude < a: return 0.0
    elif a <= altitude < lb: return (altitude - a) / (lb - a + EPSILON)
    elif lb <= altitude <= ub: return 1.0
    elif ub < altitude <= b: return (b - altitude) / (b - ub + EPSILON)
    else: return 0.0

def apply_kg_post_processing(pred_probs, dem_values):
    processed_probs = pred_probs.copy()
    for i in range(len(pred_probs)):
        max_prob = np.max(pred_probs[i])
        if max_prob < CONFIDENCE_THRESHOLD:
            altitude = dem_values[i]
            p_altitude = np.zeros(NUM_CLASSES)
            for c in range(NUM_CLASSES):
                p_altitude[c] = get_altitude_probability(altitude, c)
            
            fused_prob = W_MODEL * pred_probs[i] + W_ALTITUDE * p_altitude
            if np.sum(fused_prob) > 0:
                fused_prob = fused_prob / np.sum(fused_prob)
            processed_probs[i] = fused_prob
            
    final_preds = np.argmax(processed_probs, axis=1)
    return final_preds, processed_probs

def load_dataset(csv_path, dem_df):
    """加载数据集"""
    df = pd.read_csv(csv_path, header=None, names=['path', 'label'])
    
    X_list, y_list, dem_list, valid_paths = [], [], [], []
    logging.info(f"开始加载数据: {csv_path}")
    
    for idx, row in tqdm(df.iterrows(), total=len(df)):
        img_path = row['path'].strip()
        label = int(row['label'])
        
        # 匹配DEM
        file_name = os.path.basename(img_path)
        dem_row = dem_df[dem_df['ImageName'].apply(lambda x: os.path.basename(x)) == file_name]
        
        if len(dem_row) == 0:
            dem_row = dem_df[dem_df['ImageName'] == img_path]
            
        if len(dem_row) == 0:
            continue
            
        dem_val = dem_row['DEM'].values[0]
        
        # 读取特征
        img_features = read_tif_as_features(img_path)
        
        if img_features is not None:
            X_list.append(img_features) 
            y_list.append(label)
            dem_list.append(dem_val)
            valid_paths.append(img_path)
            
    return np.array(X_list), np.array(y_list), np.array(dem_list), valid_paths


if __name__ == "__main__":
    
    # --- 1. 加载 DEM 数据 ---
    logging.info("加载 DEM 数据表...")
    if os.path.exists(DEM_CSV):
        dem_df = pd.read_csv(DEM_CSV)
        dem_df.columns = ['ImageName', 'DEM'] 
        dem_df['DEM'] = pd.to_numeric(dem_df['DEM'], errors='coerce')
        dem_df = dem_df.dropna(subset=['DEM'])
    else:
        logging.error(f"DEM文件不存在: {DEM_CSV}")
        sys.exit(1)
    
    # --- 2. 准备数据 ---
    X_train, y_train, train_dems, _ = load_dataset(TRAIN_CSV, dem_df)
    X_val, y_val, val_dems, val_paths = load_dataset(VAL_CSV, dem_df)
    
    logging.info(f"训练集特征维度: {X_train.shape}, 验证集特征维度: {X_val.shape}")

    if len(X_train) > 0:
        # --- 3. 超参数调优 (GridSearchCV) ---
        logging.info("开始进行超参数网格搜索 (已启用防过拟合策略)...")
        
        param_grid = {
            'n_estimators': [100, 200, 300],
            'max_depth': [8, 10, 15],          
            'min_samples_split': [10, 20],     
            'min_samples_leaf': [4, 8, 10],    
            'max_features': ['sqrt', 'log2']   
        }
        
        rf = RandomForestClassifier(random_state=2024, class_weight='balanced', n_jobs=-1)
        
        grid_search = GridSearchCV(
            estimator=rf, 
            param_grid=param_grid, 
            cv=3, 
            scoring='accuracy', 
            verbose=1, 
            n_jobs=-1,
            return_train_score=True 
        )
        
        grid_search.fit(X_train, y_train)
        
        logging.info(f"最佳参数: {grid_search.best_params_}")
        logging.info(f"最佳交叉验证分数: {grid_search.best_score_:.4f}")
        
        # --- 4. 保存详细调参过程数据 ---
        results_df = pd.DataFrame(grid_search.cv_results_)
        cols_to_keep = ['params', 'mean_test_score', 'std_test_score', 'mean_train_score', 'std_train_score']
        param_keys = [f'param_{key}' for key in param_grid.keys()]
        final_cols = [c for c in param_keys if c in results_df.columns] + [c for c in cols_to_keep if c in results_df.columns]
        results_df[final_cols].to_csv(os.path.join(OUTPUT_TRAIN_DIR, 'grid_search_detailed_metrics.csv'), index=False)
        
        # --- 5. 绘图 ---
        plt.figure(figsize=(14, 7))
        sorted_indices = np.argsort(results_df['mean_test_score'])
        sorted_test_scores = results_df['mean_test_score'].values[sorted_indices]
        sorted_train_scores = results_df['mean_train_score'].values[sorted_indices]
        
        x_axis = range(len(sorted_indices))
        plt.plot(x_axis, sorted_train_scores, marker='.', linestyle='-', color='orange', label='Train Accuracy', alpha=0.7)
        plt.plot(x_axis, sorted_test_scores, marker='o', linestyle='-', color='blue', label='CV Validation Accuracy')
        plt.fill_between(x_axis, 
                         sorted_test_scores - results_df['std_test_score'].values[sorted_indices],
                         sorted_test_scores + results_df['std_test_score'].values[sorted_indices], 
                         color='blue', alpha=0.1, label='Val Std Dev')

        plt.xlabel('Parameter Combination Index (Sorted by Val Accuracy)')
        plt.ylabel('Accuracy')
        plt.title('Hyperparameter Tuning: Train vs Validation Accuracy')
        plt.legend()
        plt.grid(True, linestyle='--', alpha=0.6)
        plt.tight_layout()
        plt.savefig(os.path.join(OUTPUT_TRAIN_DIR, 'tuning_curve_train_vs_val.png'), dpi=300)
        plt.close()
        
        # --- 6. 获取最佳模型并保存 ---
        best_rf_model = grid_search.best_estimator_
        model_path = os.path.join(OUTPUT_TRAIN_DIR, 'best_rf_model.pkl')
        joblib.dump(best_rf_model, model_path)
        logging.info(f"最佳模型已保存至: {model_path}")

        # --- 7. 在独立验证集上评估---
        logging.info("在验证集上进行评估...")
        
        val_pred_probs = best_rf_model.predict_proba(X_val)
        
        if val_pred_probs.shape[1] < NUM_CLASSES:
            full_probs = np.zeros((val_pred_probs.shape[0], NUM_CLASSES))
            known_classes = best_rf_model.classes_
            full_probs[:, known_classes] = val_pred_probs
            val_pred_probs = full_probs
            
        val_preds_final, _ = apply_kg_post_processing(val_pred_probs, val_dems)
        
        val_result_df = pd.DataFrame({
            'ImagePath': val_paths,
            'True_Label': y_val,
            'Predicted_Label': val_preds_final
        })
        val_result_df.to_csv(os.path.join(OUTPUT_TRAIN_DIR, 'best_model_prediction_details.csv'), index=False)
        
        cm = confusion_matrix(y_val, val_preds_final)
        cm_df = pd.DataFrame(cm)
        cm_df.to_csv(os.path.join(OUTPUT_TRAIN_DIR, 'best_model_confusion_matrix.csv'), index=False)

        plt.figure(figsize=(10, 8))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
        plt.title('Validation Confusion Matrix (Best Model + KG)')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        plt.savefig(os.path.join(OUTPUT_TRAIN_DIR, 'confusion_matrix.png'))
        plt.close()
        
        metrics_dict = {
            'Best Params': str(grid_search.best_params_),
            'Accuracy': accuracy_score(y_val, val_preds_final),
            'Recall (Macro)': recall_score(y_val, val_preds_final, average='macro', zero_division=0),
            'Precision (Macro)': precision_score(y_val, val_preds_final, average='macro', zero_division=0),
            'F1 Score (Macro)': f1_score(y_val, val_preds_final, average='macro', zero_division=0)
        }
        pd.DataFrame([metrics_dict]).to_csv(os.path.join(OUTPUT_TRAIN_DIR, 'final_metrics_summary.csv'), index=False)
        
        class_report_dict = classification_report(y_val, val_preds_final, output_dict=True, zero_division=0)
        class_report_df = pd.DataFrame(class_report_dict).transpose()
        class_report_df.to_csv(os.path.join(OUTPUT_TRAIN_DIR, 'best_model_classification_report_per_class.csv'), index=True)
        
        logging.info("各类型详细指标 (F1, Precision, Recall) 已保存。")
        logging.info("\n" + classification_report(y_val, val_preds_final, zero_division=0))
        
        logging.info("训练结束。")
        
    else:
        logging.error("训练集为空，请检查CSV路径或图像路径是否正确。")

In [None]:
# 模型预测

# -*- coding: utf-8 -*-
import os
import sys
import numpy as np
import pandas as pd
from osgeo import gdal
import joblib
import logging
from tqdm import tqdm

MODEL_PATH = r"D:\Desktop\ResNet50_Paper\Review_comments\Exp\Exp2\Exp2_RF_KG_postpro\Model_train_result\best_rf_model.pkl"
PREDICT_IMAGE_DIR = r"D:\Desktop\ResNet50_Paper\Review_comments\Exp\Exp2\Predict_Data" 
DEM_CSV = r"D:\Desktop\ResNet50_Paper\Review_comments\Exp\Exp2\Predict_Data\DEM.CSV"
OUTPUT_DIR = r"D:\Desktop\ResNet50_Paper\Review_comments\Exp\Exp2\Exp2_RF_KG_postpro\Prediction_Result"

TIF_LABEL_DIR = os.path.join(OUTPUT_DIR, "Labels_TIF")
TIF_PROB_DIR = os.path.join(OUTPUT_DIR, "Probabilities_TIF")

for d in [TIF_LABEL_DIR, TIF_PROB_DIR]:
    if not os.path.exists(d):
        os.makedirs(d)

LABELS_DICT = {
    0: '2000-2300', 1: '1000-1900', 2: '1500-2300', 3: '750-1300', 
    4: '2650-3000', 5: '3000-3400', 6: '3400-3777', 7: '3400-3777'
}
NUM_CLASSES = 8 
CONFIDENCE_THRESHOLD = 0.5
W_MODEL = 0.6 
W_ALTITUDE = 0.4 
EPSILON = 1e-5

# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(os.path.join(OUTPUT_DIR, 'prediction_log.log'), encoding='utf-8'),
        logging.StreamHandler(sys.stdout)
    ]
)

def save_tif(data, reference_ds, output_path, dtype=gdal.GDT_Byte):
    driver = gdal.GetDriverByName('GTiff')
    if data.ndim == 2:
        bands = 1
        rows, cols = data.shape
    else:
        bands, rows, cols = data.shape

    out_ds = driver.Create(output_path, cols, rows, bands, dtype)
    out_ds.SetGeoTransform(reference_ds.GetGeoTransform())
    out_ds.SetProjection(reference_ds.GetProjection())

    if bands == 1:
        out_ds.GetRasterBand(1).WriteArray(data)
    else:
        for i in range(bands):
            out_ds.GetRasterBand(i + 1).WriteArray(data[i])
    
    out_ds.FlushCache()
    out_ds = None

def get_altitude_probability(altitude, class_id):
    if class_id not in LABELS_DICT: return 0.0
    range_str = LABELS_DICT[class_id]
    lb, ub = map(int, range_str.split('-'))
    a, b = lb - 100, ub + 100
    
    if altitude < a: return 0.0
    elif a <= altitude < lb: return (altitude - a) / (lb - a + EPSILON)
    elif lb <= altitude <= ub: return 1.0
    elif ub < altitude <= b: return (b - altitude) / (b - ub + EPSILON)
    else: return 0.0

def apply_kg_post_processing(pred_probs, altitude):
    max_prob = np.max(pred_probs)
    processed_probs = pred_probs.copy()
    
    if max_prob < CONFIDENCE_THRESHOLD:
        p_altitude = np.array([get_altitude_probability(altitude, c) for c in range(NUM_CLASSES)])
        fused_prob = W_MODEL * pred_probs[0] + W_ALTITUDE * p_altitude
        if np.sum(fused_prob) > 0:
            fused_prob /= np.sum(fused_prob)
        processed_probs[0] = fused_prob
        
    final_label = np.argmax(processed_probs, axis=1)[0]
    return final_label, processed_probs[0]

def run_prediction():
    logging.info("加载模型与数据...")
    if not os.path.exists(MODEL_PATH):
        logging.error(f"模型不存在: {MODEL_PATH}")
        return
    model = joblib.load(MODEL_PATH)
    
    dem_df = pd.read_csv(DEM_CSV)
    dem_df.columns = ['ImageName', 'DEM']
    dem_df['ImageName'] = dem_df['ImageName'].apply(lambda x: os.path.basename(x))
    dem_map = dict(zip(dem_df['ImageName'], dem_df['DEM']))

    all_files = [f for f in os.listdir(PREDICT_IMAGE_DIR) if f.lower().endswith('.tif')]
    logging.info(f"开始预测 {len(all_files)} 个文件...")

    results_csv = []
    
    for file_name in tqdm(all_files):
        img_path = os.path.join(PREDICT_IMAGE_DIR, file_name)
        if file_name not in dem_map: continue
        
        ds = gdal.Open(img_path)
        if ds is None: continue
        rows, cols = ds.RasterYSize, ds.RasterXSize
        data = ds.ReadAsArray()
        
        if data.ndim == 2: data = data[np.newaxis, :, :]
        stats_features = []
        for band in data:
            valid_pixels = band.flatten()
            stats_features.extend([
                np.nanmean(valid_pixels), np.nanstd(valid_pixels),
                np.nanmax(valid_pixels), np.nanmin(valid_pixels),
                np.nanpercentile(valid_pixels, 25), np.nanpercentile(valid_pixels, 75)
            ])
        
        features = np.array(stats_features).reshape(1, -1)
        pred_probs = model.predict_proba(features)
        
        if pred_probs.shape[1] < NUM_CLASSES:
            full_probs = np.zeros((1, NUM_CLASSES))
            full_probs[0, model.classes_] = pred_probs
            pred_probs = full_probs

        altitude = dem_map[file_name]
        final_label, final_probs = apply_kg_post_processing(pred_probs, altitude)
        
        label_mask = np.full((rows, cols), final_label, dtype=np.uint8)
        label_tif_path = os.path.join(TIF_LABEL_DIR, f"Label_{file_name}")
        save_tif(label_mask, ds, label_tif_path, gdal.GDT_Byte)
        
        prob_mask = np.zeros((NUM_CLASSES, rows, cols), dtype=np.float32)
        for c in range(NUM_CLASSES):
            prob_mask[c, :, :] = final_probs[c]
        prob_tif_path = os.path.join(TIF_PROB_DIR, f"Prob_{file_name}")
        save_tif(prob_mask, ds, prob_tif_path, gdal.GDT_Float32)

        results_csv.append({
            'ImageName': file_name,
            'Original_Predict': model.classes_[np.argmax(pred_probs)],
            'Final_Label_KG': final_label,
            'Altitude': altitude
        })
        
        ds = None 

    pd.DataFrame(results_csv).to_csv(os.path.join(OUTPUT_DIR, 'summary_results.csv'), index=False)
    logging.info(f"所有预测已完成。TIF 结果保存在 {OUTPUT_DIR} 的子文件夹中。")

if __name__ == "__main__":
    run_prediction()