In [None]:
# 模型训练

# -*- coding: utf-8 -*-
import os
import sys
import numpy as np
import pandas as pd
from osgeo import gdal
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import confusion_matrix, accuracy_score, recall_score, precision_score, f1_score, classification_report
import joblib
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import logging

# 路径配置
BASE_DIR = r"D:\Desktop\ResNet50_Paper\Review_comments\Exp\Exp2\Train_Samples"
TRAIN_CSV = os.path.join(BASE_DIR, r"label\label_train.CSV")
VAL_CSV = os.path.join(BASE_DIR, r"label\label_val.CSV")
OUTPUT_TRAIN_DIR = r"D:\Desktop\ResNet50_Paper\Review_comments\Exp\Exp2\Exp2_RF\Model_train_result"

if not os.path.exists(OUTPUT_TRAIN_DIR):
    os.makedirs(OUTPUT_TRAIN_DIR)

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(os.path.join(OUTPUT_TRAIN_DIR, 'train_log.log'), encoding='utf-8'),
        logging.StreamHandler(sys.stdout)
    ]
)

def read_tif_as_features(tif_path):
    """读取TIF文件并提取统计特征"""
    try:
        ds = gdal.Open(tif_path)
        if ds is None:
            logging.error(f"无法打开文件: {tif_path}")
            return None
        data = ds.ReadAsArray() 
        
        if data.ndim == 2:
            data = data[np.newaxis, :, :]
            
        stats_features = []
        for band in data:
            pixels = band.flatten() 
            valid_pixels = pixels 

            f_mean = np.nanmean(valid_pixels)
            f_std = np.nanstd(valid_pixels)
            f_max = np.nanmax(valid_pixels)
            f_min = np.nanmin(valid_pixels)
            f_p25 = np.nanpercentile(valid_pixels, 25)
            f_p75 = np.nanpercentile(valid_pixels, 75)
            
            stats_features.extend([f_mean, f_std, f_max, f_min, f_p25, f_p75])
            
        return np.array(stats_features)
    except Exception as e:
        logging.error(f"读取TIF出错 {tif_path}: {e}")
    return None 

def load_dataset(csv_path):
    df = pd.read_csv(csv_path, header=None, names=['path', 'label'])
    X_list, y_list, valid_paths = [], [], []
    
    logging.info(f"开始加载数据: {csv_path}") 
    
    for idx, row in tqdm(df.iterrows(), total=len(df)):
        img_path = row['path'].strip()
        label = int(row['label'])
        
        img_features = read_tif_as_features(img_path)
        
        if img_features is not None:
            X_list.append(img_features) 
            y_list.append(label)
            valid_paths.append(img_path)
            
    return np.array(X_list), np.array(y_list), valid_paths

if __name__ == "__main__":
    X_train, y_train, _ = load_dataset(TRAIN_CSV)
    X_val, y_val, val_paths = load_dataset(VAL_CSV) 
    
    logging.info(f"训练集特征维度: {X_train.shape}, 验证集特征维度: {X_val.shape}")

    if len(X_train) > 0:
        logging.info("开始进行超参数网格搜索...")
        
        param_grid = {
            'n_estimators': [100, 200, 300],
            'max_depth': [8, 10, 15],     
            'min_samples_split': [10, 20],     
            'min_samples_leaf': [4, 8, 10],    
            'max_features': ['sqrt', 'log2']   
        } 
        
        rf = RandomForestClassifier(random_state=2024, class_weight='balanced', n_jobs=-1)
        
        grid_search = GridSearchCV(
            estimator=rf, 
            param_grid=param_grid, 
            cv=3, 
            scoring='accuracy', 
            verbose=1, 
            n_jobs=-1,
            return_train_score=True 
        ) 
        
        grid_search.fit(X_train, y_train)
        logging.info(f"最佳参数: {grid_search.best_params_}")
        
        results_df = pd.DataFrame(grid_search.cv_results_)
        results_df.to_csv(os.path.join(OUTPUT_TRAIN_DIR, 'grid_search_detailed_metrics.csv'), index=False)
        
        plt.figure(figsize=(14, 7))
        sorted_indices = np.argsort(results_df['mean_test_score'])
        plt.plot(range(len(sorted_indices)), results_df['mean_train_score'].values[sorted_indices], label='Train Accuracy') 
        plt.plot(range(len(sorted_indices)), results_df['mean_test_score'].values[sorted_indices], label='CV Validation Accuracy')
        plt.legend()
        plt.savefig(os.path.join(OUTPUT_TRAIN_DIR, 'tuning_curve.png')) 
        plt.close()
        
        best_rf_model = grid_search.best_estimator_
        joblib.dump(best_rf_model, os.path.join(OUTPUT_TRAIN_DIR, 'best_rf_model.pkl'))
        logging.info("模型已保存。")

        logging.info("在验证集上进行评估...")
        val_preds_final = best_rf_model.predict(X_val) 
        
        val_result_df = pd.DataFrame({
            'ImagePath': val_paths,
            'True_Label': y_val,
            'Predicted_Label': val_preds_final
        })
        val_result_df.to_csv(os.path.join(OUTPUT_TRAIN_DIR, 'best_model_prediction_details.csv'), index=False)
    
        cm = confusion_matrix(y_val, val_preds_final)
        plt.figure(figsize=(10, 8))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
        plt.savefig(os.path.join(OUTPUT_TRAIN_DIR, 'confusion_matrix.png'))
        plt.close() 
        
        metrics_dict = {
            'Accuracy': accuracy_score(y_val, val_preds_final),
            'F1 Score (Macro)': f1_score(y_val, val_preds_final, average='macro', zero_division=0)
        }
        logging.info("\n" + classification_report(y_val, val_preds_final, zero_division=0)) 
        logging.info("训练结束。")
    else:
        logging.error("训练集为空，请检查CSV路径。")

In [None]:
# 模型预测

# -*- coding: utf-8 -*-
import os
import sys
import numpy as np
import pandas as pd
from osgeo import gdal
import joblib
from tqdm import tqdm
import logging

MODEL_PATH = r"D:\Desktop\ResNet50_Paper\Review_comments\Exp\Exp2\Exp2_RF\Model_train_result\best_rf_model.pkl"
PREDICT_PATCH_DIR = r"D:\Desktop\ResNet50_Paper\Review_comments\Exp\Exp2\Predict_Data"
OUTPUT_DIR = r"D:\Desktop\ResNet50_Paper\Review_comments\Exp\Exp2\Exp2_RF\Prediction_Result"

TIF_OUTPUT_DIR = os.path.join(OUTPUT_DIR, "TIF_Results")
if not os.path.exists(TIF_OUTPUT_DIR):
    os.makedirs(TIF_OUTPUT_DIR)

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(os.path.join(OUTPUT_DIR, 'prediction_log.log'), encoding='utf-8'),
        logging.StreamHandler(sys.stdout)
    ]
)

def read_tif_as_features(tif_path):
    """提取统计特征并返回元数据"""
    try:
        ds = gdal.Open(tif_path)
        if ds is None:
            return None, None
        
        geo_transform = ds.GetGeoTransform()
        projection = ds.GetProjection()
        width = ds.RasterXSize
        height = ds.RasterYSize
        
        data = ds.ReadAsArray()
        if data.ndim == 2:
            data = data[np.newaxis, :, :]
            
        stats_features = []
        for band in data:
            pixels = band.flatten()
            f_mean = np.nanmean(pixels)
            f_std = np.nanstd(pixels)
            f_max = np.nanmax(pixels)
            f_min = np.nanmin(pixels)
            f_p25 = np.nanpercentile(pixels, 25)
            f_p75 = np.nanpercentile(pixels, 75)
            stats_features.extend([f_mean, f_std, f_max, f_min, f_p25, f_p75])
            
        meta = {
            'geo_transform': geo_transform,
            'projection': projection,
            'width': width,
            'height': height
        }
        return np.array(stats_features), meta
    except Exception as e:
        logging.error(f"处理文件 {tif_path} 时出错: {e}")
        return None, None

def save_result_to_tif(output_path, data, meta, band_names=None):
    driver = gdal.GetDriverByName('GTiff')
    bands = data.shape[0] if data.ndim == 3 else 1
    
    ds = driver.Create(output_path, meta['width'], meta['height'], bands, gdal.GDT_Float32)
    ds.SetGeoTransform(meta['geo_transform'])
    ds.SetProjection(meta['projection'])
    
    if bands == 1:
        ds.GetRasterBand(1).WriteArray(data)
    else:
        for i in range(bands):
            band = ds.GetRasterBand(i + 1)
            band.WriteArray(data[i])
            if band_names:
                band.SetDescription(str(band_names[i]))
    
    ds.FlushCache()
    ds = None

def main():
    # 1. 加载模型
    if not os.path.exists(MODEL_PATH):
        logging.error(f"找不到模型文件: {MODEL_PATH}")
        return
    
    logging.info("加载训练好的模型...")
    model = joblib.load(MODEL_PATH)
    class_names = list(model.classes_) 
    logging.info(f"模型识别的类别: {class_names}")

    # 2. 遍历影像块
    tif_files = [f for f in os.listdir(PREDICT_PATCH_DIR) if f.lower().endswith('.tif')]
    logging.info(f"发现 {len(tif_files)} 个待预测影像块。")

    results_csv = []
    
    # 3. 执行推理并保存 TIF
    for filename in tqdm(tif_files, desc="Processing"):
        tif_path = os.path.join(PREDICT_PATCH_DIR, filename)
        features, meta = read_tif_as_features(tif_path)
        
        if features is not None:
            features_reshaped = features.reshape(1, -1)
            
            pred_label = model.predict(features_reshaped)[0]
            pred_probas = model.predict_proba(features_reshaped)[0]
            
            class_map = np.full((meta['height'], meta['width']), pred_label, dtype=np.float32)
            label_tif_path = os.path.join(TIF_OUTPUT_DIR, filename.replace('.tif', '_pred.tif'))
            save_result_to_tif(label_tif_path, class_map, meta)
            
            prob_map = np.zeros((len(class_names), meta['height'], meta['width']), dtype=np.float32)
            for i, prob in enumerate(pred_probas):
                prob_map[i, :, :] = prob
            
            prob_tif_path = os.path.join(TIF_OUTPUT_DIR, filename.replace('.tif', '_probas.tif'))
            save_result_to_tif(prob_tif_path, prob_map, meta, band_names=class_names)

            res_dict = {
                'FileName': filename,
                'Predicted_Label': pred_label,
                'Max_Confidence': np.max(pred_probas)
            }
            for i, c_name in enumerate(class_names):
                res_dict[f'Prob_Class_{c_name}'] = pred_probas[i]
                
            results_csv.append(res_dict)

    result_df = pd.DataFrame(results_csv)
    csv_output = os.path.join(OUTPUT_DIR, 'research_area_detailed_results.csv')
    result_df.to_csv(csv_output, index=False)
    
    logging.info(f"处理完成！")
    logging.info(f"CSV 结果: {csv_output}")
    logging.info(f"TIF 结果保存在: {TIF_OUTPUT_DIR}")

if __name__ == "__main__":
    main()