# 不同极化、波段模型下的实测数据验证模块(均为Insitu数据)

# 1.SMAPVEX16 Map（Iowa）（数据预处理部分参考matlab的A3_process_SMAPVEX16.m）

matlab中操作步骤：
读取E:\data\VWC\test-VWC\SMAPVEX16_Iowa\mat中的所有mat文件，即SMAPVEX16数据集内容（不要读取文件夹路径中子文件夹的mat），命名规则为YYYYMMDD.mat，YYYYMMDD为年月日字符串。需要进行填充的变量大小均为1800*3600（按照0.1°格网来的，可以先生成中心经纬度格网数据，数据在matlab存储的为北纬89.95至南纬89.95），根据SMAPVEX16数据集的mat文件的LAT、LON变量得知每个元素对应的经纬度后，提取四个邻近位置的0.1°格网的数据点，根据经纬度的距离差进行双线性插值操作。现需要对原先的数据填充以下变量：

①PFT：14个类别，读取文件名为E:\data\ESACCI PFT\Resample\Data\YYYY.mat，YYYY就是SMAPVEX对应的年份（这里由于SMAPVEX数据都是2016年的，故你直接读取2016年即可），读取变量：'water','bare','snowice','built','grassnat','grassman','shrubbd','shrubbe','shrubnd','shrubne','treebd','treebe','treend','treene'，先将每个变量每个元素除以100转为比例值，每个变量均按照前面所说的方法进行插值。


②VOD：6个类别，读取文件名为E:\data\VOD\mat\kuxcVOD\ASC\MCCA_AMSR2_010D_CCXH_VSM_VOD_Asc_YYYYMMDD_V0.nc4.mat的文件，读取变量'VOD_Ku_Hpol_Asc', 'VOD_X_Hpol_Asc', 'VOD_C_Hpol_Asc','VOD_Ku_Vpol_Asc', 'VOD_X_Vpol_Asc', 'VOD_C_Vpol_Asc'，并按照前面所说方法对每个变量进行插值。


③LAI：数据为月中数据，涉及逐日插值，所以先读取固定的数据。因为数据跨度为2016年5月28日（使用的插值数据为2016年5月）至2016年8月16日（使用的插值数据为2016年9月的），先读取这5个数据（E:\data\GLASS LAI\mat\0.1Deg\Dataset\YYYY-MM-01.tif.mat,YYYY和MM替换成我所说的月和日即可），可以使用这些数据插值出2016年1月16日至2016年9月14日之间的每一天的情况（月中数据不需要插值）。然后根据SMAPVEX数据的内容，读取对应的插值日期数据后，根据经纬度位置再进行插值。


④（SM暂时使用数据中已有）Hveg数据，'E:\data\CanopyHeight\CH.mat';根据经纬度位置完成插值处理即可。

最后，读取这些插值的数据，分别生成6个掩膜数据：Ku_H_mask、Ku_V_mask、X_H_mask、X_V_mask、C_H_mask、C_V_mask，每个掩膜都读取变量VSM（来自SMAPVEX的mat文件）、前面处理的PFT14个变量和对应的VOD（例如Ku_H_mask就是使用），然后VOD和VSM位置有其一NaN或者PFT的变量'water','bare','snowice','built'加起来的值大于等于0.05，那么这个元素的位置就应该被mask掉，无法用于后续验证。

将上述新变量以及掩膜变量覆盖保存至SMAPVEX16的mat文件中




In [18]:
import os
import numpy as np
import pandas as pd
import joblib
import glob
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import sys
from scipy.io import loadmat

# 设置路径
mat_folder = r"E:\data\VWC\test-VWC\SMAPVEX16_Iowa\mat"

# 获取当前工作目录
script_path = os.getcwd()    
models_dir = os.path.join(script_path, 'models')

# 定义需要读取的基本变量
base_variables = [
    'grassnat', 'grassman', 'shrubbd', 'shrubbe', 'shrubnd', 'shrubne',
    'treebd', 'treebe', 'treend', 'treene', 'LAI', 'VSM', 'VWC'
]

# 查找所有匹配的模型文件
model_files = glob.glob(os.path.join(models_dir, 'RFR_*.pkl'))
models = []

print("发现以下模型文件:")
for model_file in model_files:
    filename = os.path.basename(model_file)
    # 解析模型名称
    parts = filename[:-4].split('_')  # 移除.pkl
    if len(parts) < 4 or parts[0] != "RFR":
        continue
        
    band = parts[1]
    pol = parts[2].replace("pol", "")  # 移除"pol"
    model_type = int(parts[3].replace("Type", ""))
    
    # 只保留类型为1的模型
    if model_type != 1:
        print(f"跳过模型 {filename} (类型 {model_type} != 1)")
        continue
    
    # 添加该模型需要的VOD变量
    vod_var = f"{band.lower()}_vod_{pol}"
    mask_var = f"{band}_{pol}_mask"
    
    # 加载模型
    try:
        model = joblib.load(model_file)
    except Exception as e:
        print(f"加载模型 {filename} 时出错: {str(e)}")
        continue
    
    models.append({
        'name': filename[:-4],
        'band': band,
        'pol': pol,
        'type': model_type,
        'vod_var': vod_var,
        'mask_var': mask_var,
        'model': model,
        'predicted': [],
        'actual': []
    })
    print(f"- {filename} (波段: {band}, 极化: {pol}, 类型: {model_type})")

if not models:
    print("未找到任何类型为1的模型文件!")
    exit()

# 遍历MAT文件夹
mat_files = [f for f in os.listdir(mat_folder) 
            if f.endswith('.mat') and os.path.isfile(os.path.join(mat_folder, f))]

print(f"\n找到 {len(mat_files)} 个MAT文件, 开始处理...")

for mat_file in tqdm(mat_files, desc="处理MAT文件"):
    file_path = os.path.join(mat_folder, mat_file)
    
    # 检查文件是否存在并可读
    if not os.path.exists(file_path):
        print(f"文件 {mat_file} 不存在")
        continue
        
    if not os.access(file_path, os.R_OK):
        print(f"没有权限读取文件 {mat_file}")
        continue
    
    try:
        # 使用scipy加载MAT文件
        mat_data = loadmat(file_path, simplify_cells=True)
        
        # 创建数据字典，方便访问
        data_dict = {}
        for var in base_variables:
            if var in mat_data:
                data_dict[var] = mat_data[var]
            else:
                print(f"变量 {var} 在文件 {mat_file} 中缺失")
        
        # 添加模型需要的变量
        for model in models:
            vod_var = model['vod_var']
            mask_var = model['mask_var']
            
            if vod_var not in data_dict and vod_var in mat_data:
                data_dict[vod_var] = mat_data[vod_var]
            if mask_var not in data_dict and mask_var in mat_data:
                data_dict[mask_var] = mat_data[mask_var]
        
        # 处理每个模型
        for model in models:
            vod_var = model['vod_var']
            mask_var = model['mask_var']
            
            # 检查所需变量是否都存在
            required_vars = [vod_var, mask_var, 'VSM', 'LAI'] + [v for v in base_variables if v != 'VWC']
            missing_vars = [v for v in required_vars if v not in data_dict]
            
            if missing_vars:
                print(f"模型 {model['name']} 在文件 {mat_file} 中缺少变量: {', '.join(missing_vars)}")
                continue
            
            # 检查数组形状
            base_shape = data_dict['VWC'].shape
            for var in required_vars:
                if data_dict[var].shape != base_shape:
                    print(f"警告: 变量 {var} 形状不一致: 期望 {base_shape}, 实际 {data_dict[var].shape}")
            
            # 应用掩膜过滤无效数据
            mask = data_dict[mask_var].astype(bool)
            mask_flat = mask.flatten()
            
            # 获取实际VWC值
            actual_vwc = data_dict['VWC'].flatten()[mask_flat]
            
            # 准备模型输入数据
            lai_data = data_dict['LAI'].flatten()
            lai_data = lai_data[:len(mask_flat)][mask_flat]
            lai_data = np.clip(lai_data, 0, 6) / 6  # 归一化LAI
            
            vod_data = data_dict[vod_var].flatten()
            vod_data = vod_data[:len(mask_flat)][mask_flat]
            vod_data = np.clip(vod_data, 0, 2) / 2  # 归一化VOD
            
            # 获取PFT变量
            pft_vars = {
                'grassman': data_dict['grassman'].flatten(),
                'grassnat': data_dict['grassnat'].flatten(),
                'shrubbd': data_dict['shrubbd'].flatten(),
                'shrubbe': data_dict['shrubbe'].flatten(),
                'shrubnd': data_dict['shrubnd'].flatten(),
                'shrubne': data_dict['shrubne'].flatten(),
                'treebd': data_dict['treebd'].flatten(),
                'treebe': data_dict['treebe'].flatten(),
                'treend': data_dict['treend'].flatten(),
                'treene': data_dict['treene'].flatten()
            }
            
            # 截断到掩膜长度
            for key in pft_vars:
                pft_vars[key] = pft_vars[key][:len(mask_flat)][mask_flat]
            
            # 创建输入数据框
            input_df = pd.DataFrame()
            input_df['VOD'] = vod_data
            input_df['LAI'] = lai_data
            input_df['SM'] = data_dict['VSM'].flatten()[:len(mask_flat)][mask_flat]
            
            # 根据模型类型合成特征 (只处理类型1)
            if model['type'] == 1:
                input_df['Grass_man'] = pft_vars['grassman']
                input_df['Grass_nat'] = pft_vars['grassnat']
                input_df['Shrub_bd'] = pft_vars['shrubbd']
                input_df['Shrub_be'] = pft_vars['shrubbe']
                input_df['Shrub_nd'] = pft_vars['shrubnd']
                input_df['Shrub_ne'] = pft_vars['shrubne']
                input_df['Tree_bd'] = pft_vars['treebd']
                input_df['Tree_be'] = pft_vars['treebe']
                input_df['Tree_nd'] = pft_vars['treend']
                input_df['Tree_ne'] = pft_vars['treene']
                
            # 使用模型预测
            predicted_vwc = model['model'].predict(input_df)
            
            # 收集结果
            model['predicted'].extend(predicted_vwc)
            model['actual'].extend(actual_vwc)
            
    except Exception as e:
        exc_type, exc_obj, exc_tb = sys.exc_info()
        fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1] if exc_tb else "unknown"
        lineno = exc_tb.tb_lineno if exc_tb else 0
        print(f"\n处理文件 {mat_file} 时出错: {str(e)} 在文件 {fname} 的第 {lineno} 行")

# 保存结果并绘图
output_dir = "smapvex16Iowa_prediction_results"
os.makedirs(output_dir, exist_ok=True)

print("\n处理完成, 开始保存结果和绘图...")

# 配置Matplotlib使用无中文字体的后端
plt.rcParams.update({
    'font.family': 'sans-serif',
    'font.sans-serif': ['Arial', 'DejaVu Sans', 'Liberation Sans', 'Tahoma'],
    'axes.unicode_minus': False
})

for model in tqdm(models, desc="生成结果"):
    if not model['predicted']:
        print(f"模型 {model['name']} 没有预测数据，跳过")
        continue
        
    # 转换为NumPy数组
    predicted = np.array(model['predicted'])
    actual = np.array(model['actual'])
    
    # 移除NaN值
    valid_mask = ~np.isnan(actual) & ~np.isnan(predicted)
    if valid_mask.sum() == 0:
        print(f"模型 {model['name']} 没有有效数据")
        continue
        
    actual = actual[valid_mask]
    predicted = predicted[valid_mask]
    
    # 计算样本点数量
    n_points = len(actual)
    print(f"模型 {model['name']} 有效点数量: {n_points}")
    
    # 如果结果太多，随机采样一部分用于绘图
    max_points = 100000
    if n_points > max_points:
        indices = np.random.choice(n_points, max_points, replace=False)
        sampled_actual = actual[indices]
        sampled_predicted = predicted[indices]
    else:
        sampled_actual = actual
        sampled_predicted = predicted
    
    # 保存所有预测结果和实际值到CSV
    result_df = pd.DataFrame({
        'Actual': actual,
        'Predicted': predicted
    })
    result_df.to_csv(os.path.join(output_dir, f"{model['name']}_results.csv"), index=False)
    
    # 计算评估指标
    rmse = np.sqrt(np.mean((actual - predicted) ** 2))
    bias = np.mean(predicted - actual)
    
    print(f"模型 {model['name']} 评估指标:")
    print(f"  RMSE = {rmse:.4f}")
    print(f"  Bias = {bias:.4f}")
    
    try:
        # 绘制散点图 - PNG版本
        plt.figure(figsize=(10, 8))
        hb = plt.hexbin(
            sampled_actual, 
            sampled_predicted, 
            gridsize=100, 
            cmap='viridis', 
            mincnt=1,
            bins='log'  # 使用对数尺度增强可视化
        )
        cb = plt.colorbar(hb, label='Point Density')  # 修改为英文以避免字体问题
        
        # 添加参考线 y=x
        max_val = max(sampled_actual.max(), sampled_predicted.max())
        plt.plot([0, max_val], [0, max_val], 'r--', linewidth=2, label="1:1 Line")

        # 添加统计信息 (使用英文避免字体问题)
        plt.text(
            0.05, 0.95,
            f"n = {n_points:,}\nRMSE = {rmse:.3f}",
            transform=plt.gca().transAxes,
            verticalalignment='top',
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)
        )

        plt.title(f"VWC Prediction - {model['name']}", fontsize=16)
        plt.xlabel("Actual VWC", fontsize=12)
        plt.ylabel("Predicted VWC", fontsize=12)
        
        # 添加网格线增强可读性
        plt.grid(True, linestyle='--', alpha=0.5)
        plt.tight_layout()
        
        # 保存图像为PNG
        plot_path = os.path.join(output_dir, f"{model['name']}_scatter.png")
        plt.savefig(plot_path, dpi=300)
        plt.close()
        
        print(f"保存PNG结果: {model['name']}_scatter.png")
        
        # 绘制PDF版本
        plt.figure(figsize=(10, 8))
        hb_pdf = plt.hexbin(
            sampled_actual, 
            sampled_predicted, 
            gridsize=200,  # PDF使用更高的分辨率
            cmap='viridis', 
            mincnt=1,
            bins='log'
        )
        cb_pdf = plt.colorbar(hb_pdf, label='Point Density')
        plt.plot([0, max_val], [0, max_val], 'r--', linewidth=2)
        plt.title(f"VWC Prediction - {model['name']}", fontsize=16)
        plt.xlabel("Actual VWC", fontsize=12)
        plt.ylabel("Predicted VWC", fontsize=12)
        
        # PDF版本的统计信息放在底部
        plt.text(
            0.05, 0.05,
            f"n = {n_points:,} | RMSE = {rmse:.3f} | Bias = {bias:.3f}",
            transform=plt.gca().transAxes,
            verticalalignment='bottom',
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)
        )
        
        plt.grid(True, linestyle='--', alpha=0.3)
        plt.tight_layout()
        pdf_path = os.path.join(output_dir, f"{model['name']}_scatter.pdf")
        plt.savefig(pdf_path, format='pdf', dpi=300, bbox_inches='tight')
        plt.close()
        
        print(f"保存PDF结果: {model['name']}_scatter.pdf")
        
    except Exception as e:
        print(f"为模型 {model['name']} 生成图表时出错: {str(e)}")

print("\n所有处理完成! 结果保存在 'smapvex16Iowa_prediction_results' 文件夹中")

发现以下模型文件:
- RFR_C_Hpol_Type1.pkl (波段: C, 极化: H, 类型: 1)
跳过模型 RFR_C_Hpol_Type2.pkl (类型 2 != 1)
跳过模型 RFR_C_Hpol_Type3.pkl (类型 3 != 1)
- RFR_C_HVpol_Type1.pkl (波段: C, 极化: HV, 类型: 1)
跳过模型 RFR_C_HVpol_Type2.pkl (类型 2 != 1)
跳过模型 RFR_C_HVpol_Type3.pkl (类型 3 != 1)
- RFR_C_Vpol_Type1.pkl (波段: C, 极化: V, 类型: 1)
跳过模型 RFR_C_Vpol_Type2.pkl (类型 2 != 1)
跳过模型 RFR_C_Vpol_Type3.pkl (类型 3 != 1)
- RFR_Ku_Hpol_Type1.pkl (波段: Ku, 极化: H, 类型: 1)
跳过模型 RFR_Ku_Hpol_Type2.pkl (类型 2 != 1)
跳过模型 RFR_Ku_Hpol_Type3.pkl (类型 3 != 1)
- RFR_Ku_HVpol_Type1.pkl (波段: Ku, 极化: HV, 类型: 1)
跳过模型 RFR_Ku_HVpol_Type2.pkl (类型 2 != 1)
跳过模型 RFR_Ku_HVpol_Type3.pkl (类型 3 != 1)
- RFR_Ku_Vpol_Type1.pkl (波段: Ku, 极化: V, 类型: 1)
跳过模型 RFR_Ku_Vpol_Type2.pkl (类型 2 != 1)
跳过模型 RFR_Ku_Vpol_Type3.pkl (类型 3 != 1)
- RFR_X_Hpol_Type1.pkl (波段: X, 极化: H, 类型: 1)
跳过模型 RFR_X_Hpol_Type2.pkl (类型 2 != 1)
跳过模型 RFR_X_Hpol_Type3.pkl (类型 3 != 1)
- RFR_X_HVpol_Type1.pkl (波段: X, 极化: HV, 类型: 1)
跳过模型 RFR_X_HVpol_Type2.pkl (类型 2 != 1)
跳过模型 RFR_X_HVpol_Type3.pkl (类型 3 != 1

处理MAT文件:   0%|                                                                              | 0/12 [00:00<?, ?it/s]

模型 RFR_C_HVpol_Type1 在文件 20160528.mat 中缺少变量: c_vod_HV, C_HV_mask
模型 RFR_Ku_HVpol_Type1 在文件 20160528.mat 中缺少变量: ku_vod_HV, Ku_HV_mask
模型 RFR_X_HVpol_Type1 在文件 20160528.mat 中缺少变量: x_vod_HV, X_HV_mask


处理MAT文件:   8%|█████▊                                                                | 1/12 [00:01<00:17,  1.60s/it]


处理文件 20160531.mat 时出错: Found array with 0 sample(s) (shape=(0, 13)) while a minimum of 1 is required by RandomForestRegressor. 在文件 965321359.py 的第 187 行


处理MAT文件:  33%|███████████████████████▎                                              | 4/12 [00:02<00:03,  2.32it/s]

模型 RFR_C_HVpol_Type1 在文件 20160601.mat 中缺少变量: c_vod_HV, C_HV_mask

处理文件 20160601.mat 时出错: Found array with 0 sample(s) (shape=(0, 13)) while a minimum of 1 is required by RandomForestRegressor. 在文件 965321359.py 的第 187 行
模型 RFR_C_HVpol_Type1 在文件 20160602.mat 中缺少变量: c_vod_HV, C_HV_mask

处理文件 20160602.mat 时出错: Found array with 0 sample(s) (shape=(0, 13)) while a minimum of 1 is required by RandomForestRegressor. 在文件 965321359.py 的第 187 行
模型 RFR_C_HVpol_Type1 在文件 20160603.mat 中缺少变量: c_vod_HV, C_HV_mask
模型 RFR_Ku_HVpol_Type1 在文件 20160603.mat 中缺少变量: ku_vod_HV, Ku_HV_mask
模型 RFR_X_HVpol_Type1 在文件 20160603.mat 中缺少变量: x_vod_HV, X_HV_mask


处理MAT文件:  50%|███████████████████████████████████                                   | 6/12 [00:04<00:03,  1.53it/s]

模型 RFR_C_HVpol_Type1 在文件 20160605.mat 中缺少变量: c_vod_HV, C_HV_mask

处理文件 20160605.mat 时出错: Found array with 0 sample(s) (shape=(0, 13)) while a minimum of 1 is required by RandomForestRegressor. 在文件 965321359.py 的第 187 行

处理文件 20160803.mat 时出错: Found array with 0 sample(s) (shape=(0, 13)) while a minimum of 1 is required by RandomForestRegressor. 在文件 965321359.py 的第 187 行

处理文件 20160805.mat 时出错: Found array with 0 sample(s) (shape=(0, 13)) while a minimum of 1 is required by RandomForestRegressor. 在文件 965321359.py 的第 187 行
模型 RFR_C_HVpol_Type1 在文件 20160806.mat 中缺少变量: c_vod_HV, C_HV_mask
模型 RFR_Ku_HVpol_Type1 在文件 20160806.mat 中缺少变量: ku_vod_HV, Ku_HV_mask
模型 RFR_X_HVpol_Type1 在文件 20160806.mat 中缺少变量: x_vod_HV, X_HV_mask


处理MAT文件:  75%|████████████████████████████████████████████████████▌                 | 9/12 [00:05<00:01,  1.94it/s]

模型 RFR_C_HVpol_Type1 在文件 20160813.mat 中缺少变量: c_vod_HV, C_HV_mask
模型 RFR_Ku_HVpol_Type1 在文件 20160813.mat 中缺少变量: ku_vod_HV, Ku_HV_mask
模型 RFR_X_HVpol_Type1 在文件 20160813.mat 中缺少变量: x_vod_HV, X_HV_mask


处理MAT文件:  83%|█████████████████████████████████████████████████████████▌           | 10/12 [00:07<00:01,  1.25it/s]

模型 RFR_C_HVpol_Type1 在文件 20160814.mat 中缺少变量: c_vod_HV, C_HV_mask
模型 RFR_Ku_HVpol_Type1 在文件 20160814.mat 中缺少变量: ku_vod_HV, Ku_HV_mask
模型 RFR_X_HVpol_Type1 在文件 20160814.mat 中缺少变量: x_vod_HV, X_HV_mask


处理MAT文件:  92%|███████████████████████████████████████████████████████████████▎     | 11/12 [00:08<00:01,  1.01s/it]

模型 RFR_C_HVpol_Type1 在文件 20160816.mat 中缺少变量: c_vod_HV, C_HV_mask
模型 RFR_Ku_HVpol_Type1 在文件 20160816.mat 中缺少变量: ku_vod_HV, Ku_HV_mask
模型 RFR_X_HVpol_Type1 在文件 20160816.mat 中缺少变量: x_vod_HV, X_HV_mask


处理MAT文件: 100%|█████████████████████████████████████████████████████████████████████| 12/12 [00:10<00:00,  1.15it/s]



处理完成, 开始保存结果和绘图...


生成结果:   0%|                                                                                  | 0/9 [00:00<?, ?it/s]

模型 RFR_C_Hpol_Type1 有效点数量: 37043
模型 RFR_C_Hpol_Type1 评估指标:
  RMSE = 2.9634
  Bias = 2.8039
保存PNG结果: RFR_C_Hpol_Type1_scatter.png


生成结果:  11%|████████▏                                                                 | 1/9 [00:03<00:26,  3.30s/it]

保存PDF结果: RFR_C_Hpol_Type1_scatter.pdf
模型 RFR_C_HVpol_Type1 没有预测数据，跳过
模型 RFR_C_Vpol_Type1 有效点数量: 6857
模型 RFR_C_Vpol_Type1 评估指标:
  RMSE = 2.7479
  Bias = 2.5967
保存PNG结果: RFR_C_Vpol_Type1_scatter.png


生成结果:  33%|████████████████████████▋                                                 | 3/9 [00:04<00:08,  1.48s/it]

保存PDF结果: RFR_C_Vpol_Type1_scatter.pdf
模型 RFR_Ku_Hpol_Type1 有效点数量: 26135
模型 RFR_Ku_Hpol_Type1 评估指标:
  RMSE = 2.0016
  Bias = 1.8442
保存PNG结果: RFR_Ku_Hpol_Type1_scatter.png


生成结果:  44%|████████████████████████████████▉                                         | 4/9 [00:07<00:09,  1.87s/it]

保存PDF结果: RFR_Ku_Hpol_Type1_scatter.pdf
模型 RFR_Ku_HVpol_Type1 没有预测数据，跳过
模型 RFR_Ku_Vpol_Type1 有效点数量: 24874
模型 RFR_Ku_Vpol_Type1 评估指标:
  RMSE = 2.9737
  Bias = 2.8415
保存PNG结果: RFR_Ku_Vpol_Type1_scatter.png


生成结果:  67%|█████████████████████████████████████████████████▎                        | 6/9 [00:10<00:05,  1.70s/it]

保存PDF结果: RFR_Ku_Vpol_Type1_scatter.pdf
模型 RFR_X_Hpol_Type1 有效点数量: 27239
模型 RFR_X_Hpol_Type1 评估指标:
  RMSE = 2.8751
  Bias = 2.7067
保存PNG结果: RFR_X_Hpol_Type1_scatter.png


生成结果:  78%|█████████████████████████████████████████████████████████▌                | 7/9 [00:12<00:03,  1.86s/it]

保存PDF结果: RFR_X_Hpol_Type1_scatter.pdf
模型 RFR_X_HVpol_Type1 没有预测数据，跳过
模型 RFR_X_Vpol_Type1 有效点数量: 26350
模型 RFR_X_Vpol_Type1 评估指标:
  RMSE = 1.8060
  Bias = 1.6190
保存PNG结果: RFR_X_Vpol_Type1_scatter.png


生成结果: 100%|██████████████████████████████████████████████████████████████████████████| 9/9 [00:15<00:00,  1.73s/it]

保存PDF结果: RFR_X_Vpol_Type1_scatter.pdf

所有处理完成! 结果保存在 'smapvex16Iowa_prediction_results' 文件夹中





# 2.CLASIC07

In [15]:
# 将txt读取为xlsx，并且转化WGS84的经纬度坐标
import pandas as pd
import os
import pyproj
from pyproj import Transformer
import warnings
from datetime import datetime

# 忽略pyproj的警告信息
warnings.filterwarnings("ignore", category=FutureWarning)

def convert_utm_to_wgs84(easting, northing, zone=14, hemisphere='N'):
    """
    将UTM坐标转换为WGS84经纬度
    
    参数:
        easting: UTM东坐标
        northing: UTM北坐标
        zone: UTM区域号 (默认为14)
        hemisphere: 半球标识 ('N' 北半球 或 'S' 南半球)
    
    返回:
        (经度, 纬度) 元组
    """
    # 定义UTM投影系统
    utm_crs = pyproj.CRS(f"+proj=utm +zone={zone} +{hemisphere} +ellps=WGS84 +datum=WGS84 +units=m +no_defs")
    
    # 定义WGS84地理坐标系统
    wgs84_crs = pyproj.CRS("EPSG:4326")  # WGS84
    
    # 创建转换器
    transformer = Transformer.from_crs(utm_crs, wgs84_crs)
    
    # 执行坐标转换 - 返回(经度, 纬度)
    longitude, latitude = transformer.transform(easting, northing)
    
    return longitude, latitude  # 正确返回(经度, 纬度)

def process_vegetation_data(input_file_path):
    # 读取文本文件
    with open(input_file_path, 'r') as file:
        lines = file.readlines()
    
    # 处理数据行（跳过前5行标题）
    processed_data = []
    
    # 从第6行开始处理（索引5），跳过前5行标题
    for line in lines[5:]:
        clean_line = line.strip()
        if not clean_line:  # 跳过空行
            continue
            
        parts = clean_line.split()
        
        # 处理Crop列可能包含空格的情况
        if len(parts) == 10:
            # 标准情况：Crop为单次词
            crop = parts[1]
            date_index = 2
        elif len(parts) == 11:
            # Crop包含空格：如"Cut WW"
            crop = f"{parts[1]} {parts[2]}"
            date_index = 3
        else:
            # 跳过格式不匹配的行
            print(f"跳过格式不匹配的行: {clean_line}")
            continue
        
        # 提取各字段
        field = parts[0]
        date_str = parts[date_index]
        doy = parts[date_index + 1]
        time = parts[date_index + 2]
        vwc = parts[date_index + 3]
        easting = parts[date_index + 6]
        northing = parts[date_index + 7]
        
        # 跳过标题行（如果第一个字段是"Field"）
        if field == "Field":
            print(f"跳过标题行: {clean_line}")
            continue
        
        # 转换日期格式
        try:
            # 解析原始日期格式: "月/日/年"
            date_obj = datetime.strptime(date_str, "%m/%d/%Y")
            # 格式化为新的日期字符串: YYYY-MM-DD
            formatted_date = date_obj.strftime("%Y-%m-%d")
        except ValueError as e:
            print(f"日期格式转换错误: {e} (Field: {field}, Date: {date_str})")
            formatted_date = date_str  # 保留原始日期字符串
        
        # 将UTM坐标转换为WGS84经纬度
        try:
            # 转换为浮点数
            easting_float = float(easting)
            northing_float = float(northing)
            
            # 执行坐标转换 (使用UTM Zone 14N)
            # 正确获取经度和纬度
            longitude, latitude = convert_utm_to_wgs84(easting_float, northing_float, zone=14, hemisphere='N')
            
            # 保留6位小数精度
            longitude = round(longitude, 6)
            latitude = round(latitude, 6)
        except (ValueError, TypeError) as e:
            print(f"坐标转换错误: {e} (Field: {field}, Easting: {easting}, Northing: {northing})")
            longitude, latitude = None, None
        
        processed_data.append([
            field, crop, formatted_date, doy, time, vwc,
            longitude, latitude, easting, northing  # 正确顺序：经度、纬度
        ])
    
    # 创建DataFrame - 修正列名
    columns = [
        "Field", "Crop", "Date", "DOY", "Time (CDT)", 
        "VWC (kg/m²)", "Latitude (WGS84)", "Longitude (WGS84)", 
        "UTM Easting", "UTM Northing"
    ]
    df = pd.DataFrame(processed_data, columns=columns)
    
    # 优化数据类型
    numeric_cols = ["DOY", "VWC (kg/m²)", "Longitude (WGS84)", "Latitude (WGS84)", 
                   "UTM Easting", "UTM Northing"]
    for col in numeric_cols:
        df[col] = pd.to_numeric(df[col], errors='coerce')
    
    # 创建输出文件路径（同名.xlsx）
    base_name = os.path.splitext(os.path.basename(input_file_path))[0]
    output_dir = os.path.dirname(input_file_path)
    output_file = os.path.join(output_dir, f"{base_name}.xlsx")
    
    # 保存为Excel文件（带自动列宽调整）
    try:
        with pd.ExcelWriter(output_file, engine='xlsxwriter') as writer:
            df.to_excel(writer, index=False, sheet_name='Vegetation Data')
            
            # 获取工作簿和工作表对象
            workbook = writer.book
            worksheet = writer.sheets['Vegetation Data']
            
            # 设置自动调整列宽
            for i, col in enumerate(df.columns):
                # 获取列的最大宽度
                max_len = max(df[col].astype(str).map(len).max(), len(col)) + 2
                worksheet.set_column(i, i, max_len)
    except Exception as e:
        # 如果xlsxwriter引擎失败，尝试使用openpyxl
        print(f"使用xlsxwriter引擎失败: {e}")
        print("尝试使用openpyxl引擎...")
        df.to_excel(output_file, index=False, sheet_name='Vegetation Data')
        print(f"已使用openpyxl引擎保存文件，但不支持自动调整列宽")
    
    return output_file

# 直接指定文本文件路径
input_file_path = r"E:\data\VWC\test-VWC\Insitu CLASIC07\CL07V_SUM_VEG_CLASIC.txt"

# 检查文件是否存在
if not os.path.isfile(input_file_path):
    print(f"错误: 文件 '{input_file_path}' 不存在!")
else:
    # 处理数据并保存
    output_file = process_vegetation_data(input_file_path)
    print(f"数据已成功处理并保存至: {output_file}")

数据已成功处理并保存至: E:\data\VWC\test-VWC\Insitu CLASIC07\CL07V_SUM_VEG_CLASIC.xlsx


In [16]:
# 数据填充
import pandas as pd
import numpy as np
import os
import h5py
from datetime import datetime, timedelta
import warnings
warnings.filterwarnings("ignore")

# 全局变量记录插值详细信息
interpolation_details = []

# ====================== 改进的MAT文件读取函数 ======================
def read_hdf5_mat(file_path, expected_keys=None):
    """读取MATLAB v7.3格式的HDF5文件，优先查找特定变量"""
    try:
        with h5py.File(file_path, 'r') as f:
            data = {}
            
            def visitor_func(name, obj):
                if isinstance(obj, h5py.Dataset):
                    if h5py.check_string_dtype(obj.dtype):
                        value = ''.join(chr(c) for c in obj[:])
                    else:
                        value = np.array(obj)
                    if value.ndim >= 2:
                        value = value.T
                    base_name = name.split('/')[-1]
                    data[base_name] = value
            
            f.visititems(visitor_func)
            
            # 优先查找预期变量
            if expected_keys:
                for key in expected_keys:
                    if key in data:
                        return {key: data[key]}
            
            return data
    except Exception as e:
        print(f"  读取HDF5 MAT文件失败: {str(e)}")
        return {}

# ====================== 改进的双线性插值函数 ======================
def bilinear_interpolation_with_details(lat_grid, lon_grid, target_lat, target_lon, grid_data):
    """
    执行双线性插值并记录详细信息
    :param lat_grid: 网格纬度数组 (1D, 从北向南递减)
    :param lon_grid: 网格经度数组 (1D, 从西向东递增)
    :param target_lat: 目标点纬度
    :param target_lon: 目标点经度
    :param grid_data: 网格数据 (2D数组, 形状为(len(lat_grid), len(lon_grid)))
    :return: 插值值
    """
    global interpolation_details
    
    try:
        # 记录网格形状
        grid_shape = grid_data.shape
        
        # 验证网格尺寸
        if len(lat_grid) != grid_shape[0] or len(lon_grid) != grid_shape[1]:
            print(f"警告: 网格尺寸不匹配! 纬度网格: {len(lat_grid)}, 经度网格: {len(lon_grid)}, 数据形状: {grid_shape}")
            return np.nan
        
        # 查找最近的纬度索引（纬度从北向南递减）
        # 纬度网格: 89.95 (北) -> -89.95 (南)
        lat_idx = np.argmin(np.abs(lat_grid - target_lat))
        
        # 查找最近的经度索引（经度从西向东递增）
        # 经度网格: -179.95 (西) -> 179.95 (东)

        lon_idx = np.argmin(np.abs(lon_grid - target_lon))
        
        # 确定四个角点索引
        # 纬度处理：目标点位于两个纬度网格点之间
        if target_lat > lat_grid[lat_idx]:
            # 目标纬度大于当前网格点纬度（更北）
            if lat_idx > 0:
                lat_idx0 = lat_idx - 1
                lat_idx1 = lat_idx
            else:
                lat_idx0 = lat_idx
                lat_idx1 = lat_idx
        else:
            # 目标纬度小于当前网格点纬度（更南）
            if lat_idx < len(lat_grid) - 1:
                lat_idx0 = lat_idx
                lat_idx1 = lat_idx + 1
            else:
                lat_idx0 = lat_idx
                lat_idx1 = lat_idx
        
        # 经度处理：目标点位于两个经度网格点之间
        if target_lon > lon_grid[lon_idx]:
            # 目标经度大于当前网格点经度（更东）
            if lon_idx < len(lon_grid) - 1:
                lon_idx0 = lon_idx
                lon_idx1 = lon_idx + 1
            else:
                lon_idx0 = lon_idx
                lon_idx1 = lon_idx
        else:
            # 目标经度小于当前网格点经度（更西）
            if lon_idx > 0:
                lon_idx0 = lon_idx - 1
                lon_idx1 = lon_idx
            else:
                lon_idx0 = lon_idx
                lon_idx1 = lon_idx
        
        # 获取四个角点值
        Q00 = grid_data[lat_idx0, lon_idx0]
        Q01 = grid_data[lat_idx0, lon_idx1]
        Q10 = grid_data[lat_idx1, lon_idx0]
        Q11 = grid_data[lat_idx1, lon_idx1]
        
        # 四个角点坐标
        y0 = lat_grid[lat_idx0]
        y1 = lat_grid[lat_idx1]
        x0 = lon_grid[lon_idx0]
        x1 = lon_grid[lon_idx1]
        
        # 如果有NaN，使用最接近的点
        if np.isnan(Q00) or np.isnan(Q01) or np.isnan(Q10) or np.isnan(Q11):
            result = grid_data[lat_idx, lon_idx]
            details = {
                'type': 'nearest',
                'row': lat_idx,
                'col': lon_idx,
                'target_lat': target_lat,
                'target_lon': target_lon,
                'grid_shape': grid_shape,
                'values': [grid_data[lat_idx, lon_idx]],
                'lat_values': [lat_grid[lat_idx]],
                'lon_values': [lon_grid[lon_idx]]
            }
        else:
            # 双线性插值公式
            dx = (target_lon - x0) / (x1 - x0) if (x1 - x0) != 0 else 0
            dy = (target_lat - y0) / (y1 - y0) if (y1 - y0) != 0 else 0
            result = (1 - dx) * (1 - dy) * Q00 + dx * (1 - dy) * Q01 + (1 - dx) * dy * Q10 + dx * dy * Q11
            
            details = {
                'type': 'bilinear',
                'rows': [lat_idx0, lat_idx0, lat_idx1, lat_idx1],
                'cols': [lon_idx0, lon_idx1, lon_idx0, lon_idx1],
                'target_lat': target_lat,
                'target_lon': target_lon,
                'grid_shape': grid_shape,
                'values': [Q00, Q01, Q10, Q11],
                'lat_values': [y0, y0, y1, y1],
                'lon_values': [x0, x1, x0, x1]
            }
        
        # 保存插值详细信息
        interpolation_details.append(details)
        return result
    
    except Exception as e:
        print(f"插值错误: {str(e)}")
        return np.nan
# ====================== 主处理函数 ======================
def process_clasic07_data(input_file_path):
    """
    处理CLASIC07数据，执行多种插值操作
    """
    global interpolation_details
    
    try:
        interpolation_details = []  # 重置插值详情
        
        # ========== 1. 读取原始数据 ==========
        print(f"读取原始Excel文件: {input_file_path}")
        df = pd.read_excel(input_file_path)
        
        # 调整列名以匹配您的文件
        df = df.rename(columns={
            'Latitude (WGS84)': 'Latitude',
            'Longitude (WGS84)': 'Longitude'
        })
        
        # 打印列名以验证
        print("数据列名:", df.columns.tolist())
        
        # 定义标准经纬度网格 (0.1°分辨率)
        # 纬度: 北纬89.95°(0) -> 南纬-89.95°(1799)
        lat_grid = np.linspace(89.95, -89.95, 1800)
        
        # 经度: -179.95°(0) -> 179.95°(3599) [根据您的要求]
        lon_grid = np.linspace(-179.95, 179.95, 3600)
        
        print(f"成功读取 {len(df)} 条记录")
        
        # ========== 2. 准备PFT数据 (14个类别) ==========
        pft_file = r"E:\data\ESACCI PFT\Resample\Data\2007.mat"
        if os.path.exists(pft_file):
            print(f"\n处理PFT数据: {pft_file}")
            mat_data = read_hdf5_mat(pft_file)
            
            pft_columns = ['water','bare','snowice','built','grassnat','grassman',
                          'shrubbd','shrubbe','shrubnd','shrubne',
                          'treebd','treebe','treend','treene']
            
            available_pft = [col for col in pft_columns if col in mat_data]
            print(f"  文件中可用的PFT变量: {', '.join(available_pft)}")
            
            # 处理每个可用的PFT类别
            for col in available_pft:
                grid_data = mat_data[col] / 100.0
                df[f'PFT_{col}'] = df.apply(
                    lambda row: bilinear_interpolation_with_details(
                        lat_grid, lon_grid, 
                        row['Latitude'], row['Longitude'], 
                        grid_data
                    ) if not np.isnan(row['Latitude']) and not np.isnan(row['Longitude']) 
                    else np.nan, axis=1
                )
                print(f"  已添加列: PFT_{col}")
        else:
            print(f"\n警告: PFT文件不存在 - {pft_file}")
        
        # ========== 3. 处理VOD数据 (7个变量) ==========
        vod_base_dir = r"E:\data\VOD\mat\kuxcVOD\ASC"
        vod_cols = ['SM','ku_vod_H', 'ku_vod_V', 'x_vod_H','x_vod_V', 'c_vod_H','c_vod_V']
        
        for col in vod_cols:
            df[col] = np.nan
        
        print("\n处理VOD数据:")
        
        # 收集所有唯一日期并排序
        unique_dates = sorted(df['Date'].unique())
        vod_files_found = 0
        
        for date in unique_dates:
            # 转换为字符串格式YYYYMMDD
            try:
                if isinstance(date, pd.Timestamp):
                    date_str = date.strftime("%Y%m%d")
                elif isinstance(date, datetime):
                    date_str = date.strftime("%Y%m%d")
                elif isinstance(date, str):
                    date_dt = datetime.strptime(date, "%Y-%m-%d")
                    date_str = date_dt.strftime("%Y%m%d")
                else:
                    date_str = str(date).replace("-", "")[:8]
            except Exception as e:
                print(f"  无法解析日期: {date}, 错误: {e}")
                continue
            
            vod_file = os.path.join(vod_base_dir, f"MCCA_AMSRE_010D_CCXH_VSM_VOD_Asc_{date_str}_V0.nc4.mat")
            if os.path.exists(vod_file):
                vod_files_found += 1
                print(f"  处理日期: {date_str}, 文件: {os.path.basename(vod_file)}")
                vod_data = read_hdf5_mat(vod_file)
                
                for col in vod_cols:
                    if col in vod_data:
                        grid_data = vod_data[col]
                        mask = df['Date'] == date
                        df.loc[mask, col] = df[mask].apply(
                            lambda row: bilinear_interpolation_with_details(
                                lat_grid, lon_grid, 
                                row['Latitude'], row['Longitude'], 
                                grid_data
                            ) if not np.isnan(row['Latitude']) and not np.isnan(row['Longitude']) 
                            else np.nan, axis=1
                        )
                        print(f"    已更新: {col}")
                    else:
                        print(f"    警告: VOD变量 {col} 不存在于文件中")
            else:
                print(f"  警告: VOD文件不存在 - {os.path.basename(vod_file)}")
                
        if vod_files_found == 0:
            print("  警告: 没有找到任何VOD文件，VOD列将保留为空")
        
        # ========== 4. 处理LAI卫星数据 (时间插值) ==========
        print("\n处理LAI卫星数据...")
        df['LAI_Satellite'] = np.nan
        
        # 预期可能的LAI变量名
        expected_lai_keys = ['lai', 'LAI', 'data']
        
        # 定义三个LAI文件
        lai_files = {
            datetime(2007, 5, 15): r"E:\data\GLASS LAI\mat\0.1Deg\Dataset\2007-05-01.tif.mat",
            datetime(2007, 6, 15): r"E:\data\GLASS LAI\mat\0.1Deg\Dataset\2007-06-01.tif.mat",
            datetime(2007, 7, 15): r"E:\data\GLASS LAI\mat\0.1Deg\Dataset\2007-07-01.tif.mat"
        }
        
        # 加载LAI数据
        lai_data = {}
        for date_key, file_path in lai_files.items():
            if os.path.exists(file_path):
                print(f"  加载LAI数据 ({date_key.strftime('%Y-%m-%d')}): {file_path}")
                # 使用增强的MAT文件读取函数，指定预期变量
                file_data = read_hdf5_mat(file_path, expected_keys=expected_lai_keys)
                
                if file_data:
                    # 直接获取LAI数据（优先找到的变量）
                    lai_value = list(file_data.values())[0]
                    print(f"    成功读取LAI变量 '{list(file_data.keys())[0]}'，数据形状: {lai_value.shape}")
                    lai_data[date_key] = lai_value
                else:
                    print("    警告: 文件中未找到预期LAI变量，使用全零数组")
                    lai_data[date_key] = np.zeros((1800, 3600))
            else:
                print(f"  警告: LAI文件不存在 - {file_path}")
                lai_data[date_key] = np.zeros((1800, 3600))
        
        # 定义关键日期
        mid_may = datetime(2007, 5, 15)
        mid_june = datetime(2007, 6, 15)
        mid_july = datetime(2007, 7, 15)
        
        # 处理每个日期的数据
        for date in unique_dates:
            try:
                # 确保日期为datetime对象
                if isinstance(date, pd.Timestamp):
                    date_dt = date.to_pydatetime()
                elif isinstance(date, datetime):
                    date_dt = date
                elif isinstance(date, str):
                    date_dt = datetime.strptime(date, "%Y-%m-%d")
                else:
                    print(f"  无法识别的日期格式: {date}")
                    continue
                
                # 应用LAI插值策略
                if date_dt <= mid_may:
                    # 使用5月15日数据
                    interpolated_lai = lai_data[mid_may]
                    print(f"  处理日期: {date_dt.strftime('%Y-%m-%d')}, 使用5月15日数据")
                elif date_dt == mid_june:
                    # 直接使用6月15日数据
                    interpolated_lai = lai_data[mid_june]
                    print(f"  处理日期: {date_dt.strftime('%Y-%m-%d')}, 使用6月15日数据")
                elif date_dt >= mid_july:
                    # 使用7月15日数据
                    interpolated_lai = lai_data[mid_july]
                    print(f"  处理日期: {date_dt.strftime('%Y-%m-%d')}, 使用7月15日数据")
                elif mid_may < date_dt < mid_june:
                    # 5月16日至6月14日使用5月和6月数据插值
                    weight = (date_dt - mid_may).days / (mid_june - mid_may).days
                    interpolated_lai = (1 - weight) * lai_data[mid_may] + weight * lai_data[mid_june]
                    print(f"  处理日期: {date_dt.strftime('%Y-%m-%d')}, 5-6月插值, 权重: {weight:.2f}")
                elif mid_june < date_dt < mid_july:
                    # 6月16日至7月14日使用6月和7月数据插值
                    weight = (date_dt - mid_june).days / (mid_july - mid_june).days
                    interpolated_lai = (1 - weight) * lai_data[mid_june] + weight * lai_data[mid_july]
                    print(f"  处理日期: {date_dt.strftime('%Y-%m-%d')}, 6-7月插值, 权重: {weight:.2f}")
                else:
                    print(f"  日期 {date_dt.strftime('%Y-%m-%d')} 不在处理范围内")
                    continue
                
                # 计算平均LAI值用于验证
                mean_lai = np.nanmean(interpolated_lai)
                print(f"    平均LAI: {mean_lai:.4f}")
                
                # 应用空间插值
                mask = df['Date'] == date
                df.loc[mask, 'LAI_Satellite'] = df[mask].apply(
                    lambda row: bilinear_interpolation_with_details(
                        lat_grid, lon_grid, 
                        row['Latitude'], row['Longitude'], 
                        interpolated_lai
                    ) if not np.isnan(row['Latitude']) and not np.isnan(row['Longitude']) 
                    else np.nan, axis=1
                )
            except Exception as e:
                print(f"  处理日期{date}时出错: {str(e)}")
        
        # ========== 5. 处理植被高度数据 ==========
        print("\n处理植被高度数据...")
        df['Hveg'] = np.nan
        
        hveg_file = r"E:\data\CanopyHeight\CH.mat"
        if os.path.exists(hveg_file):
            print(f"  加载植被高度数据: {hveg_file}")
            hveg_data = read_hdf5_mat(hveg_file)
            hveg_key = next(iter(hveg_data)) if hveg_data else None
            
            if hveg_key:
                df['Hveg'] = df.apply(
                    lambda row: bilinear_interpolation_with_details(
                        lat_grid, lon_grid, 
                        row['Latitude'], row['Longitude'], 
                        hveg_data[hveg_key]
                    ) if not np.isnan(row['Latitude']) and not np.isnan(row['Longitude']) 
                    else np.nan, axis=1
                )
                print(f"  已添加植被高度列")
            else:
                print(f"  警告: 无法找到Hveg变量")
                df['Hveg'] = np.nan
        else:
            print(f"  警告: Hveg文件不存在 - {hveg_file}")
            df['Hveg'] = np.nan
        
        # ========== 6. 保存结果 ==========
        output_file_path = r"E:\data\VWC\test-VWC\Insitu CLASIC07\CL07V_SUM_VEG_CLASIC_ML.xlsx"
        print(f"\n保存结果到: {output_file_path}")
        df.to_excel(output_file_path, index=False)
        
        # 保存插值详细信息到Excel
        if interpolation_details:
            details_df = pd.DataFrame(interpolation_details)
            details_path = r"E:\data\VWC\test-VWC\Insitu CLASIC07\interpolation_details.xlsx"
            details_df.to_excel(details_path, index=False)
            print(f"插值详细信息保存到: {details_path}")
        else:
            print("警告: 没有插值详细信息可保存")
        
        # ========== 7. 统计报告 ==========
        print("\n处理完成!")
        print(f"总记录数: {len(df)}")
        print(f"插值操作次数: {len(interpolation_details)}")
        
        if interpolation_details:
            # 显示前3次插值的详细信息
            print("\n前3次插值的详细信息:")
            for i, detail in enumerate(interpolation_details[:3]):
                print(f"\n插值 #{i+1}")
                print(f"  类型: {detail['type']}")
                print(f"  目标位置: ({detail['target_lat']:.6f}, {detail['target_lon']:.6f})")
                print(f"  网格形状: {detail['grid_shape']}")
                
                if detail['type'] == 'bilinear':
                    print(f"  使用的4个网格点:")
                    for j in range(4):
                        print(f"    点{j+1}: 行 {detail['rows'][j]}, 列 {detail['cols'][j]} - " +
                              f"位置 ({detail['lat_values'][j]:.6f}, {detail['lon_values'][j]:.6f}) - " +
                              f"值: {detail['values'][j]:.6f}")
                else:
                    print(f"  最近邻点: 行 {detail['row']}, 列 {detail['col']} - " +
                          f"位置 ({detail['lat_values'][0]:.6f}, {detail['lon_values'][0]:.6f}) - " +
                          f"值: {detail['values'][0]:.6f}")
        
        return True
        
    except Exception as e:
        print(f"处理过程中出错: {str(e)}")
        import traceback
        print("错误详细信息:")
        print(traceback.format_exc())
        return False

# ========================== 主程序 ==========================
if __name__ == "__main__":
    # 输入文件路径
    input_file = r"E:\data\VWC\test-VWC\Insitu CLASIC07\CL07V_SUM_VEG_CLASIC.xlsx"
    
    print("="*60)
    print("开始处理CLASIC07数据插值任务")
    print("="*60)
    
    if not os.path.exists(input_file):
        print(f"错误: 输入文件不存在 - {input_file}")
        print(f"请检查路径: {os.path.abspath(input_file)}")
    else:
        print(f"输入文件: {input_file}")
        print(f"输出将保存到: E:\\data\\VWC\\test-VWC\\Insitu CLASIC07\\CL07V_SUM_VEG_CLASIC_ML.xlsx")
        
        success = process_clasic07_data(input_file)
        if success:
            print("\n" + "="*30)
            print("任务成功完成!")
            print("="*30)
        else:
            print("\n" + "="*30)
            print("任务失败，请检查错误信息")
            print("="*30)

开始处理CLASIC07数据插值任务
输入文件: E:\data\VWC\test-VWC\Insitu CLASIC07\CL07V_SUM_VEG_CLASIC.xlsx
输出将保存到: E:\data\VWC\test-VWC\Insitu CLASIC07\CL07V_SUM_VEG_CLASIC_ML.xlsx
读取原始Excel文件: E:\data\VWC\test-VWC\Insitu CLASIC07\CL07V_SUM_VEG_CLASIC.xlsx
数据列名: ['Field', 'Crop', 'Date', 'DOY', 'Time (CDT)', 'VWC (kg/m²)', 'Latitude', 'Longitude', 'UTM Easting', 'UTM Northing']
成功读取 22 条记录

处理PFT数据: E:\data\ESACCI PFT\Resample\Data\2007.mat
  文件中可用的PFT变量: water, bare, snowice, built, grassnat, grassman, shrubbd, shrubbe, shrubnd, shrubne, treebd, treebe, treend, treene
  已添加列: PFT_water
  已添加列: PFT_bare
  已添加列: PFT_snowice
  已添加列: PFT_built
  已添加列: PFT_grassnat
  已添加列: PFT_grassman
  已添加列: PFT_shrubbd
  已添加列: PFT_shrubbe
  已添加列: PFT_shrubnd
  已添加列: PFT_shrubne
  已添加列: PFT_treebd
  已添加列: PFT_treebe
  已添加列: PFT_treend
  已添加列: PFT_treene

处理VOD数据:
  处理日期: 20070610, 文件: MCCA_AMSRE_010D_CCXH_VSM_VOD_Asc_20070610_V0.nc4.mat
    已更新: SM
    已更新: ku_vod_H
    已更新: ku_vod_V
    已更新: x_vod_H
    已更新: x_vod_V
    已

In [16]:
import pandas as pd
import numpy as np
import os
import joblib
import matplotlib.pyplot as plt
import matplotlib
from sklearn.metrics import mean_squared_error, r2_score

# 设置全局字体
plt.rcParams['font.family'] = 'Arial'
plt.rcParams['font.weight'] = 'bold'

# 设置常量
TEST_FILE = r"E:\data\VWC\test-VWC\Insitu CLASIC07\CL07V_SUM_VEG_CLASIC_ML.xlsx"
MODEL_DIR = "models"
SAVE_RESULTS = "model_predictions_results_CLASIC07.xlsx"
FIG_DIR = "figures"

# 定义波段和极化组合
BANDS = ['Ku', 'C', 'X']
POLS = ['H', 'V', 'HV']

# 波段颜色定义
BAND_COLORS = {
    'Ku': (253/255, 173/255, 115/255, 0.7),
    'C': (178/255, 125/255, 104/255, 0.7),
    'X': (224/255, 104/255, 46/255, 0.7)
}

# 极化类型标记定义
POL_MARKERS = {
    'H': 's',  # 方形
    'V': '^',  # 三角形
    'HV': 'o'  # 圆形
}

def normalize_LAI(lai_series):
    """对LAI进行归一化"""
    return lai_series.clip(0, 6) / 6

def normalize_VOD(vod_series):
    """对VOD进行归一化"""
    return vod_series.clip(0, 2) / 2

# PFT列名映射字典
PFT_MAPPING = {
    'PFT_grassnat': 'Grass_nat',
    'PFT_grassman': 'Grass_man',
    'PFT_shrubbd': 'Shrub_bd',
    'PFT_shrubbe': 'Shrub_be',
    'PFT_shrubnd': 'Shrub_nd',
    'PFT_shrubne': 'Shrub_ne',
    'PFT_treebd': 'Tree_bd',
    'PFT_treebe': 'Tree_be',
    'PFT_treend': 'Tree_nd',
    'PFT_treene': 'Tree_ne'
}

def get_model_columns(band, pol):
    """获取指定模型所需的列名"""
    base_columns = [
        'VWC (kg/m²)',  # 实际值
        'LAI_Satellite',  # LAI
        'SM'  # 土壤湿度
    ]
    
    # 添加所有PFT列
    base_columns.extend(PFT_MAPPING.keys())
    
    # 根据极化类型添加VOD列
    if pol == 'H':
        return base_columns + [f'{band.lower()}_vod_H']
    elif pol == 'V':
        return base_columns + [f'{band.lower()}_vod_V']
    elif pol == 'HV':
        return base_columns + [f'{band.lower()}_vod_H', f'{band.lower()}_vod_V']

def get_feature_order(pol):
    """获取特征列的顺序（模型期望的列顺序）"""
    base_features = [
        'LAI', 'SM',
        'Grass_man', 'Grass_nat',
        'Shrub_bd', 'Shrub_be', 'Shrub_nd', 'Shrub_ne',
        'Tree_bd', 'Tree_be', 'Tree_nd', 'Tree_ne'
    ]
    
    if pol in ['H', 'V']:
        return ['VOD'] + base_features
    elif pol == 'HV':
        return ['VOD-Hpol', 'VOD-Vpol'] + base_features

def prepare_input_data(df, band, pol):
    """为指定模型准备输入数据"""
    # 创建数据副本
    data = df.copy()
    
    # 1. 归一化处理
    data['LAI'] = normalize_LAI(data['LAI_Satellite'])
    
    # 2. 处理VOD列
    if pol == 'H':
        vod_col = f'{band.lower()}_vod_H'
        data['VOD'] = normalize_VOD(data[vod_col])
    elif pol == 'V':
        vod_col = f'{band.lower()}_vod_V'
        data['VOD'] = normalize_VOD(data[vod_col])
    elif pol == 'HV':
        # 重命名列以匹配模型训练时的特征名
        data = data.rename(columns={
            f'{band.lower()}_vod_H': 'VOD-Hpol',
            f'{band.lower()}_vod_V': 'VOD-Vpol'
        })
        # 归一化处理
        data['VOD-Hpol'] = normalize_VOD(data['VOD-Hpol'])
        data['VOD-Vpol'] = normalize_VOD(data['VOD-Vpol'])
    
    # 3. 重命名PFT列为模型期望的名称
    data = data.rename(columns=PFT_MAPPING)
    
    # 4. 按模型要求排序特征列
    feature_order = get_feature_order(pol)
    
    return data[feature_order]

def plot_combined_scatter(actual, predictions_dict):
    """
    绘制组合散点图，包含所有波段和极化类型
    
    参数:
    actual -- 实际值 (Series)
    predictions_dict -- 字典结构: {
        'H': {band: pred_series},
        'V': {band: pred_series},
        'HV': {band: pred_series}
    }
    """
    # 创建图形
    plt.figure(figsize=(10, 10))
    ax = plt.gca()
    
    # 存储所有组合的RMSE值
    rmse_values = {}
    
    # 收集所有数据点
    max_val = 0
    
    # 遍历所有波段和极化组合
    for band in BANDS:
        for pol in POLS:
            pred_series = predictions_dict[pol].get(band)
            
            if pred_series is not None and not pred_series.isnull().all():
                # 创建实际值和预测值的临时DF
                temp_df = pd.DataFrame({
                    'actual': actual,
                    'pred': pred_series
                }).dropna()
                
                if not temp_df.empty:
                    # 计算RMSE
                    rmse = np.sqrt(mean_squared_error(temp_df['actual'], temp_df['pred']))
                    rmse_values[f"{band}-{pol}"] = rmse
                    
                    # 更新最大值
                    band_max = max(temp_df['actual'].max(), temp_df['pred'].max())
                    if band_max > max_val:
                        max_val = band_max
                    
                    # 绘制散点
                    plt.scatter(
                        temp_df['actual'], temp_df['pred'], 
                        alpha=0.7, 
                        color=BAND_COLORS[band],
                        marker=POL_MARKERS[pol],
                        s=50,
                        edgecolors='none',
                        zorder=2,
                        label=f"{band}-{pol}"
                    )
    
    # 如果没有数据可绘制，直接返回
    if not rmse_values:
        print("  警告: 没有有效的预测数据!")
        plt.close()
        return
    
    # 添加1:1参考线
    max_val *= 1.05
    plt.plot([0, max_val], [0, max_val], 'k--', lw=1.5, alpha=0.7, zorder=1)
    
    # 设置坐标轴范围
    plt.xlim(0, max_val)
    plt.ylim(0, max_val)
    
    # 设置坐标轴标签
    plt.xlabel('Insitu VWC (kg/m²)', fontsize=14, fontweight='bold')
    plt.ylabel('RF VWC (kg/m²)', fontsize=14, fontweight='bold')
    
    # 设置标题
    plt.title('CLASIC07 Insitu VWC', 
             fontsize=18, fontweight='bold', pad=20)
    
    # # 添加图例
    # plt.legend(loc='lower right', frameon=True, fontsize=10, ncol=3)
    
    # 添加RMSE文本（左上角，3×3网格布局）
    if rmse_values:
        # 设置文本位置
        x_pos = 0.05
        y_pos = 0.95
        
        # 添加标题
        plt.text(x_pos, y_pos, 'RMSE (kg/m²):', 
                 transform=ax.transAxes,
                 fontsize=12,
                 fontweight='bold',
                 verticalalignment='top')
        
        y_pos -= 0.05
        
        # 遍历每个波段
        for band_idx, band in enumerate(BANDS):
            # 遍历每个极化类型
            for pol_idx, pol in enumerate(POLS):
                # 计算位置
                text_x = x_pos + pol_idx * 0.15
                text_y = y_pos - band_idx * 0.08
                
                # 获取RMSE值
                rmse = rmse_values.get(f"{band}-{pol}", None)
                
                if rmse is not None:
                    # 绘制标记
                    plt.scatter(
                        text_x, text_y, 
                        transform=ax.transAxes,
                        marker=POL_MARKERS[pol],
                        color=BAND_COLORS[band],
                        s=80,
                        alpha=0.7
                    )
                    
                    # 添加文本
                    plt.text(
                        text_x + 0.01, text_y, 
                        f"{band}-{pol}: {rmse:.3f}", 
                        transform=ax.transAxes,
                        fontsize=10,
                        fontweight='bold',
                        verticalalignment='center'
                    )
                else:
                    # 添加缺失值标记
                    plt.text(
                        text_x, text_y, 
                        f"{band}-{pol}: N/A", 
                        transform=ax.transAxes,
                        fontsize=10,
                        fontweight='bold',
                        verticalalignment='center',
                        color='gray'
                    )
    
    # 添加网格线
    plt.grid(True, linestyle='--', alpha=0.3, zorder=0)
    
    # 调整布局
    plt.tight_layout()
    
    # 创建保存目录
    os.makedirs(FIG_DIR, exist_ok=True)
    
    # 保存图像
    fig_path = os.path.join(FIG_DIR, 'CLASIC07_VWC_Scatter.png')
    plt.savefig(fig_path, dpi=300, bbox_inches='tight')
    print(f"  组合散点图已保存至: {fig_path}")
    plt.close()

def predict_and_evaluate():
    """主函数：加载所有模型进行预测并评估结果"""
    # 1. 加载测试数据
    print(f"正在加载测试数据: {TEST_FILE}")
    
    # 收集所有可能的列
    all_columns = set(['VWC (kg/m²)', 'LAI_Satellite', 'SM'])
    # 添加所有PFT列
    all_columns.update(PFT_MAPPING.keys())
    # 添加所有VOD列
    for band in BANDS:
        all_columns.add(f'{band.lower()}_vod_H')
        all_columns.add(f'{band.lower()}_vod_V')
    
    # 读取Excel文件
    test_df = pd.read_excel(TEST_FILE, usecols=list(all_columns))
    print(f"加载完成，总样本数: {len(test_df)}")
    
    # 存储所有预测结果
    results = pd.DataFrame(index=test_df.index)
    results['Actual_VWC'] = test_df['VWC (kg/m²)']
    
    # 为每个极化类型存储预测结果
    predictions_by_pol = {
        'H': {band: None for band in BANDS},
        'V': {band: None for band in BANDS},
        'HV': {band: None for band in BANDS}
    }
    
    # 2. 对每个模型进行预测
    for band in BANDS:
        for pol in POLS:
            model_name = f"RFR_{band}_{pol}pol_Type1.pkl"
            model_path = os.path.join(MODEL_DIR, model_name)
            
            print(f"\n处理 {band}-{pol} 模型: {model_name}")
            
            # 准备输入数据
            model_cols = get_model_columns(band, pol)
            model_data = test_df[model_cols].copy()
            
            # 删除缺失值
            clean_data = model_data.dropna()
            print(f"  有效样本数: {len(clean_data)} (删除缺失值后)")
            
            if len(clean_data) == 0:
                print("  警告: 无有效样本可用于此模型!")
                results[f"{band}_{pol}_Predicted"] = np.nan
                predictions_by_pol[pol][band] = None
                continue
            
            # 预处理输入数据
            try:
                X_input = prepare_input_data(clean_data, band, pol)
                
                # 加载模型并进行预测
                if os.path.exists(model_path):
                    model = joblib.load(model_path)
                    predictions = model.predict(X_input)
                    
                    # 存储预测结果
                    results[f"{band}_{pol}_Predicted"] = np.nan
                    results.loc[clean_data.index, f"{band}_{pol}_Predicted"] = predictions
                    
                    # 存储到对应极化类型的字典
                    predictions_by_pol[pol][band] = results[f"{band}_{pol}_Predicted"].copy()
                    
                    # 计算评估指标
                    actual = clean_data['VWC (kg/m²)']
                    rmse = np.sqrt(mean_squared_error(actual, predictions))
                    r2 = r2_score(actual, predictions)
                    print(f"  预测完成 - RMSE: {rmse:.4f}, R²: {r2:.4f}")
                    
                else:
                    print(f"  警告: 未找到模型文件 {model_path}!")
                    results[f"{band}_{pol}_Predicted"] = np.nan
                    predictions_by_pol[pol][band] = None
            except Exception as e:
                import traceback
                print(f"  预测失败: {str(e)}")
                # 打印更详细的错误信息
                traceback.print_exc()
                results[f"{band}_{pol}_Predicted"] = np.nan
                predictions_by_pol[pol][band] = None
    
    # 3. 保存结果
    results.to_excel(SAVE_RESULTS)
    print(f"\n所有预测结果已保存至: {SAVE_RESULTS}")
    
    # 4. 绘制组合散点图
    print("\n正在绘制组合散点图...")
    plot_combined_scatter(
        results['Actual_VWC'], 
        predictions_by_pol
    )
    
    return results, predictions_by_pol

# 执行主函数
if __name__ == "__main__":
    results, predictions_by_pol = predict_and_evaluate()
    print("\n所有处理完成!")

正在加载测试数据: E:\data\VWC\test-VWC\Insitu CLASIC07\CL07V_SUM_VEG_CLASIC_ML.xlsx
加载完成，总样本数: 22

处理 Ku-H 模型: RFR_Ku_Hpol_Type1.pkl
  有效样本数: 17 (删除缺失值后)
  预测完成 - RMSE: 2.1272, R²: -0.4410

处理 Ku-V 模型: RFR_Ku_Vpol_Type1.pkl
  有效样本数: 7 (删除缺失值后)
  预测完成 - RMSE: 1.8424, R²: -101.4697

处理 Ku-HV 模型: RFR_Ku_HVpol_Type1.pkl
  有效样本数: 7 (删除缺失值后)
  预测完成 - RMSE: 1.4431, R²: -61.8683

处理 C-H 模型: RFR_C_Hpol_Type1.pkl
  有效样本数: 17 (删除缺失值后)
  预测完成 - RMSE: 2.4028, R²: -0.8387

处理 C-V 模型: RFR_C_Vpol_Type1.pkl
  有效样本数: 17 (删除缺失值后)
  预测完成 - RMSE: 2.5623, R²: -1.0909

处理 C-HV 模型: RFR_C_HVpol_Type1.pkl
  有效样本数: 17 (删除缺失值后)
  预测完成 - RMSE: 2.1126, R²: -0.4214

处理 X-H 模型: RFR_X_Hpol_Type1.pkl
  有效样本数: 17 (删除缺失值后)
  预测完成 - RMSE: 2.1898, R²: -0.5271

处理 X-V 模型: RFR_X_Vpol_Type1.pkl
  有效样本数: 11 (删除缺失值后)
  预测完成 - RMSE: 1.7433, R²: -3.8789

处理 X-HV 模型: RFR_X_HVpol_Type1.pkl
  有效样本数: 11 (删除缺失值后)
  预测完成 - RMSE: 1.5107, R²: -2.6636

所有预测结果已保存至: model_predictions_results_CLASIC07.xlsx

正在绘制组合散点图...
  组合散点图已保存至: figures\CLASIC07

正在加载测试数据: E:\data\VWC\test-VWC\Insitu CLASIC07\CL07V_SUM_VEG_CLASIC_ML.xlsx
加载完成，总样本数: 22

处理 Ku-H 模型: RFR_Ku_Hpol_Type1.pkl
  有效样本数: 17 (删除缺失值后)
  预测完成 - RMSE: 2.1272, R²: -0.4410

处理 Ku-V 模型: RFR_Ku_Vpol_Type1.pkl
  有效样本数: 7 (删除缺失值后)
  预测完成 - RMSE: 1.8424, R²: -101.4697

处理 Ku-HV 模型: RFR_Ku_HVpol_Type1.pkl
  有效样本数: 7 (删除缺失值后)
  预测完成 - RMSE: 1.4431, R²: -61.8683

处理 C-H 模型: RFR_C_Hpol_Type1.pkl
  有效样本数: 17 (删除缺失值后)
  预测完成 - RMSE: 2.4028, R²: -0.8387

处理 C-V 模型: RFR_C_Vpol_Type1.pkl
  有效样本数: 17 (删除缺失值后)
  预测完成 - RMSE: 2.5623, R²: -1.0909

处理 C-HV 模型: RFR_C_HVpol_Type1.pkl
  有效样本数: 17 (删除缺失值后)
  预测完成 - RMSE: 2.1126, R²: -0.4214

处理 X-H 模型: RFR_X_Hpol_Type1.pkl
  有效样本数: 17 (删除缺失值后)
  预测完成 - RMSE: 2.1898, R²: -0.5271

处理 X-V 模型: RFR_X_Vpol_Type1.pkl
  有效样本数: 11 (删除缺失值后)
  预测完成 - RMSE: 1.7433, R²: -3.8789

处理 X-HV 模型: RFR_X_HVpol_Type1.pkl
  有效样本数: 11 (删除缺失值后)
  预测完成 - RMSE: 1.5107, R²: -2.6636

所有预测结果已保存至: model_predictions_results.xlsx

正在绘制分组箱线图...
  分组散点图已保存至: figures\CLASIC07_VWC_Grou

# 3.SMAPVEX08 

In [5]:
# SMEX08 VWC Map数据数据处理为tif文件
import os
import struct
import array
from osgeo import gdal, osr
import xml.etree.ElementTree as ET

def parse_envi_header(header_path):
    """解析 ENVI 头文件 (.hdr)"""
    params = {}
    try:
        with open(header_path, 'r') as f:
            for line in f:
                line = line.strip()
                if not line or line.startswith(';'):
                    continue
                
                if '=' in line:
                    key, value = line.split('=', 1)
                    key = key.strip()
                    value = value.strip()
                    
                    # 处理特殊字符
                    if '{' in value and '}' in value:
                        value = value.split('{')[1].split('}')[0].strip()
                    
                    # 解析数值
                    if value.replace('.', '', 1).replace('-', '', 1).isdigit():
                        if '.' in value:
                            value = float(value)
                        else:
                            value = int(value)
                    
                    params[key] = value
    except Exception as e:
        print(f"头文件解析错误: {str(e)}")
    
    return {
        'samples': params.get('samples', 8885),
        'lines': params.get('lines', 7956),
        'bands': params.get('bands', 1),
        'data_type': params.get('data type', 4),  # float32
        'byte_order': params.get('byte order', 0),  # 小端
        'header_offset': params.get('header offset', 0),
        'interleave': params.get('interleave', 'bsq'),
        'map_info': params.get('map info', None)
    }

def get_gdal_type(data_type):
    """映射ENVI数据类型到GDAL类型"""
    type_map = {
        1: gdal.GDT_Byte,     # byte
        2: gdal.GDT_Int16,    # int16
        3: gdal.GDT_Int32,    # int32
        4: gdal.GDT_Float32,  # float32
        5: gdal.GDT_Float64,  # float64
        12: gdal.GDT_UInt16,  # uint16
        13: gdal.GDT_UInt32   # uint32
    }
    return type_map.get(data_type, gdal.GDT_Float32)

def get_data_size(data_type):
    """获取每种数据类型的字节大小"""
    size_map = {
        1: 1,  # byte
        2: 2,  # int16
        3: 4,  # int32
        4: 4,  # float32
        5: 8,  # float64
        12: 2, # uint16
        13: 4  # uint32
    }
    return size_map.get(data_type, 4)

def get_struct_format(byte_order, data_type):
    """生成struct格式字符串"""
    endian_char = '<' if byte_order == 0 else '>'  # <小端 >大端
    
    type_char_map = {
        1: 'B',  # 无符号字节
        2: 'h',  # 短整型
        3: 'i',  # 整型
        4: 'f',  # 浮点数
        5: 'd',  # 双精度
        12: 'H', # 无符号短整型
        13: 'I'  # 无符号整型
    }
    
    return endian_char + type_char_map.get(data_type, 'f')

def fst_to_tif_no_numpy(input_fst_path, output_tif_path):
    """不使用NumPy转换ENVI FST文件为GeoTIFF"""
    # 获取相关文件路径
    input_dir = os.path.dirname(input_fst_path)
    input_name = os.path.splitext(os.path.basename(input_fst_path))[0]
    
    header_path = os.path.join(input_dir, f"{input_name}.hdr")
    xml_path = os.path.join(input_dir, f"{input_name}.xml")
    
    # 解析头文件
    header_info = parse_envi_header(header_path)
    
    # 提取关键参数
    samples = header_info['samples']
    lines = header_info['lines']
    bands = header_info['bands']
    data_type = header_info['data_type']
    byte_order = header_info['byte_order']
    header_offset = header_info['header_offset']
    interleave = header_info['interleave'].lower()
    
    # 计算数据大小
    data_size = get_data_size(data_type)
    total_pixels = samples * lines * bands
    total_bytes = total_pixels * data_size
    
    print(f"文件信息:")
    print(f"  尺寸: {samples}×{lines}像素")
    print(f"  波段数: {bands}")
    print(f"  数据类型: {data_type}")
    print(f"  交错方式: {interleave}")
    print(f"  总像素数: {total_pixels:,}")
    print(f"  总字节数: {total_bytes:,}")
    
    # 创建GDAL驱动
    driver = gdal.GetDriverByName('GTiff')
    gdal_type = get_gdal_type(data_type)
    ds = driver.Create(
        output_tif_path,
        samples,
        lines,
        bands,
        gdal_type
    )
    
    if ds is None:
        print(f"无法创建输出文件: {output_tif_path}")
        return False
    
    # 设置地理变换
    x_min = 375325.0
    y_max = 4361460.0
    pixel_width = 10.0
    pixel_height = -10.0  # 负值因为Y轴从北向南
    transform = (x_min, pixel_width, 0, y_max, 0, pixel_height)
    ds.SetGeoTransform(transform)
    
    # 设置投影
    srs = osr.SpatialReference()
    srs.ImportFromEPSG(32618)  # UTM zone 18N, WGS84
    ds.SetProjection(srs.ExportToWkt())
    
    # 准备读取数据
    struct_format = get_struct_format(byte_order, data_type)
    pixel_size = struct.calcsize(struct_format)
    print(f"  每个像素大小: {pixel_size}字节")
    
    # 打开FST文件
    with open(input_fst_path, 'rb') as f:
        # 跳过头部偏移
        f.seek(header_offset)
        
        # 逐行读取数据
        for line_idx in range(lines):
            if line_idx % 100 == 0:
                print(f"  处理进度: {line_idx+1}/{lines}行 ({((line_idx+1)/lines)*100:.1f}%)")
            
            # 根据交错格式读取数据
            band_data = {}
            
            if interleave == 'bsq':
                # 对每个波段顺序读取整行
                for band in range(bands):
                    # 读取整行数据
                    line_bytes = f.read(samples * data_size)
                    
                    if len(line_bytes) != samples * data_size:
                        print(f"错误: 行数据不完整 (预期 {samples*data_size}字节, 实际 {len(line_bytes)}字节)")
                        return False
                    
                    band_data[band] = line_bytes
            
            elif interleave == 'bil':
                # 整行包含所有波段
                line_bytes = f.read(samples * bands * data_size)
                
                if len(line_bytes) != samples * bands * data_size:
                    print(f"错误: 行数据不完整 (预期 {samples*bands*data_size}字节, 实际 {len(line_bytes)}字节)")
                    return False
                
                # 分割到各波段
                for band in range(bands):
                    band_bytes = bytearray()
                    for sample in range(samples):
                        # 提取当前像素在当前波段的数据
                        start = (sample * bands + band) * data_size
                        end = start + data_size
                        band_bytes.extend(line_bytes[start:end])
                    band_data[band] = bytes(band_bytes)
            
            elif interleave == 'bip':
                # 逐像素读取
                line_bytes = f.read(samples * bands * data_size)
                
                if len(line_bytes) != samples * bands * data_size:
                    print(f"错误: 行数据不完整 (预期 {samples*bands*data_size}字节, 实际 {len(line_bytes)}字节)")
                    return False
                
                # 分割到各波段
                for band in range(bands):
                    band_bytes = bytearray()
                    for sample in range(samples):
                        # 提取当前像素在当前波段的数据
                        start = (sample * bands + band) * data_size
                        end = start + data_size
                        band_bytes.extend(line_bytes[start:end])
                    band_data[band] = bytes(band_bytes)
            
            else:
                print(f"错误: 不支持的interleave类型: {interleave}")
                return False
            
            # 写入GDAL
            for band_idx, data_bytes in band_data.items():
                band = ds.GetRasterBand(band_idx + 1)
                band.WriteRaster(0, line_idx, samples, 1, data_bytes)
    
    # 设置元数据
    for band_idx in range(bands):
        band = ds.GetRasterBand(band_idx + 1)
        band.SetDescription(f"Band {band_idx+1}")
        band.SetMetadataItem("DESCRIPTION", "Soil Volumetric Water Content")
        band.SetNoDataValue(0.0)
    
    # 添加全局元数据
    ds.SetMetadata({
        "Source_File": os.path.basename(input_fst_path),
        "Rows": str(lines),
        "Columns": str(samples),
        "Bands": str(bands),
        "Data_Type": str(data_type),
        "Processed_By": "GDAL-only FST Converter"
    })
    
    # 清理
    ds.FlushCache()
    ds = None
    
    print(f"成功创建: {output_tif_path}")
    return True

if __name__ == "__main__":
    # 输入和输出路径
    input_dir = r"E:\data\VWC\test-VWC\SMEX08"
    fst_file = "SV08VWC_vwc.fst"
    input_path = os.path.join(input_dir, fst_file)
    output_path = os.path.join(input_dir, fst_file.replace('.fst', '_gdal.tif'))
    
    print(f"开始转换 FST 到 TIFF (GDAL-only方法)")
    print(f"输入文件: {input_path}")
    print(f"输出文件: {output_path}")
    
    # 确保文件存在
    if not os.path.exists(input_path):
        print(f"错误: 输入文件不存在 - {input_path}")
        exit(1)
    
    # 检查头文件
    header_path = input_path.replace('.fst', '.hdr')
    if not os.path.exists(header_path):
        print(f"错误: ENVI头文件不存在 - {header_path}")
        exit(1)
    
    # 转换文件
    success = fst_to_tif_no_numpy(input_path, output_path)
    
    if success:
        print("\n转换成功完成!")
        print(f"请检查输出文件: {output_path}")
        
        # 验证结果
        if os.path.exists(output_path):
            print("\n输出文件验证:")
            try:
                ds = gdal.Open(output_path)
                if ds:
                    print(f"  尺寸: {ds.RasterXSize}x{ds.RasterYSize}")
                    print(f"  波段数: {ds.RasterCount}")
                    print(f"  数据类型: {gdal.GetDataTypeName(ds.GetRasterBand(1).DataType)}")
                    
                    # 简单数据采样
                    band = ds.GetRasterBand(1)
                    scanline = band.ReadRaster(0, 0, min(10, ds.RasterXSize), 1)
                    values = struct.unpack('f' * min(10, ds.RasterXSize), scanline)
                    print(f"  第一行前10个值: {values}")
                    
                    ds = None
                else:
                    print("  警告: 无法打开输出文件进行验证")
            except Exception as e:
                print(f"  验证时出错: {str(e)}")
        else:
            print("  错误: 输出文件未创建")
    else:
        print("转换过程中出现错误")

开始转换 FST 到 TIFF (GDAL-only方法)
输入文件: E:\data\VWC\test-VWC\SMEX08\SV08VWC_vwc.fst
输出文件: E:\data\VWC\test-VWC\SMEX08\SV08VWC_vwc_gdal.tif
文件信息:
  尺寸: 8885×7956像素
  波段数: 1
  数据类型: 4
  交错方式: bsq
  总像素数: 70,689,060
  总字节数: 282,756,240
  每个像素大小: 4字节
  处理进度: 1/7956行 (0.0%)
  处理进度: 101/7956行 (1.3%)
  处理进度: 201/7956行 (2.5%)
  处理进度: 301/7956行 (3.8%)
  处理进度: 401/7956行 (5.0%)
  处理进度: 501/7956行 (6.3%)
  处理进度: 601/7956行 (7.6%)
  处理进度: 701/7956行 (8.8%)
  处理进度: 801/7956行 (10.1%)
  处理进度: 901/7956行 (11.3%)
  处理进度: 1001/7956行 (12.6%)
  处理进度: 1101/7956行 (13.8%)
  处理进度: 1201/7956行 (15.1%)
  处理进度: 1301/7956行 (16.4%)
  处理进度: 1401/7956行 (17.6%)
  处理进度: 1501/7956行 (18.9%)
  处理进度: 1601/7956行 (20.1%)
  处理进度: 1701/7956行 (21.4%)
  处理进度: 1801/7956行 (22.6%)
  处理进度: 1901/7956行 (23.9%)
  处理进度: 2001/7956行 (25.2%)
  处理进度: 2101/7956行 (26.4%)
  处理进度: 2201/7956行 (27.7%)
  处理进度: 2301/7956行 (28.9%)
  处理进度: 2401/7956行 (30.2%)
  处理进度: 2501/7956行 (31.4%)
  处理进度: 2601/7956行 (32.7%)
  处理进度: 2701/7956行 (33.9%)
  处理进度: 2801/7956行 (3

In [9]:
# 数据填充以收集自变量
import pandas as pd
import numpy as np
import os
import h5py
from datetime import datetime, timedelta
import warnings
warnings.filterwarnings("ignore")

# 全局变量记录插值详细信息
interpolation_details = []

# ====================== 改进的MAT文件读取函数 ======================
def read_hdf5_mat(file_path, expected_keys=None):
    """读取MATLAB v7.3格式的HDF5文件，优先查找特定变量"""
    try:
        with h5py.File(file_path, 'r') as f:
            data = {}
            
            def visitor_func(name, obj):
                if isinstance(obj, h5py.Dataset):
                    if h5py.check_string_dtype(obj.dtype):
                        value = ''.join(chr(c) for c in obj[:])
                    else:
                        value = np.array(obj)
                    if value.ndim >= 2:
                        value = value.T
                    base_name = name.split('/')[-1]
                    data[base_name] = value
            
            f.visititems(visitor_func)
            
            # 优先查找预期变量
            if expected_keys:
                for key in expected_keys:
                    if key in data:
                        return {key: data[key]}
            
            return data
    except Exception as e:
        print(f"  读取HDF5 MAT文件失败: {str(e)}")
        return {}

# ====================== 改进的双线性插值函数 ======================
def bilinear_interpolation_with_details(lat_grid, lon_grid, target_lat, target_lon, grid_data):
    """
    执行双线性插值并记录详细信息
    :param lat_grid: 网格纬度数组 (1D, 从北向南递减)
    :param lon_grid: 网格经度数组 (1D, 从西向东递增)
    :param target_lat: 目标点纬度
    :param target_lon: 目标点经度
    :param grid_data: 网格数据 (2D数组, 形状为(len(lat_grid), len(lon_grid)))
    :return: 插值值
    """
    global interpolation_details
    
    try:
        # 记录网格形状
        grid_shape = grid_data.shape
        
        # 验证网格尺寸
        if len(lat_grid) != grid_shape[0] or len(lon_grid) != grid_shape[1]:
            print(f"警告: 网格尺寸不匹配! 纬度网格: {len(lat_grid)}, 经度网格: {len(lon_grid)}, 数据形状: {grid_shape}")
            return np.nan
        
        # 查找最近的纬度索引（纬度从北向南递减）
        # 纬度网格: 89.95 (北) -> -89.95 (南)
        lat_idx = np.argmin(np.abs(lat_grid - target_lat))
        
        # 查找最近的经度索引（经度从西向东递增）
        # 经度网格: -179.95 (西) -> 179.95 (东)
        lon_idx = np.argmin(np.abs(lon_grid - target_lon))
        
        # 确定四个角点索引
        # 纬度处理：目标点位于两个纬度网格点之间
        if lat_idx == 0:
            lat_idx0, lat_idx1 = 0, 1
        elif lat_idx == len(lat_grid) - 1:
            lat_idx0, lat_idx1 = len(lat_grid) - 2, len(lat_grid) - 1
        else:
            if target_lat > lat_grid[lat_idx]:
                # 目标纬度大于当前网格点纬度（更北）
                if lat_idx > 0:
                    lat_idx0 = lat_idx - 1
                    lat_idx1 = lat_idx
                else:
                    lat_idx0 = lat_idx
                    lat_idx1 = lat_idx
            else:
                # 目标纬度小于当前网格点纬度（更南）
                if lat_idx < len(lat_grid) - 1:
                    lat_idx0 = lat_idx
                    lat_idx1 = lat_idx + 1
                else:
                    lat_idx0 = lat_idx
                    lat_idx1 = lat_idx
        
        # 经度处理：目标点位于两个经度网格点之间
        if lon_idx == 0:
            lon_idx0, lon_idx1 = 0, 1
        elif lon_idx == len(lon_grid) - 1:
            lon_idx0, lon_idx1 = len(lon_grid) - 2, len(lon_grid) - 1
        else:
            if target_lon > lon_grid[lon_idx]:
                # 目标经度大于当前网格点经度（更东）
                if lon_idx < len(lon_grid) - 1:
                    lon_idx0 = lon_idx
                    lon_idx1 = lon_idx + 1
                else:
                    lon_idx0 = lon_idx
                    lon_idx1 = lon_idx
            else:
                # 目标经度小于当前网格点经度（更西）
                if lon_idx > 0:
                    lon_idx0 = lon_idx - 1
                    lon_idx1 = lon_idx
                else:
                    lon_idx0 = lon_idx
                    lon_idx1 = lon_idx
        
        # 获取四个角点值
        Q00 = grid_data[lat_idx0, lon_idx0]
        Q01 = grid_data[lat_idx0, lon_idx1]
        Q10 = grid_data[lat_idx1, lon_idx0]
        Q11 = grid_data[lat_idx1, lon_idx1]
        
        # 四个角点坐标
        y0 = lat_grid[lat_idx0]
        y1 = lat_grid[lat_idx1]
        x0 = lon_grid[lon_idx0]
        x1 = lon_grid[lon_idx1]
        
        # 如果有NaN，使用最接近的点
        if np.isnan(Q00) or np.isnan(Q01) or np.isnan(Q10) or np.isnan(Q11):
            result = grid_data[lat_idx, lon_idx]
            details = {
                'type': 'nearest',
                'row': lat_idx,
                'col': lon_idx,
                'target_lat': target_lat,
                'target_lon': target_lon,
                'grid_shape': grid_shape,
                'values': [grid_data[lat_idx, lon_idx]],
                'lat_values': [lat_grid[lat_idx]],
                'lon_values': [lon_grid[lon_idx]]
            }
        else:
            # 双线性插值公式
            dx = (target_lon - x0) / (x1 - x0) if (x1 - x0) != 0 else 0
            dy = (target_lat - y0) / (y1 - y0) if (y1 - y0) != 0 else 0
            result = (1 - dx) * (1 - dy) * Q00 + dx * (1 - dy) * Q01 + (1 - dx) * dy * Q10 + dx * dy * Q11
            
            details = {
                'type': 'bilinear',
                'rows': [lat_idx0, lat_idx0, lat_idx1, lat_idx1],
                'cols': [lon_idx0, lon_idx1, lon_idx0, lon_idx1],
                'target_lat': target_lat,
                'target_lon': target_lon,
                'grid_shape': grid_shape,
                'values': [Q00, Q01, Q10, Q11],
                'lat_values': [y0, y0, y1, y1],
                'lon_values': [x0, x1, x0, x1]
            }
        
        # 保存插值详细信息
        interpolation_details.append(details)
        return result
    
    except Exception as e:
        print(f"插值错误: {str(e)}")
        return np.nan

# ====================== 主处理函数 ======================
def process_smapvex_data(input_file_path):
    """
    处理SMAPVEX数据，执行多种插值操作
    """
    global interpolation_details
    
    try:
        interpolation_details = []  # 重置插值详情
        
        # ========== 1. 读取原始数据 ==========
        print(f"读取原始Excel文件: {input_file_path}")
        df = pd.read_excel(input_file_path)
        
        # 定义标准经纬度网格 (0.1°分辨率)
        # 纬度: 北纬89.95°(0) -> 南纬-89.95°(1799)
        lat_grid = np.linspace(89.95, -89.95, 1800)
        
        # 经度: -179.95°(0) -> 179.95°(3599)
        lon_grid = np.linspace(-179.95, 179.95, 3600)
        
        print(f"成功读取 {len(df)} 条记录")
        
        # ========== 2. 准备PFT数据 (14个类别) ==========
        pft_file = r"E:\data\ESACCI PFT\Resample\Data\2008.mat"
        if os.path.exists(pft_file):
            print(f"\n处理PFT数据: {pft_file}")
            mat_data = read_hdf5_mat(pft_file)
            
            pft_columns = ['water','bare','snowice','built','grassnat','grassman',
                          'shrubbd','shrubbe','shrubnd','shrubne',
                          'treebd','treebe','treend','treene']
            
            available_pft = [col for col in pft_columns if col in mat_data]
            print(f"  文件中可用的PFT变量: {', '.join(available_pft)}")
            
            # 处理每个可用的PFT类别
            for col in available_pft:
                grid_data = mat_data[col] / 100.0
                df[f'PFT_{col}'] = df.apply(
                    lambda row: bilinear_interpolation_with_details(
                        lat_grid, lon_grid, 
                        row['Latitude'], row['Longitude'], 
                        grid_data
                    ) if not np.isnan(row['Latitude']) and not np.isnan(row['Longitude']) 
                    else np.nan, axis=1
                )
                print(f"  已添加列: PFT_{col}")
        else:
            print(f"\n警告: PFT文件不存在 - {pft_file}")
        
        # ========== 3. 处理VOD数据 (7个变量) ==========
        vod_base_dir = r"E:\data\VOD\mat\kuxcVOD\ASC"
        vod_cols = ['SM','ku_vod_H', 'ku_vod_V', 'x_vod_H','x_vod_V', 'c_vod_H','c_vod_V']
        
        for col in vod_cols:
            df[col] = np.nan
        
        print("\n处理VOD数据:")
        
        # 收集所有唯一日期并排序
        unique_dates = sorted(df['Date'].unique())
        vod_files_found = 0
        
        for date in unique_dates:
            # 转换为字符串格式YYYYMMDD
            try:
                if isinstance(date, pd.Timestamp):
                    date_str = date.strftime("%Y%m%d")
                else:
                    date_str = datetime.strptime(str(date)[:10], "%Y-%m-%d").strftime("%Y%m%d")
            except:
                print(f"  无法解析日期: {date}")
                continue
            
            vod_file = os.path.join(vod_base_dir, f"MCCA_AMSRE_010D_CCXH_VSM_VOD_Asc_{date_str}_V0.nc4.mat")
            if os.path.exists(vod_file):
                vod_files_found += 1
                print(f"  处理日期: {date_str}, 文件: {os.path.basename(vod_file)}")
                vod_data = read_hdf5_mat(vod_file)
                
                for col in vod_cols:
                    if col in vod_data:
                        grid_data = vod_data[col]
                        mask = df['Date'] == date
                        df.loc[mask, col] = df[mask].apply(
                            lambda row: bilinear_interpolation_with_details(
                                lat_grid, lon_grid, 
                                row['Latitude'], row['Longitude'], 
                                grid_data
                            ) if not np.isnan(row['Latitude']) and not np.isnan(row['Longitude']) 
                            else np.nan, axis=1
                        )
                        print(f"    已更新: {col}")
                    else:
                        print(f"    警告: VOD变量 {col} 不存在于文件中")
            else:
                print(f"  警告: VOD文件不存在 - {os.path.basename(vod_file)}")
                
        if vod_files_found == 0:
            print("  警告: 没有找到任何VOD文件，VOD列将保留为空")
        
        # ========== 4. 处理LAI卫星数据 (时间插值) ==========
        print("\n处理LAI卫星数据...")
        df['LAI_Satellite'] = np.nan
        
        # 预期可能的LAI变量名
        expected_lai_keys = ['lai', 'LAI', 'data']
        
        lai_sep_file = r"E:\data\GLASS LAI\mat\0.1Deg\Dataset\2008-09-01.tif.mat"
        lai_oct_file = r"E:\data\GLASS LAI\mat\0.1Deg\Dataset\2008-10-01.tif.mat"
        
        if os.path.exists(lai_sep_file) and os.path.exists(lai_oct_file):
            print(f"  加载9月LAI数据 (视为9月15日): {lai_sep_file}")
            sep_data = read_hdf5_mat(lai_sep_file, expected_keys=expected_lai_keys)
            
            # 直接获取LAI数据
            lai_sep_data = list(sep_data.values())[0] if sep_data else np.zeros((1800, 3600))
            print(f"    读取的数据形状: {lai_sep_data.shape}")
            
            print(f"  加载10月LAI数据 (视为10月15日): {lai_oct_file}")
            oct_data = read_hdf5_mat(lai_oct_file, expected_keys=expected_lai_keys)
            lai_oct_data = list(oct_data.values())[0] if oct_data else np.zeros((1800, 3600))
            print(f"    读取的数据形状: {lai_oct_data.shape}")
            
            # 定义月中日期
            sep_mid_date = datetime(2008, 9, 15)  # 9月15日
            oct_mid_date = datetime(2008, 10, 15)  # 10月15日
            
            total_days = (oct_mid_date - sep_mid_date).days
            
            # 处理每个日期的数据
            for date in unique_dates:
                try:
                    # 确保日期为datetime对象
                    if isinstance(date, pd.Timestamp):
                        date_dt = date.to_pydatetime()
                    elif isinstance(date, str):
                        date_dt = datetime.strptime(date, "%Y-%m-%d")
                    else:
                        date_dt = date
                    
                    # 计算时间权重
                    if date_dt <= sep_mid_date:
                        weight = 0.0
                    elif date_dt >= oct_mid_date:
                        weight = 1.0
                    else:
                        weight = (date_dt - sep_mid_date).days / total_days
                    
                    # 应用时间插值
                    interpolated_lai = (1 - weight) * lai_sep_data + weight * lai_oct_data
                    
                    mask = df['Date'] == date
                    df.loc[mask, 'LAI_Satellite'] = df[mask].apply(
                        lambda row: bilinear_interpolation_with_details(
                            lat_grid, lon_grid, 
                            row['Latitude'], row['Longitude'], 
                            interpolated_lai
                        ) if not np.isnan(row['Latitude']) and not np.isnan(row['Longitude']) 
                        else np.nan, axis=1
                    )
                    
                    # 验证插值结果
                    mean_lai = np.nanmean(interpolated_lai)
                    print(f"  已处理日期: {date_dt.strftime('%Y-%m-%d')}, 权重: {weight:.2f}, 平均LAI: {mean_lai:.4f}")
                except Exception as e:
                    print(f"  处理日期{date}时出错: {str(e)}")
        else:
            print(f"  警告: LAI文件不存在 - {' 或 '.join([lai_sep_file, lai_oct_file])}")
            print("  将使用固定值0作为LAI卫星数据")
            df['LAI_Satellite'] = 0.0
        
        # ========== 5. 处理植被高度数据 ==========
        print("\n处理植被高度数据...")
        df['Hveg'] = np.nan
        
        hveg_file = r"E:\data\CanopyHeight\CH.mat"
        if os.path.exists(hveg_file):
            print(f"  加载植被高度数据: {hveg_file}")
            hveg_data = read_hdf5_mat(hveg_file, expected_keys=['CH', 'ch'])
            
            # 直接获取高度数据
            hveg_key = list(hveg_data.keys())[0] if hveg_data else None
            
            if hveg_key:
                hveg_values = hveg_data[hveg_key]
                df['Hveg'] = df.apply(
                    lambda row: bilinear_interpolation_with_details(
                        lat_grid, lon_grid, 
                        row['Latitude'], row['Longitude'], 
                        hveg_values
                    ) if not np.isnan(row['Latitude']) and not np.isnan(row['Longitude']) 
                    else np.nan, axis=1
                )
                print(f"  已添加植被高度列，数据形状: {hveg_values.shape}")
            else:
                print(f"  警告: 无法找到Hveg变量")
                df['Hveg'] = np.nan
        else:
            print(f"  警告: Hveg文件不存在 - {hveg_file}")
            df['Hveg'] = np.nan
        
        # ========== 6. 保存结果 ==========
        output_file_path = r"E:\data\VWC\test-VWC\Insitu SMEX08\processed_SV08V_ML.xlsx"
        print(f"\n保存结果到: {output_file_path}")
        df.to_excel(output_file_path, index=False)
        
        # 保存插值详细信息到Excel
        if interpolation_details:
            details_df = pd.DataFrame(interpolation_details)
            details_path = r"E:\data\VWC\test-VWC\Insitu SMEX08\interpolation_details.xlsx"
            details_df.to_excel(details_path, index=False)
            print(f"插值详细信息保存到: {details_path}")
        else:
            print("警告: 没有插值详细信息可保存")
        
        # ========== 7. 统计报告 ==========
        print("\n处理完成!")
        print(f"总记录数: {len(df)}")
        print(f"插值操作次数: {len(interpolation_details)}")
        
        if interpolation_details:
            # 显示前3次插值的详细信息
            print("\n前3次插值的详细信息:")
            for i, detail in enumerate(interpolation_details[:3]):
                print(f"\n插值 #{i+1}")
                print(f"  类型: {detail['type']}")
                print(f"  目标位置: ({detail['target_lat']:.6f}, {detail['target_lon']:.6f})")
                print(f"  网格形状: {detail['grid_shape']}")
                
                if detail['type'] == 'bilinear':
                    print(f"  使用的4个网格点:")
                    for j in range(4):
                        print(f"    点{j+1}: 行 {detail['rows'][j]}, 列 {detail['cols'][j]} - " +
                              f"位置 ({detail['lat_values'][j]:.6f}, {detail['lon_values'][j]:.6f}) - " +
                              f"值: {detail['values'][j]:.6f}")
                else:
                    print(f"  最近邻点: 行 {detail['row']}, 列 {detail['col']} - " +
                          f"位置 ({detail['lat_values'][0]:.6f}, {detail['lon_values'][0]:.6f}) - " +
                          f"值: {detail['values'][0]:.6f}")
        
        return True
        
    except Exception as e:
        print(f"处理过程中出错: {str(e)}")
        import traceback
        print("错误详细信息:")
        print(traceback.format_exc())
        return False

# ========================== 主程序 ==========================
if __name__ == "__main__":
    # 输入文件路径
    input_file = r"E:\data\VWC\test-VWC\Insitu SMEX08\processed_SV08V_Sum_VEG_SMAPVEX.xlsx"
    
    print("="*60)
    print("开始处理SMAPVEX数据插值任务")
    print("="*60)
    
    if not os.path.exists(input_file):
        print(f"错误: 输入文件不存在 - {input_file}")
        print(f"请检查路径: {os.path.abspath(input_file)}")
    else:
        print(f"输入文件: {input_file}")
        print(f"输出将保存到: E:\\data\\VWC\\test-VWC\\Insitu SMEX08\\processed_SV08V_ML.xlsx")
        
        success = process_smapvex_data(input_file)
        if success:
            print("\n" + "="*30)
            print("任务成功完成!")
            print("="*30)
        else:
            print("\n" + "="*30)
            print("任务失败，请检查错误信息")
            print("="*30)

开始处理SMAPVEX数据插值任务
输入文件: E:\data\VWC\test-VWC\Insitu SMEX08\processed_SV08V_Sum_VEG_SMAPVEX.xlsx
输出将保存到: E:\data\VWC\test-VWC\Insitu SMEX08\processed_SV08V_ML.xlsx
读取原始Excel文件: E:\data\VWC\test-VWC\Insitu SMEX08\processed_SV08V_Sum_VEG_SMAPVEX.xlsx
成功读取 10 条记录

处理PFT数据: E:\data\ESACCI PFT\Resample\Data\2008.mat
  文件中可用的PFT变量: water, bare, snowice, built, grassnat, grassman, shrubbd, shrubbe, shrubnd, shrubne, treebd, treebe, treend, treene
  已添加列: PFT_water
  已添加列: PFT_bare
  已添加列: PFT_snowice
  已添加列: PFT_built
  已添加列: PFT_grassnat
  已添加列: PFT_grassman
  已添加列: PFT_shrubbd
  已添加列: PFT_shrubbe
  已添加列: PFT_shrubnd
  已添加列: PFT_shrubne
  已添加列: PFT_treebd
  已添加列: PFT_treebe
  已添加列: PFT_treend
  已添加列: PFT_treene

处理VOD数据:
  处理日期: 20081002, 文件: MCCA_AMSRE_010D_CCXH_VSM_VOD_Asc_20081002_V0.nc4.mat
    已更新: SM
    已更新: ku_vod_H
    已更新: ku_vod_V
    已更新: x_vod_H
    已更新: x_vod_V
    已更新: c_vod_H
    已更新: c_vod_V
  处理日期: 20081004, 文件: MCCA_AMSRE_010D_CCXH_VSM_VOD_Asc_20081004_V0.nc4.mat
    已更新: S

In [21]:
# 机器学习结果填充，和实测值对比
import pandas as pd
import numpy as np
import os
import joblib
import matplotlib.pyplot as plt
import matplotlib
from sklearn.metrics import mean_squared_error, r2_score

# 设置全局字体
plt.rcParams['font.family'] = 'Arial'
plt.rcParams['font.weight'] = 'bold'

# 设置常量
TEST_FILE = r"E:\data\VWC\test-VWC\Insitu SMEX08\processed_SV08V_ML.xlsx"
MODEL_DIR = "models"
SAVE_RESULTS = "model_predictions_results_SMEX08.xlsx"
FIG_DIR = "figures"

# 定义波段和极化组合
BANDS = ['Ku', 'C', 'X']
POLS = ['H', 'V', 'HV']

# 波段颜色定义
BAND_COLORS = {
    'Ku': (253/255, 173/255, 115/255, 0.7),
    'C': (178/255, 125/255, 104/255, 0.7),
    'X': (224/255, 104/255, 46/255, 0.7)
}

# 极化类型标记定义
POL_MARKERS = {
    'H': 's',  # 方形
    'V': '^',  # 三角形
    'HV': 'o'  # 圆形
}

def normalize_LAI(lai_series):
    """对LAI进行归一化"""
    return lai_series.clip(0, 6) / 6

def normalize_VOD(vod_series):
    """对VOD进行归一化"""
    return vod_series.clip(0, 2) / 2

# PFT列名映射字典
PFT_MAPPING = {
    'PFT_grassnat': 'Grass_nat',
    'PFT_grassman': 'Grass_man',
    'PFT_shrubbd': 'Shrub_bd',
    'PFT_shrubbe': 'Shrub_be',
    'PFT_shrubnd': 'Shrub_nd',
    'PFT_shrubne': 'Shrub_ne',
    'PFT_treebd': 'Tree_bd',
    'PFT_treebe': 'Tree_be',
    'PFT_treend': 'Tree_nd',
    'PFT_treene': 'Tree_ne'
}

def get_model_columns(band, pol):
    """获取指定模型所需的列名"""
    base_columns = [
        'VWC',  # 实际值
        'LAI',  # LAI
        'SM'  # 土壤湿度
    ]
    
    # 添加所有PFT列
    base_columns.extend(PFT_MAPPING.keys())
    
    # 根据极化类型添加VOD列
    if pol == 'H':
        return base_columns + [f'{band.lower()}_vod_H']
    elif pol == 'V':
        return base_columns + [f'{band.lower()}_vod_V']
    elif pol == 'HV':
        return base_columns + [f'{band.lower()}_vod_H', f'{band.lower()}_vod_V']

def get_feature_order(pol):
    """获取特征列的顺序（模型期望的列顺序）"""
    base_features = [
        'LAI', 'SM',
        'Grass_man', 'Grass_nat',
        'Shrub_bd', 'Shrub_be', 'Shrub_nd', 'Shrub_ne',
        'Tree_bd', 'Tree_be', 'Tree_nd', 'Tree_ne'
    ]
    
    if pol in ['H', 'V']:
        return ['VOD'] + base_features
    elif pol == 'HV':
        return ['VOD-Hpol', 'VOD-Vpol'] + base_features

def prepare_input_data(df, band, pol):
    """为指定模型准备输入数据"""
    # 创建数据副本
    data = df.copy()
    
    # 1. 归一化处理
    data['LAI'] = normalize_LAI(data['LAI'])
    
    # 2. 处理VOD列
    if pol == 'H':
        vod_col = f'{band.lower()}_vod_H'
        data['VOD'] = normalize_VOD(data[vod_col])
    elif pol == 'V':
        vod_col = f'{band.lower()}_vod_V'
        data['VOD'] = normalize_VOD(data[vod_col])
    elif pol == 'HV':
        # 重命名列以匹配模型训练时的特征名
        data = data.rename(columns={
            f'{band.lower()}_vod_H': 'VOD-Hpol',
            f'{band.lower()}_vod_V': 'VOD-Vpol'
        })
        # 归一化处理
        data['VOD-Hpol'] = normalize_VOD(data['VOD-Hpol'])
        data['VOD-Vpol'] = normalize_VOD(data['VOD-Vpol'])
    
    # 3. 重命名PFT列为模型期望的名称
    data = data.rename(columns=PFT_MAPPING)
    
    # 4. 按模型要求排序特征列
    feature_order = get_feature_order(pol)
    
    return data[feature_order]

def plot_combined_scatter(actual, predictions_dict):
    """
    绘制组合散点图，包含所有波段和极化类型
    
    参数:
    actual -- 实际值 (Series)
    predictions_dict -- 字典结构: {
        'H': {band: pred_series},
        'V': {band: pred_series},
        'HV': {band: pred_series}
    }
    """
    # 创建图形
    plt.figure(figsize=(10, 10))
    ax = plt.gca()
    
    # 存储所有组合的RMSE值
    rmse_values = {}
    
    # 收集所有数据点
    max_val = 0
    
    # 遍历所有波段和极化组合
    for band in BANDS:
        for pol in POLS:
            pred_series = predictions_dict[pol].get(band)
            
            if pred_series is not None and not pred_series.isnull().all():
                # 创建实际值和预测值的临时DF
                temp_df = pd.DataFrame({
                    'actual': actual,
                    'pred': pred_series
                }).dropna()
                
                if not temp_df.empty:
                    # 计算RMSE
                    rmse = np.sqrt(mean_squared_error(temp_df['actual'], temp_df['pred']))
                    rmse_values[f"{band}-{pol}"] = rmse
                    
                    # 更新最大值
                    band_max = max(temp_df['actual'].max(), temp_df['pred'].max())
                    if band_max > max_val:
                        max_val = band_max
                    
                    # 绘制散点
                    plt.scatter(
                        temp_df['actual'], temp_df['pred'], 
                        alpha=0.7, 
                        color=BAND_COLORS[band],
                        marker=POL_MARKERS[pol],
                        s=50,
                        edgecolors='none',
                        zorder=2,
                        label=f"{band}-{pol}"
                    )
    
    # 如果没有数据可绘制，直接返回
    if not rmse_values:
        print("  警告: 没有有效的预测数据!")
        plt.close()
        return
    
    # 添加1:1参考线
    max_val *= 1.05
    plt.plot([0, max_val], [0, max_val], 'k--', lw=1.5, alpha=0.7, zorder=1)
    
    # 设置坐标轴范围
    plt.xlim(0, max_val)
    plt.ylim(0, max_val)
    
    # 设置坐标轴标签
    plt.xlabel('Insitu VWC (kg/m²)', fontsize=14, fontweight='bold')
    plt.ylabel('RF VWC (kg/m²)', fontsize=14, fontweight='bold')
    
    # 设置标题
    plt.title('SMEX08 Insitu VWC', 
             fontsize=18, fontweight='bold', pad=20)
    
    # # 添加图例
    # plt.legend(loc='lower right', frameon=True, fontsize=10, ncol=3)
    
    # 添加RMSE文本（左上角，3×3网格布局）
    if rmse_values:
        # 设置文本位置
        x_pos = 0.05
        y_pos = 0.95
        
        # 添加标题
        plt.text(x_pos, y_pos, 'RMSE (kg/m²):', 
                 transform=ax.transAxes,
                 fontsize=12,
                 fontweight='bold',
                 verticalalignment='top')
        
        y_pos -= 0.05
        
        # 遍历每个波段
        for band_idx, band in enumerate(BANDS):
            # 遍历每个极化类型
            for pol_idx, pol in enumerate(POLS):
                # 计算位置
                text_x = x_pos + pol_idx * 0.15
                text_y = y_pos - band_idx * 0.08
                
                # 获取RMSE值
                rmse = rmse_values.get(f"{band}-{pol}", None)
                
                if rmse is not None:
                    # 绘制标记
                    plt.scatter(
                        text_x, text_y, 
                        transform=ax.transAxes,
                        marker=POL_MARKERS[pol],
                        color=BAND_COLORS[band],
                        s=80,
                        alpha=0.7
                    )
                    
                    # 添加文本
                    plt.text(
                        text_x + 0.01, text_y, 
                        f"{band}-{pol}: {rmse:.3f}", 
                        transform=ax.transAxes,
                        fontsize=10,
                        fontweight='bold',
                        verticalalignment='center'
                    )
                else:
                    # 添加缺失值标记
                    plt.text(
                        text_x, text_y, 
                        f"{band}-{pol}: N/A", 
                        transform=ax.transAxes,
                        fontsize=10,
                        fontweight='bold',
                        verticalalignment='center',
                        color='gray'
                    )
    
    # 添加网格线
    plt.grid(True, linestyle='--', alpha=0.3, zorder=0)
    
    # 调整布局
    plt.tight_layout()
    
    # 创建保存目录
    os.makedirs(FIG_DIR, exist_ok=True)
    
    # 保存图像
    fig_path = os.path.join(FIG_DIR, 'SMEX08_VWC_Scatter.png')
    plt.savefig(fig_path, dpi=300, bbox_inches='tight')
    print(f"  组合散点图已保存至: {fig_path}")
    plt.close()

def predict_and_evaluate():
    """主函数：加载所有模型进行预测并评估结果"""
    # 1. 加载测试数据
    print(f"正在加载测试数据: {TEST_FILE}")
    
    # 收集所有可能的列
    all_columns = set(['VWC', 'LAI', 'SM'])
    # 添加所有PFT列
    all_columns.update(PFT_MAPPING.keys())
    # 添加所有VOD列
    for band in BANDS:
        all_columns.add(f'{band.lower()}_vod_H')
        all_columns.add(f'{band.lower()}_vod_V')
    
    # 读取Excel文件
    test_df = pd.read_excel(TEST_FILE, usecols=list(all_columns))
    print(f"加载完成，总样本数: {len(test_df)}")
    
    # 存储所有预测结果
    results = pd.DataFrame(index=test_df.index)
    results['Actual_VWC'] = test_df['VWC']
    
    # 为每个极化类型存储预测结果
    predictions_by_pol = {
        'H': {band: None for band in BANDS},
        'V': {band: None for band in BANDS},
        'HV': {band: None for band in BANDS}
    }
    
    # 2. 对每个模型进行预测
    for band in BANDS:
        for pol in POLS:
            model_name = f"RFR_{band}_{pol}pol_Type1.pkl"
            model_path = os.path.join(MODEL_DIR, model_name)
            
            print(f"\n处理 {band}-{pol} 模型: {model_name}")
            
            # 准备输入数据
            model_cols = get_model_columns(band, pol)
            model_data = test_df[model_cols].copy()
            
            # 删除缺失值
            clean_data = model_data.dropna()
            print(f"  有效样本数: {len(clean_data)} (删除缺失值后)")
            
            if len(clean_data) == 0:
                print("  警告: 无有效样本可用于此模型!")
                results[f"{band}_{pol}_Predicted"] = np.nan
                predictions_by_pol[pol][band] = None
                continue
            
            # 预处理输入数据
            try:
                X_input = prepare_input_data(clean_data, band, pol)
                
                # 加载模型并进行预测
                if os.path.exists(model_path):
                    model = joblib.load(model_path)
                    predictions = model.predict(X_input)
                    
                    # 存储预测结果
                    results[f"{band}_{pol}_Predicted"] = np.nan
                    results.loc[clean_data.index, f"{band}_{pol}_Predicted"] = predictions
                    
                    # 存储到对应极化类型的字典
                    predictions_by_pol[pol][band] = results[f"{band}_{pol}_Predicted"].copy()
                    
                    # 计算评估指标
                    actual = clean_data['VWC']
                    rmse = np.sqrt(mean_squared_error(actual, predictions))
                    r2 = r2_score(actual, predictions)
                    print(f"  预测完成 - RMSE: {rmse:.4f}, R²: {r2:.4f}")
                    
                else:
                    print(f"  警告: 未找到模型文件 {model_path}!")
                    results[f"{band}_{pol}_Predicted"] = np.nan
                    predictions_by_pol[pol][band] = None
            except Exception as e:
                import traceback
                print(f"  预测失败: {str(e)}")
                # 打印更详细的错误信息
                traceback.print_exc()
                results[f"{band}_{pol}_Predicted"] = np.nan
                predictions_by_pol[pol][band] = None
    
    # 3. 保存结果
    results.to_excel(SAVE_RESULTS)
    print(f"\n所有预测结果已保存至: {SAVE_RESULTS}")
    
    # 4. 绘制组合散点图
    print("\n正在绘制组合散点图...")
    plot_combined_scatter(
        results['Actual_VWC'], 
        predictions_by_pol
    )
    
    return results, predictions_by_pol

# 执行主函数
if __name__ == "__main__":
    results, predictions_by_pol = predict_and_evaluate()
    print("\n所有处理完成!")

正在加载测试数据: E:\data\VWC\test-VWC\Insitu SMEX08\processed_SV08V_ML.xlsx
加载完成，总样本数: 10

处理 Ku-H 模型: RFR_Ku_Hpol_Type1.pkl
  有效样本数: 10 (删除缺失值后)
  预测完成 - RMSE: 3.3100, R²: -4.9195

处理 Ku-V 模型: RFR_Ku_Vpol_Type1.pkl
  有效样本数: 10 (删除缺失值后)
  预测完成 - RMSE: 3.9155, R²: -7.2834

处理 Ku-HV 模型: RFR_Ku_HVpol_Type1.pkl
  有效样本数: 10 (删除缺失值后)
  预测完成 - RMSE: 3.3757, R²: -5.1570

处理 C-H 模型: RFR_C_Hpol_Type1.pkl
  有效样本数: 10 (删除缺失值后)
  预测完成 - RMSE: 3.8144, R²: -6.8610

处理 C-V 模型: RFR_C_Vpol_Type1.pkl
  有效样本数: 10 (删除缺失值后)
  预测完成 - RMSE: 3.9438, R²: -7.4035

处理 C-HV 模型: RFR_C_HVpol_Type1.pkl
  有效样本数: 10 (删除缺失值后)
  预测完成 - RMSE: 4.0439, R²: -7.8357

处理 X-H 模型: RFR_X_Hpol_Type1.pkl
  有效样本数: 10 (删除缺失值后)
  预测完成 - RMSE: 3.3705, R²: -5.1378

处理 X-V 模型: RFR_X_Vpol_Type1.pkl
  有效样本数: 10 (删除缺失值后)
  预测完成 - RMSE: 3.6296, R²: -6.1180

处理 X-HV 模型: RFR_X_HVpol_Type1.pkl
  有效样本数: 10 (删除缺失值后)
  预测完成 - RMSE: 3.5702, R²: -5.8867

所有预测结果已保存至: model_predictions_results_SMEX08.xlsx

正在绘制组合散点图...
  组合散点图已保存至: figures\SMEX08_VWC_Scatter

# 4.NSIDC（SMEX02）

In [29]:
# CLASIC07 050（没有保留的数据行，不使用）
import pandas as pd
import numpy as np
from pyproj import Transformer

# 文件路径
file_path = r"E:\data\VWC\test-VWC\NSIDC_0666\NSIDC0666_matchup_pals_grid_v107_111012_CLASIC07_050.txt"

# 读取数据文件
df = pd.read_csv(file_path, sep='\s+', header=0, na_values=['NaN'])

# 创建日期列 (YYYYMMDD格式)
df['Date'] = df.apply(
    lambda row: f"{int(row['Year'])}{int(row['Month']):02d}{int(row['Day']):02d}", 
    axis=1
)

# 创建UTM到WGS84转换器 (UTM zone 14N - EPSG:32614)
transformer = Transformer.from_crs("EPSG:32614", "EPSG:4326", always_xy=True)

# 坐标转换函数
def convert_utm_to_wgs84(easting, northing):
    try:
        lon, lat = transformer.transform(easting, northing)
        return pd.Series([lon, lat])
    except Exception as e:
        return pd.Series([np.nan, np.nan])

# 应用坐标转换
df[['Longitude', 'Latitude']] = df.apply(
    lambda row: convert_utm_to_wgs84(row['UTM-E'], row['UTM-N']), 
    axis=1
)

print(f"移除前数据行数: {len(df)}")
df = df.dropna(subset=['VWC-Field'])
print(f"移除后数据行数: {len(df)}")

# 打印包含日期列的前5行数据
print("包含日期列的前5行数据：")
print(df[['Date', 'Year', 'Month', 'Day', 'UTM-E', 'UTM-N', 'Longitude', 'Latitude']].head(5).to_string(index=False))
print("\n" + "="*80 + "\n")

# 打印统计信息
print(f"文件总行数: {len(df)}")
print(f"日期范围: {df['Date'].min()} 到 {df['Date'].max()}")
print(f"经纬度范围:")
print(f"  经度: {df['Longitude'].min():.6f}° 到 {df['Longitude'].max():.6f}°")
print(f"  纬度: {df['Latitude'].min():.6f}° 到 {df['Latitude'].max():.6f}°")

移除前数据行数: 1400
移除后数据行数: 0
包含日期列的前5行数据：
Empty DataFrame
Columns: [Date, Year, Month, Day, UTM-E, UTM-N, Longitude, Latitude]
Index: []


文件总行数: 0
日期范围: nan 到 nan
经纬度范围:
  经度: nan° 到 nan°
  纬度: nan° 到 nan°


In [30]:
# CLASIC07 060（没有保留的数据行，不使用）
import pandas as pd
import numpy as np
from pyproj import Transformer

# 文件路径
file_path = r"E:\data\VWC\test-VWC\NSIDC_0666\NSIDC0666_matchup_pals_grid_v107_111012_CLASIC07_060.txt"

# 读取数据文件
df = pd.read_csv(file_path, sep='\s+', header=0, na_values=['NaN'])

# 创建日期列 (YYYYMMDD格式)
df['Date'] = df.apply(
    lambda row: f"{int(row['Year'])}{int(row['Month']):02d}{int(row['Day']):02d}", 
    axis=1
)

# 创建UTM到WGS84转换器 (UTM zone 14N - EPSG:32614)
transformer = Transformer.from_crs("EPSG:32614", "EPSG:4326", always_xy=True)

# 坐标转换函数
def convert_utm_to_wgs84(easting, northing):
    try:
        lon, lat = transformer.transform(easting, northing)
        return pd.Series([lon, lat])
    except Exception as e:
        return pd.Series([np.nan, np.nan])

# 应用坐标转换
df[['Longitude', 'Latitude']] = df.apply(
    lambda row: convert_utm_to_wgs84(row['UTM-E'], row['UTM-N']), 
    axis=1
)

print(f"移除前数据行数: {len(df)}")
df = df.dropna(subset=['VWC-Field'])
print(f"移除后数据行数: {len(df)}")

# 打印包含日期列的前5行数据
print("包含日期列的前5行数据：")
print(df[['Date', 'Year', 'Month', 'Day', 'UTM-E', 'UTM-N', 'Longitude', 'Latitude']].head(5).to_string(index=False))
print("\n" + "="*80 + "\n")

# 打印统计信息
print(f"文件总行数: {len(df)}")
print(f"日期范围: {df['Date'].min()} 到 {df['Date'].max()}")
print(f"经纬度范围:")
print(f"  经度: {df['Longitude'].min():.6f}° 到 {df['Longitude'].max():.6f}°")
print(f"  纬度: {df['Latitude'].min():.6f}° 到 {df['Latitude'].max():.6f}°")

移除前数据行数: 4032
移除后数据行数: 0
包含日期列的前5行数据：
Empty DataFrame
Columns: [Date, Year, Month, Day, UTM-E, UTM-N, Longitude, Latitude]
Index: []


文件总行数: 0
日期范围: nan 到 nan
经纬度范围:
  经度: nan° 到 nan°
  纬度: nan° 到 nan°


In [31]:
# SMAPVEX08（保留的数据行很少，不使用）
import pandas as pd
import numpy as np
from pyproj import Transformer

# 文件路径
file_path = r"E:\data\VWC\test-VWC\NSIDC_0666\NSIDC0666_matchup_pals_grid_v107_111012_SMAPVEX08.txt"

# 读取数据文件
df = pd.read_csv(file_path, sep='\s+', header=0, na_values=['NaN'])

# 创建日期列 (YYYYMMDD格式)
df['Date'] = df.apply(
    lambda row: f"{int(row['Year'])}{int(row['Month']):02d}{int(row['Day']):02d}", 
    axis=1
)

# 创建UTM到WGS84转换器 (UTM zone 18N - EPSG:32618)
transformer = Transformer.from_crs("EPSG:32618", "EPSG:4326", always_xy=True)

# 坐标转换函数
def convert_utm_to_wgs84(easting, northing):
    try:
        lon, lat = transformer.transform(easting, northing)
        return pd.Series([lon, lat])
    except Exception as e:
        return pd.Series([np.nan, np.nan])

# 应用坐标转换
df[['Longitude', 'Latitude']] = df.apply(
    lambda row: convert_utm_to_wgs84(row['UTM-E'], row['UTM-N']), 
    axis=1
)

print(f"移除前数据行数: {len(df)}")
df = df.dropna(subset=['VWC-Field'])
print(f"移除后数据行数: {len(df)}")

# 打印包含日期列的前5行数据
print("包含日期列的前5行数据：")
print(df[['Date', 'Year', 'Month', 'Day', 'UTM-E', 'UTM-N', 'Longitude', 'Latitude']].head(5).to_string(index=False))
print("\n" + "="*80 + "\n")

# 打印统计信息
print(f"文件总行数: {len(df)}")
print(f"日期范围: {df['Date'].min()} 到 {df['Date'].max()}")
print(f"经纬度范围:")
print(f"  经度: {df['Longitude'].min():.6f}° 到 {df['Longitude'].max():.6f}°")
print(f"  纬度: {df['Latitude'].min():.6f}° 到 {df['Latitude'].max():.6f}°")

移除前数据行数: 9800
移除后数据行数: 5
包含日期列的前5行数据：
    Date  Year  Month  Day    UTM-E     UTM-N  Longitude  Latitude
20081002  2008     10    2 413099.2 4319526.1 -76.003847 39.020458
20081002  2008     10    2 418699.2 4316326.1 -75.938785 38.992164
20081004  2008     10    4 416299.2 4315526.1 -75.966397 38.984729
20081006  2008     10    6 412299.2 4317926.1 -76.012882 39.005962
20081006  2008     10    6 414699.2 4317926.1 -75.985167 39.006199


文件总行数: 5
日期范围: 20081002 到 20081006
经纬度范围:
  经度: -76.012882° 到 -75.938785°
  纬度: 38.984729° 到 39.020458°


In [4]:
# SMEX02
import pandas as pd
import numpy as np
from pyproj import Transformer
import os

# 文件路径
file_path = r"E:\data\VWC\test-VWC\NSIDC_0666\NSIDC0666_matchup_pals_grid_v107_111012_SMEX02.txt"

# 读取数据文件
df = pd.read_csv(file_path, sep='\s+', header=0, na_values=['NaN'])

# 创建日期列 (YYYYMMDD格式)
df['Date'] = pd.to_datetime(df.apply(
    lambda row: f"{int(row['Year'])}{int(row['Month']):02d}{int(row['Day']):02d}", 
    axis=1
), format='%Y%m%d')

# 创建UTM到WGS84转换器 (UTM zone 15N - EPSG:32615)
transformer = Transformer.from_crs("EPSG:32615", "EPSG:4326", always_xy=True)

# 坐标转换函数
def convert_utm_to_wgs84(easting, northing):
    try:
        lon, lat = transformer.transform(easting, northing)
        return pd.Series([lon, lat])
    except Exception as e:
        return pd.Series([np.nan, np.nan])

# 应用坐标转换
df[['Longitude', 'Latitude']] = df.apply(
    lambda row: convert_utm_to_wgs84(row['UTM-E'], row['UTM-N']), 
    axis=1
)

print(f"移除前数据行数: {len(df)}")
df = df.dropna(subset=['VWC-Field'])
print(f"移除后数据行数: {len(df)}")

# 打印包含日期列的前5行数据
print("包含日期列的前5行数据：")
print(df[['Date', 'Year', 'Month', 'Day', 'UTM-E', 'UTM-N', 'Longitude', 'Latitude']].head(5).to_string(index=False))
print("\n" + "="*80 + "\n")

# 打印统计信息
print(f"文件总行数: {len(df)}")
print(f"日期范围: {df['Date'].min().strftime('%Y-%m-%d')} 到 {df['Date'].max().strftime('%Y-%m-%d')}")
print(f"经纬度范围:")
print(f"  经度: {df['Longitude'].min():.6f}° 到 {df['Longitude'].max():.6f}°")
print(f"  纬度: {df['Latitude'].min():.6f}° 到 {df['Latitude'].max():.6f}°")

# 保存为xlsx文件
output_path = r"E:\data\VWC\test-VWC\NSIDC_0666\SMEX02\processed_SMEX02V.xlsx"

# 创建输出目录（如果不存在）
os.makedirs(os.path.dirname(output_path), exist_ok=True)

# 保存DataFrame到Excel文件
df.to_excel(
    output_path, 
    index=False,
    sheet_name='SMEX02_VWC',
    engine='openpyxl'
)

print(f"处理后的数据已保存到: {output_path}")

移除前数据行数: 3440
移除后数据行数: 104
包含日期列的前5行数据：
      Date  Year  Month  Day    UTM-E     UTM-N  Longitude  Latitude
2002-06-25  2002      6   25 437375.2 4642453.2 -93.755366 41.931559
2002-06-25  2002      6   25 437375.2 4643253.2 -93.755451 41.938764
2002-06-25  2002      6   25 437375.2 4646453.2 -93.755791 41.967583
2002-06-25  2002      6   25 437375.2 4648053.2 -93.755962 41.981992
2002-06-25  2002      6   25 438975.2 4645653.2 -93.736400 41.960503


文件总行数: 104
日期范围: 2002-06-25 到 2002-07-08
经纬度范围:
  经度: -93.755962° 到 -93.688057°
  纬度: 41.924601° 到 41.982417°
处理后的数据已保存到: E:\data\VWC\test-VWC\NSIDC_0666\SMEX02\processed_SMEX02V.xlsx


In [8]:
# 数据填充以收集自变量
import pandas as pd
import numpy as np
import os
import h5py
from datetime import datetime, timedelta
import warnings
warnings.filterwarnings("ignore")

# 全局变量记录插值详细信息
interpolation_details = []

# ====================== 改进的MAT文件读取函数 ======================
def read_hdf5_mat(file_path, expected_keys=None):
    """读取MATLAB v7.3格式的HDF5文件，优先查找特定变量"""
    try:
        with h5py.File(file_path, 'r') as f:
            data = {}
            
            def visitor_func(name, obj):
                if isinstance(obj, h5py.Dataset):
                    if h5py.check_string_dtype(obj.dtype):
                        value = ''.join(chr(c) for c in obj[:])
                    else:
                        value = np.array(obj)
                    if value.ndim >= 2:
                        value = value.T
                    base_name = name.split('/')[-1]
                    data[base_name] = value
            
            f.visititems(visitor_func)
            
            # 优先查找预期变量
            if expected_keys:
                for key in expected_keys:
                    if key in data:
                        return {key: data[key]}
            
            return data
    except Exception as e:
        print(f"  读取HDF5 MAT文件失败: {str(e)}")
        return {}

# ====================== 改进的双线性插值函数 ======================
def bilinear_interpolation_with_details(lat_grid, lon_grid, target_lat, target_lon, grid_data):
    """
    执行双线性插值并记录详细信息
    :param lat_grid: 网格纬度数组 (1D, 从北向南递减)
    :param lon_grid: 网格经度数组 (1D, 从西向东递增)
    :param target_lat: 目标点纬度
    :param target_lon: 目标点经度
    :param grid_data: 网格数据 (2D数组, 形状为(len(lat_grid), len(lon_grid)))
    :return: 插值值
    """
    global interpolation_details
    
    try:
        # 记录网格形状
        grid_shape = grid_data.shape
        
        # 验证网格尺寸
        if len(lat_grid) != grid_shape[0] or len(lon_grid) != grid_shape[1]:
            print(f"警告: 网格尺寸不匹配! 纬度网格: {len(lat_grid)}, 经度网格: {len(lon_grid)}, 数据形状: {grid_shape}")
            return np.nan
        
        # 查找最近的纬度索引（纬度从北向南递减）
        # 纬度网格: 89.95 (北) -> -89.95 (南)
        lat_idx = np.argmin(np.abs(lat_grid - target_lat))
        
        # 查找最近的经度索引（经度从西向东递增）
        # 经度网格: -179.95 (西) -> 179.95 (东)
        lon_idx = np.argmin(np.abs(lon_grid - target_lon))
        
        # 确定四个角点索引
        # 纬度处理：目标点位于两个纬度网格点之间
        if lat_idx == 0:
            lat_idx0, lat_idx1 = 0, 1
        elif lat_idx == len(lat_grid) - 1:
            lat_idx0, lat_idx1 = len(lat_grid) - 2, len(lat_grid) - 1
        else:
            if target_lat > lat_grid[lat_idx]:
                # 目标纬度大于当前网格点纬度（更北）
                if lat_idx > 0:
                    lat_idx0 = lat_idx - 1
                    lat_idx1 = lat_idx
                else:
                    lat_idx0 = lat_idx
                    lat_idx1 = lat_idx
            else:
                # 目标纬度小于当前网格点纬度（更南）
                if lat_idx < len(lat_grid) - 1:
                    lat_idx0 = lat_idx
                    lat_idx1 = lat_idx + 1
                else:
                    lat_idx0 = lat_idx
                    lat_idx1 = lat_idx
        
        # 经度处理：目标点位于两个经度网格点之间
        if lon_idx == 0:
            lon_idx0, lon_idx1 = 0, 1
        elif lon_idx == len(lon_grid) - 1:
            lon_idx0, lon_idx1 = len(lon_grid) - 2, len(lon_grid) - 1
        else:
            if target_lon > lon_grid[lon_idx]:
                # 目标经度大于当前网格点经度（更东）
                if lon_idx < len(lon_grid) - 1:
                    lon_idx0 = lon_idx
                    lon_idx1 = lon_idx + 1
                else:
                    lon_idx0 = lon_idx
                    lon_idx1 = lon_idx
            else:
                # 目标经度小于当前网格点经度（更西）
                if lon_idx > 0:
                    lon_idx0 = lon_idx - 1
                    lon_idx1 = lon_idx
                else:
                    lon_idx0 = lon_idx
                    lon_idx1 = lon_idx
        
        # 获取四个角点值
        Q00 = grid_data[lat_idx0, lon_idx0]
        Q01 = grid_data[lat_idx0, lon_idx1]
        Q10 = grid_data[lat_idx1, lon_idx0]
        Q11 = grid_data[lat_idx1, lon_idx1]
        
        # 四个角点坐标
        y0 = lat_grid[lat_idx0]
        y1 = lat_grid[lat_idx1]
        x0 = lon_grid[lon_idx0]
        x1 = lon_grid[lon_idx1]
        
        # 如果有NaN，使用最接近的点
        if np.isnan(Q00) or np.isnan(Q01) or np.isnan(Q10) or np.isnan(Q11):
            result = grid_data[lat_idx, lon_idx]
            details = {
                'type': 'nearest',
                'row': lat_idx,
                'col': lon_idx,
                'target_lat': target_lat,
                'target_lon': target_lon,
                'grid_shape': grid_shape,
                'values': [grid_data[lat_idx, lon_idx]],
                'lat_values': [lat_grid[lat_idx]],
                'lon_values': [lon_grid[lon_idx]]
            }
        else:
            # 双线性插值公式
            dx = (target_lon - x0) / (x1 - x0) if (x1 - x0) != 0 else 0
            dy = (target_lat - y0) / (y1 - y0) if (y1 - y0) != 0 else 0
            result = (1 - dx) * (1 - dy) * Q00 + dx * (1 - dy) * Q01 + (1 - dx) * dy * Q10 + dx * dy * Q11
            
            details = {
                'type': 'bilinear',
                'rows': [lat_idx0, lat_idx0, lat_idx1, lat_idx1],
                'cols': [lon_idx0, lon_idx1, lon_idx0, lon_idx1],
                'target_lat': target_lat,
                'target_lon': target_lon,
                'grid_shape': grid_shape,
                'values': [Q00, Q01, Q10, Q11],
                'lat_values': [y0, y0, y1, y1],
                'lon_values': [x0, x1, x0, x1]
            }
        
        # 保存插值详细信息
        interpolation_details.append(details)
        return result
    
    except Exception as e:
        print(f"插值错误: {str(e)}")
        return np.nan

# ====================== 主处理函数 ======================
def process_smapvex_data(input_file_path):
    """
    处理SMAPVEX数据，执行多种插值操作
    """
    global interpolation_details
    
    try:
        interpolation_details = []  # 重置插值详情
        
        # ========== 1. 读取原始数据 ==========
        print(f"读取原始Excel文件: {input_file_path}")
        df = pd.read_excel(input_file_path)
        
        # 定义标准经纬度网格 (0.1°分辨率)
        # 纬度: 北纬89.95°(0) -> 南纬-89.95°(1799)
        lat_grid = np.linspace(89.95, -89.95, 1800)
        
        # 经度: -179.95°(0) -> 179.95°(3599)
        lon_grid = np.linspace(-179.95, 179.95, 3600)
        
        print(f"成功读取 {len(df)} 条记录")
        
        # ========== 2. 准备PFT数据 (14个类别) ==========
        pft_file = r"E:\data\ESACCI PFT\Resample\Data\2002.mat"
        if os.path.exists(pft_file):
            print(f"\n处理PFT数据: {pft_file}")
            mat_data = read_hdf5_mat(pft_file)
            
            pft_columns = ['water','bare','snowice','built','grassnat','grassman',
                          'shrubbd','shrubbe','shrubnd','shrubne',
                          'treebd','treebe','treend','treene']
            
            available_pft = [col for col in pft_columns if col in mat_data]
            print(f"  文件中可用的PFT变量: {', '.join(available_pft)}")
            
            # 处理每个可用的PFT类别
            for col in available_pft:
                grid_data = mat_data[col] / 100.0
                df[f'PFT_{col}'] = df.apply(
                    lambda row: bilinear_interpolation_with_details(
                        lat_grid, lon_grid, 
                        row['Latitude'], row['Longitude'], 
                        grid_data
                    ) if not np.isnan(row['Latitude']) and not np.isnan(row['Longitude']) 
                    else np.nan, axis=1
                )
                print(f"  已添加列: PFT_{col}")
        else:
            print(f"\n警告: PFT文件不存在 - {pft_file}")
        
        # ========== 3. 处理VOD数据 (7个变量) ==========
        vod_base_dir = r"E:\data\VOD\mat\kuxcVOD\ASC"
        vod_cols = ['sm','ku_vod_H', 'ku_vod_V', 'x_vod_H','x_vod_V', 'c_vod_H','c_vod_V']
        
        for col in vod_cols:
            df[col] = np.nan
        
        print("\n处理VOD数据:")
        
        # 收集所有唯一日期并排序
        unique_dates = sorted(df['Date'].unique())
        vod_files_found = 0
        
        for date in unique_dates:
            # 转换为字符串格式YYYYMMDD
            try:
                if isinstance(date, pd.Timestamp):
                    date_str = date.strftime("%Y%m%d")
                else:
                    date_str = datetime.strptime(str(date)[:10], "%Y-%m-%d").strftime("%Y%m%d")
            except:
                print(f"  无法解析日期: {date}")
                continue
            
            vod_file = os.path.join(vod_base_dir, f"MCCA_AMSRE_010D_CCXH_VSM_VOD_Asc_{date_str}_V0.nc4.mat")
            if os.path.exists(vod_file):
                vod_files_found += 1
                print(f"  处理日期: {date_str}, 文件: {os.path.basename(vod_file)}")
                vod_data = read_hdf5_mat(vod_file)
                
                for col in vod_cols:
                    vod_var_name = col
                    if col == 'sm':
                        vod_var_name = 'SM'  # 实际文件中的变量名是大写的 SM
                    if vod_var_name in vod_data:
                        grid_data = vod_data[vod_var_name]
                        mask = df['Date'] == date
                        df.loc[mask, col] = df[mask].apply(
                            lambda row: bilinear_interpolation_with_details(
                                lat_grid, lon_grid, 
                                row['Latitude'], row['Longitude'], 
                                grid_data
                            ) if not np.isnan(row['Latitude']) and not np.isnan(row['Longitude']) 
                            else np.nan, axis=1
                        )
                        print(f"    已更新: {col}")
                    else:
                        print(f"    警告: VOD变量 {vod_var_name} (映射自 {col}) 不存在于文件中")
            else:
                print(f"  警告: VOD文件不存在 - {os.path.basename(vod_file)}")
                
        if vod_files_found == 0:
            print("  警告: 没有找到任何VOD文件，VOD列将保留为空")
        
        # ========== 4. 处理LAI卫星数据 (时间插值) ==========
        print("\n处理LAI卫星数据...")
        df['LAI_Satellite'] = np.nan
        
        # 预期可能的LAI变量名
        expected_lai_keys = ['lai', 'LAI', 'data']
        
        lai_sep_file = r"E:\data\GLASS LAI\mat\0.1Deg\Dataset\2002-06-01.tif.mat"
        lai_oct_file = r"E:\data\GLASS LAI\mat\0.1Deg\Dataset\2002-07-01.tif.mat"
        
        if os.path.exists(lai_sep_file) and os.path.exists(lai_oct_file):
            print(f"  加载6月LAI数据 (视为6月15日): {lai_sep_file}")
            sep_data = read_hdf5_mat(lai_sep_file, expected_keys=expected_lai_keys)
            
            # 直接获取LAI数据
            lai_sep_data = list(sep_data.values())[0] if sep_data else np.zeros((1800, 3600))
            print(f"    读取的数据形状: {lai_sep_data.shape}")
            
            print(f"  加载7月LAI数据 (视为7月15日): {lai_oct_file}")
            oct_data = read_hdf5_mat(lai_oct_file, expected_keys=expected_lai_keys)
            lai_oct_data = list(oct_data.values())[0] if oct_data else np.zeros((1800, 3600))
            print(f"    读取的数据形状: {lai_oct_data.shape}")
            
            # 定义月中日期
            sep_mid_date = datetime(2002, 6, 15)  # 9月15日
            oct_mid_date = datetime(2002, 7, 15)  # 10月15日
            
            total_days = (oct_mid_date - sep_mid_date).days
            
            # 处理每个日期的数据
            for date in unique_dates:
                try:
                    # 确保日期为datetime对象
                    if isinstance(date, pd.Timestamp):
                        date_dt = date.to_pydatetime()
                    elif isinstance(date, str):
                        date_dt = datetime.strptime(date, "%Y-%m-%d")
                    else:
                        date_dt = date
                    
                    # 计算时间权重
                    if date_dt <= sep_mid_date:
                        weight = 0.0
                    elif date_dt >= oct_mid_date:
                        weight = 1.0
                    else:
                        weight = (date_dt - sep_mid_date).days / total_days
                    
                    # 应用时间插值
                    interpolated_lai = (1 - weight) * lai_sep_data + weight * lai_oct_data
                    
                    mask = df['Date'] == date
                    df.loc[mask, 'LAI_Satellite'] = df[mask].apply(
                        lambda row: bilinear_interpolation_with_details(
                            lat_grid, lon_grid, 
                            row['Latitude'], row['Longitude'], 
                            interpolated_lai
                        ) if not np.isnan(row['Latitude']) and not np.isnan(row['Longitude']) 
                        else np.nan, axis=1
                    )
                    
                    # 验证插值结果
                    mean_lai = np.nanmean(interpolated_lai)
                    print(f"  已处理日期: {date_dt.strftime('%Y-%m-%d')}, 权重: {weight:.2f}, 平均LAI: {mean_lai:.4f}")
                except Exception as e:
                    print(f"  处理日期{date}时出错: {str(e)}")
        else:
            print(f"  警告: LAI文件不存在 - {' 或 '.join([lai_sep_file, lai_oct_file])}")
            print("  将使用固定值0作为LAI卫星数据")
            df['LAI_Satellite'] = 0.0
        
        # ========== 5. 处理植被高度数据 ==========
        print("\n处理植被高度数据...")
        df['Hveg'] = np.nan
        
        hveg_file = r"E:\data\CanopyHeight\CH.mat"
        if os.path.exists(hveg_file):
            print(f"  加载植被高度数据: {hveg_file}")
            hveg_data = read_hdf5_mat(hveg_file, expected_keys=['CH', 'ch'])
            
            # 直接获取高度数据
            hveg_key = list(hveg_data.keys())[0] if hveg_data else None
            
            if hveg_key:
                hveg_values = hveg_data[hveg_key]
                df['Hveg'] = df.apply(
                    lambda row: bilinear_interpolation_with_details(
                        lat_grid, lon_grid, 
                        row['Latitude'], row['Longitude'], 
                        hveg_values
                    ) if not np.isnan(row['Latitude']) and not np.isnan(row['Longitude']) 
                    else np.nan, axis=1
                )
                print(f"  已添加植被高度列，数据形状: {hveg_values.shape}")
            else:
                print(f"  警告: 无法找到Hveg变量")
                df['Hveg'] = np.nan
        else:
            print(f"  警告: Hveg文件不存在 - {hveg_file}")
            df['Hveg'] = np.nan
        
        # ========== 6. 保存结果 ==========
        output_file_path = r"E:\data\VWC\test-VWC\NSIDC_0666\SMEX02\processed_SMEX02V_ML.xlsx"
        print(f"\n保存结果到: {output_file_path}")
        df.to_excel(output_file_path, index=False)
        
        # 保存插值详细信息到Excel
        if interpolation_details:
            details_df = pd.DataFrame(interpolation_details)
            details_path = r"E:\data\VWC\test-VWC\NSIDC_0666\SMEX02\interpolation_details.xlsx"
            details_df.to_excel(details_path, index=False)
            print(f"插值详细信息保存到: {details_path}")
        else:
            print("警告: 没有插值详细信息可保存")
        
        # ========== 7. 统计报告 ==========
        print("\n处理完成!")
        print(f"总记录数: {len(df)}")
        print(f"插值操作次数: {len(interpolation_details)}")
        
        if interpolation_details:
            # 显示前3次插值的详细信息
            print("\n前3次插值的详细信息:")
            for i, detail in enumerate(interpolation_details[:3]):
                print(f"\n插值 #{i+1}")
                print(f"  类型: {detail['type']}")
                print(f"  目标位置: ({detail['target_lat']:.6f}, {detail['target_lon']:.6f})")
                print(f"  网格形状: {detail['grid_shape']}")
                
                if detail['type'] == 'bilinear':
                    print(f"  使用的4个网格点:")
                    for j in range(4):
                        print(f"    点{j+1}: 行 {detail['rows'][j]}, 列 {detail['cols'][j]} - " +
                              f"位置 ({detail['lat_values'][j]:.6f}, {detail['lon_values'][j]:.6f}) - " +
                              f"值: {detail['values'][j]:.6f}")
                else:
                    print(f"  最近邻点: 行 {detail['row']}, 列 {detail['col']} - " +
                          f"位置 ({detail['lat_values'][0]:.6f}, {detail['lon_values'][0]:.6f}) - " +
                          f"值: {detail['values'][0]:.6f}")
        
        return True
        
    except Exception as e:
        print(f"处理过程中出错: {str(e)}")
        import traceback
        print("错误详细信息:")
        print(traceback.format_exc())
        return False

# ========================== 主程序 ==========================
if __name__ == "__main__":
    # 输入文件路径
    input_file = r"E:\data\VWC\test-VWC\NSIDC_0666\SMEX02\processed_SMEX02V.xlsx"
    
    print("="*60)
    print("开始处理SMEX02数据插值任务")
    print("="*60)
    
    if not os.path.exists(input_file):
        print(f"错误: 输入文件不存在 - {input_file}")
        print(f"请检查路径: {os.path.abspath(input_file)}")
    else:
        print(f"输入文件: {input_file}")
        print(f"输出将保存到: E:\\data\\VWC\\test-VWC\\NSIDC_0666\\SMEX02\\processed_SMEX02V_ML.xlsx")
        
        success = process_smapvex_data(input_file)
        if success:
            print("\n" + "="*30)
            print("任务成功完成!")
            print("="*30)
        else:
            print("\n" + "="*30)
            print("任务失败，请检查错误信息")
            print("="*30)

开始处理SMEX02数据插值任务
输入文件: E:\data\VWC\test-VWC\NSIDC_0666\SMEX02\processed_SMEX02V.xlsx
输出将保存到: E:\data\VWC\test-VWC\NSIDC_0666\SMEX02\processed_SMEX02V_ML.xlsx
读取原始Excel文件: E:\data\VWC\test-VWC\NSIDC_0666\SMEX02\processed_SMEX02V.xlsx
成功读取 104 条记录

处理PFT数据: E:\data\ESACCI PFT\Resample\Data\2002.mat
  文件中可用的PFT变量: water, bare, snowice, built, grassnat, grassman, shrubbd, shrubbe, shrubnd, shrubne, treebd, treebe, treend, treene
  已添加列: PFT_water
  已添加列: PFT_bare
  已添加列: PFT_snowice
  已添加列: PFT_built
  已添加列: PFT_grassnat
  已添加列: PFT_grassman
  已添加列: PFT_shrubbd
  已添加列: PFT_shrubbe
  已添加列: PFT_shrubnd
  已添加列: PFT_shrubne
  已添加列: PFT_treebd
  已添加列: PFT_treebe
  已添加列: PFT_treend
  已添加列: PFT_treene

处理VOD数据:
  处理日期: 20020625, 文件: MCCA_AMSRE_010D_CCXH_VSM_VOD_Asc_20020625_V0.nc4.mat
    已更新: sm
    已更新: ku_vod_H
    已更新: ku_vod_V
    已更新: x_vod_H
    已更新: x_vod_V
    已更新: c_vod_H
    已更新: c_vod_V
  处理日期: 20020627, 文件: MCCA_AMSRE_010D_CCXH_VSM_VOD_Asc_20020627_V0.nc4.mat
    已更新: sm
    已更新: ku_

In [13]:
# 机器学习结果填充，和实测值对比
import pandas as pd
import numpy as np
import os
import joblib
import matplotlib.pyplot as plt
import matplotlib
from sklearn.metrics import mean_squared_error, r2_score

# 设置全局字体
plt.rcParams['font.family'] = 'Arial'
plt.rcParams['font.weight'] = 'bold'

# 设置常量
TEST_FILE = r"E:\data\VWC\test-VWC\NSIDC_0666\SMEX02\processed_SMEX02V_ML.xlsx"
MODEL_DIR = "models"
SAVE_RESULTS = "model_predictions_results_SMEX02.xlsx"
FIG_DIR = "figures"

# 定义波段和极化组合
BANDS = ['Ku', 'C', 'X']
POLS = ['H', 'V', 'HV']

# 波段颜色定义
BAND_COLORS = {
    'Ku': (253/255, 173/255, 115/255, 0.7),
    'C': (178/255, 125/255, 104/255, 0.7),
    'X': (224/255, 104/255, 46/255, 0.7)
}

# 极化类型标记定义
POL_MARKERS = {
    'H': 's',  # 方形
    'V': '^',  # 三角形
    'HV': 'o'  # 圆形
}

def normalize_LAI(lai_series):
    """对LAI进行归一化"""
    return lai_series.clip(0, 6) / 6

def normalize_VOD(vod_series):
    """对VOD进行归一化"""
    return vod_series.clip(0, 2) / 2

# PFT列名映射字典
PFT_MAPPING = {
    'PFT_grassnat': 'Grass_nat',
    'PFT_grassman': 'Grass_man',
    'PFT_shrubbd': 'Shrub_bd',
    'PFT_shrubbe': 'Shrub_be',
    'PFT_shrubnd': 'Shrub_nd',
    'PFT_shrubne': 'Shrub_ne',
    'PFT_treebd': 'Tree_bd',
    'PFT_treebe': 'Tree_be',
    'PFT_treend': 'Tree_nd',
    'PFT_treene': 'Tree_ne'
}

def get_model_columns(band, pol):
    """获取指定模型所需的列名"""
    base_columns = [
        'VWC-Field',  # 修改点1: VWC改为VWC-Field
        'LAI_Satellite',  # 修改点2: LAI改为LAI_Satellite
        'SM'  # 处理后合并的土壤湿度列
    ]
    
    # 添加所有PFT列
    base_columns.extend(PFT_MAPPING.keys())
    
    # 根据极化类型添加VOD列
    if pol == 'H':
        return base_columns + [f'{band.lower()}_vod_H']
    elif pol == 'V':
        return base_columns + [f'{band.lower()}_vod_V']
    elif pol == 'HV':
        return base_columns + [f'{band.lower()}_vod_H', f'{band.lower()}_vod_V']

def get_feature_order(pol):
    """获取特征列的顺序（模型期望的列顺序）"""
    base_features = [
        'LAI', 'SM',  # 注意：准备时会重命名LAI_Satellite为LAI
        'Grass_man', 'Grass_nat',
        'Shrub_bd', 'Shrub_be', 'Shrub_nd', 'Shrub_ne',
        'Tree_bd', 'Tree_be', 'Tree_nd', 'Tree_ne'
    ]
    
    if pol in ['H', 'V']:
        return ['VOD'] + base_features
    elif pol == 'HV':
        return ['VOD-Hpol', 'VOD-Vpol'] + base_features

def prepare_input_data(df, band, pol):
    """为指定模型准备输入数据"""
    # 创建数据副本
    data = df.copy()
    
    # 1. 归一化处理 - 使用新的LAI_Satellite列
    data['LAI'] = normalize_LAI(data['LAI_Satellite'])  # 修改点4: 使用LAI_Satellite
    
    # 2. 处理VOD列
    if pol == 'H':
        vod_col = f'{band.lower()}_vod_H'
        data['VOD'] = normalize_VOD(data[vod_col])
    elif pol == 'V':
        vod_col = f'{band.lower()}_vod_V'
        data['VOD'] = normalize_VOD(data[vod_col])
    elif pol == 'HV':
        # 重命名列以匹配模型训练时的特征名
        data = data.rename(columns={
            f'{band.lower()}_vod_H': 'VOD-Hpol',
            f'{band.lower()}_vod_V': 'VOD-Vpol'
        })
        # 归一化处理
        data['VOD-Hpol'] = normalize_VOD(data['VOD-Hpol'])
        data['VOD-Vpol'] = normalize_VOD(data['VOD-Vpol'])
    
    # 3. 重命名PFT列为模型期望的名称
    data = data.rename(columns=PFT_MAPPING)
    
    # 4. 按模型要求排序特征列
    feature_order = get_feature_order(pol)
    
    # 5. 确保特征顺序与模型训练时一致
    # 获取实际存在的列
    existing_columns = [col for col in feature_order if col in data.columns]
    
    # 检查是否有缺失的特征
    missing_features = set(feature_order) - set(existing_columns)
    if missing_features:
        print(f"  警告: 缺少特征列: {', '.join(missing_features)}")
        # 为缺失的特征添加空列
        for col in missing_features:
            data[col] = np.nan
    
    # 返回按正确顺序排列的特征
    return data[feature_order]

def plot_combined_scatter(actual, predictions_dict):
    """
    绘制组合散点图，包含所有波段和极化类型
    
    参数:
    actual -- 实际值 (Series)
    predictions_dict -- 字典结构: {
        'H': {band: pred_series},
        'V': {band: pred_series},
        'HV': {band: pred_series}
    }
    """
    # 创建图形
    plt.figure(figsize=(10, 10))
    ax = plt.gca()
    
    # 存储所有组合的RMSE值
    rmse_values = {}
    
    # 收集所有数据点
    max_val = 0
    
    # 遍历所有波段和极化组合
    for band in BANDS:
        for pol in POLS:
            pred_series = predictions_dict[pol].get(band)
            
            if pred_series is not None and not pred_series.isnull().all():
                # 创建实际值和预测值的临时DF
                temp_df = pd.DataFrame({
                    'actual': actual,
                    'pred': pred_series
                }).dropna()
                
                if not temp_df.empty:
                    # 计算RMSE
                    rmse = np.sqrt(mean_squared_error(temp_df['actual'], temp_df['pred']))
                    rmse_values[f"{band}-{pol}"] = rmse
                    
                    # 更新最大值
                    band_max = max(temp_df['actual'].max(), temp_df['pred'].max())
                    if band_max > max_val:
                        max_val = band_max
                    
                    # 绘制散点
                    plt.scatter(
                        temp_df['actual'], temp_df['pred'], 
                        alpha=0.7, 
                        color=BAND_COLORS[band],
                        marker=POL_MARKERS[pol],
                        s=50,
                        edgecolors='none',
                        zorder=2,
                        label=f"{band}-{pol}"
                    )
    
    # 如果没有数据可绘制，直接返回
    if not rmse_values:
        print("  警告: 没有有效的预测数据!")
        plt.close()
        return
    
    # 添加1:1参考线
    max_val *= 1.05
    plt.plot([0, max_val], [0, max_val], 'k--', lw=1.5, alpha=0.7, zorder=1)
    
    # 设置坐标轴范围
    plt.xlim(0, max_val)
    plt.ylim(0, max_val)
    
    # 设置坐标轴标签
    plt.xlabel('Insitu VWC (kg/m²)', fontsize=14, fontweight='bold')
    plt.ylabel('RF VWC (kg/m²)', fontsize=14, fontweight='bold')
    
    # 设置标题
    plt.title('SMEX08 Insitu VWC', 
             fontsize=18, fontweight='bold', pad=20)
    
    # 添加RMSE文本（左上角，3×3网格布局）
    if rmse_values:
        # 设置文本位置
        x_pos = 0.05
        y_pos = 0.95
        
        # 添加标题
        plt.text(x_pos, y_pos, 'RMSE (kg/m²):', 
                 transform=ax.transAxes,
                 fontsize=12,
                 fontweight='bold',
                 verticalalignment='top')
        
        y_pos -= 0.05
        
        # 遍历每个波段
        for band_idx, band in enumerate(BANDS):
            # 遍历每个极化类型
            for pol_idx, pol in enumerate(POLS):
                # 计算位置
                text_x = x_pos + pol_idx * 0.15
                text_y = y_pos - band_idx * 0.08
                
                # 获取RMSE值
                rmse = rmse_values.get(f"{band}-{pol}", None)
                
                if rmse is not None:
                    # 绘制标记
                    plt.scatter(
                        text_x, text_y, 
                        transform=ax.transAxes,
                        marker=POL_MARKERS[pol],
                        color=BAND_COLORS[band],
                        s=80,
                        alpha=0.7
                    )
                    
                    # 添加文本
                    plt.text(
                        text_x + 0.01, text_y, 
                        f"{band}-{pol}: {rmse:.3f}", 
                        transform=ax.transAxes,
                        fontsize=10,
                        fontweight='bold',
                        verticalalignment='center'
                    )
                else:
                    # 添加缺失值标记
                    plt.text(
                        text_x, text_y, 
                        f"{band}-{pol}: N/A", 
                        transform=ax.transAxes,
                        fontsize=10,
                        fontweight='bold',
                        verticalalignment='center',
                        color='gray'
                    )
    
    # 添加网格线
    plt.grid(True, linestyle='--', alpha=0.3, zorder=0)
    
    # 调整布局
    plt.tight_layout()
    
    # 创建保存目录
    os.makedirs(FIG_DIR, exist_ok=True)
    
    # 保存图像
    fig_path = os.path.join(FIG_DIR, 'SMEX02_VWC_Scatter.png')
    plt.savefig(fig_path, dpi=300, bbox_inches='tight')
    print(f"  组合散点图已保存至: {fig_path}")
    plt.close()

def predict_and_evaluate():
    """主函数：加载所有模型进行预测并评估结果"""
    # 1. 加载测试数据
    print(f"正在加载测试数据: {TEST_FILE}")
    
    # 收集所有可能的列
    all_columns = set(['VWC-Field', 'LAI_Satellite'])  # 修改点3: 更新列名
    
    # 添加土壤湿度相关列 (修改点5: 添加sm列)
    all_columns.update(['SM', 'sm'])
    
    # 添加所有PFT列
    all_columns.update(PFT_MAPPING.keys())
    
    # 添加所有VOD列
    for band in BANDS:
        all_columns.add(f'{band.lower()}_vod_H')
        all_columns.add(f'{band.lower()}_vod_V')
    
    # 读取Excel文件
    test_df = pd.read_excel(TEST_FILE, usecols=list(all_columns))
    print(f"加载完成，总样本数: {len(test_df)}")
    
    # 修改点6: 土壤湿度合并逻辑
    # 优先使用SM，如果SM没有值则使用sm
    if 'SM' in test_df.columns and 'sm' in test_df.columns:
        print("  合并SM和sm列...")
        # 合并土壤湿度：优先使用SM，如果SM没有值则使用sm
        test_df['SM'] = test_df['SM'].combine_first(test_df['sm'])
        # 删除多余的sm列
        test_df = test_df.drop(columns=['sm'])
        print("  合并完成，保留SM列")
    elif 'sm' in test_df.columns:
        print("  只有sm列，重命名为SM")
        test_df = test_df.rename(columns={'sm': 'SM'})
    elif 'SM' in test_df.columns:
        print("  只有SM列，无需处理")
    else:
        print("  警告: 没有找到任何土壤湿度列!")
        # 添加一个空的SM列
        test_df['SM'] = np.nan
    
    # 确保只有一个SM列
    if sum(col == 'SM' for col in test_df.columns) > 1:
        print("  警告: 存在多个SM列，删除重复列")
        # 保留第一个SM列，删除其他SM列
        sm_cols = [col for col in test_df.columns if col == 'SM']
        for col in sm_cols[1:]:
            test_df = test_df.drop(columns=[col])
    # 存储所有预测结果
    results = pd.DataFrame(index=test_df.index)
    results['Actual_VWC'] = test_df['VWC-Field']  # 修改点7: 使用新列名
    
    # 为每个极化类型存储预测结果
    predictions_by_pol = {
        'H': {band: None for band in BANDS},
        'V': {band: None for band in BANDS},
        'HV': {band: None for band in BANDS}
    }
    
    # 2. 对每个模型进行预测
    for band in BANDS:
        for pol in POLS:
            model_name = f"RFR_{band}_{pol}pol_Type1.pkl"
            model_path = os.path.join(MODEL_DIR, model_name)
            
            print(f"\n处理 {band}-{pol} 模型: {model_name}")
            
            # 准备输入数据
            model_cols = get_model_columns(band, pol)
            model_data = test_df[model_cols].copy()
            
            # 修改点8: 在土壤湿度合并后，跳过两个值都没有的行
            # 如果某行的两个土壤湿度都没有值，则不把这行输入至模型
            model_data = model_data.dropna(subset=['SM'])
            
            # 删除其他缺失值
            clean_data = model_data.dropna()
            print(f"  有效样本数: {len(clean_data)} (删除缺失值后)")
            
            if len(clean_data) == 0:
                print("  警告: 无有效样本可用于此模型!")
                results[f"{band}_{pol}_Predicted"] = np.nan
                predictions_by_pol[pol][band] = None
                continue
            
            # 预处理输入数据
            try:
                X_input = prepare_input_data(clean_data, band, pol)
                
                # 加载模型并进行预测
                if os.path.exists(model_path):
                    model = joblib.load(model_path)
                    # 添加特征顺序验证
                    if hasattr(model, 'feature_names_in_'):
                        # 获取模型期望的特征顺序
                        expected_features = model.feature_names_in_
                        
                        # 验证特征顺序
                        if list(X_input.columns) != list(expected_features):
                            print(f"  调整特征顺序以匹配模型期望...")
                            print(f"  当前特征顺序: {list(X_input.columns)}")
                            print(f"  模型期望顺序: {list(expected_features)}")
                            
                            # 调整特征顺序
                            X_input = X_input[expected_features]
                    
                    predictions = model.predict(X_input)
                    
                    # 存储预测结果
                    results[f"{band}_{pol}_Predicted"] = np.nan
                    results.loc[clean_data.index, f"{band}_{pol}_Predicted"] = predictions
                    
                    # 存储到对应极化类型的字典
                    predictions_by_pol[pol][band] = results[f"{band}_{pol}_Predicted"].copy()
                    
                    # 计算评估指标
                    actual = clean_data['VWC-Field']  # 修改点9: 使用新列名
                    rmse = np.sqrt(mean_squared_error(actual, predictions))
                    r2 = r2_score(actual, predictions)
                    print(f"  预测完成 - RMSE: {rmse:.4f}, R²: {r2:.4f}")
                    
                else:
                    print(f"  警告: 未找到模型文件 {model_path}!")
                    results[f"{band}_{pol}_Predicted"] = np.nan
                    predictions_by_pol[pol][band] = None
            except Exception as e:
                import traceback
                print(f"  预测失败: {str(e)}")
                # 打印更详细的错误信息
                traceback.print_exc()
                results[f"{band}_{pol}_Predicted"] = np.nan
                predictions_by_pol[pol][band] = None
    
    # 3. 保存结果
    results.to_excel(SAVE_RESULTS)
    print(f"\n所有预测结果已保存至: {SAVE_RESULTS}")
    
    # 4. 绘制组合散点图
    print("\n正在绘制组合散点图...")
    plot_combined_scatter(
        results['Actual_VWC'], 
        predictions_by_pol
    )
    
    return results, predictions_by_pol

# 执行主函数
if __name__ == "__main__":
    results, predictions_by_pol = predict_and_evaluate()
    print("\n所有处理完成!")

正在加载测试数据: E:\data\VWC\test-VWC\NSIDC_0666\SMEX02\processed_SMEX02V_ML.xlsx
加载完成，总样本数: 104
  合并SM和sm列...
  合并完成，保留SM列

处理 Ku-H 模型: RFR_Ku_Hpol_Type1.pkl
  有效样本数: 61 (删除缺失值后)
  预测完成 - RMSE: 3.4445, R²: -2.9965

处理 Ku-V 模型: RFR_Ku_Vpol_Type1.pkl
  有效样本数: 35 (删除缺失值后)
  预测完成 - RMSE: 5.0307, R²: -6.8198

处理 Ku-HV 模型: RFR_Ku_HVpol_Type1.pkl
  有效样本数: 35 (删除缺失值后)
  预测完成 - RMSE: 3.4030, R²: -2.5782

处理 C-H 模型: RFR_C_Hpol_Type1.pkl
  有效样本数: 35 (删除缺失值后)
  预测完成 - RMSE: 4.2484, R²: -4.5769

处理 C-V 模型: RFR_C_Vpol_Type1.pkl
  有效样本数: 13 (删除缺失值后)
  预测完成 - RMSE: 3.7594, R²: -3.5016

处理 C-HV 模型: RFR_C_HVpol_Type1.pkl
  有效样本数: 13 (删除缺失值后)
  预测完成 - RMSE: 3.2189, R²: -2.3003

处理 X-H 模型: RFR_X_Hpol_Type1.pkl
  有效样本数: 65 (删除缺失值后)
  预测完成 - RMSE: 4.2207, R²: -5.0884

处理 X-V 模型: RFR_X_Vpol_Type1.pkl
  有效样本数: 43 (删除缺失值后)
  预测完成 - RMSE: 3.2505, R²: -2.5602

处理 X-HV 模型: RFR_X_HVpol_Type1.pkl
  有效样本数: 43 (删除缺失值后)
  预测完成 - RMSE: 2.9027, R²: -1.8390

所有预测结果已保存至: model_predictions_results_SMEX02.xlsx

正在绘制组合散点图...
  组合散

# 5.SMAPVEX2016-Manitoba

In [8]:
# 将kml文件的经纬度信息整合至表格中。
import os
import pandas as pd
import geopandas as gpd
import zipfile
from pathlib import Path

# 配置路径
base_path = Path(r"E:\data\VWC\test-VWC\Insitu SMAPVEX16 Manitoba")
output_dir = base_path / "Processed_Results"
output_dir.mkdir(exist_ok=True)

def extract_kmz(kmz_path, output_name):
    """安全解压KMZ文件"""
    extract_dir = output_dir / output_name
    extract_dir.mkdir(exist_ok=True)
    
    with zipfile.ZipFile(kmz_path, 'r') as zip_ref:
        zip_ref.extractall(extract_dir)
    
    for file in extract_dir.glob("**/*.kml"):
        return file
    raise FileNotFoundError(f"在 {kmz_path} 中未找到KML文件")

def process_field_sites():
    """处理站点数据并添加经纬度"""
    print("步骤1：处理KMZ站点文件...")
    sites_kml = extract_kmz(
        base_path / "SV16M_V_FieldSites.kmz",
        "FieldSites_Extracted"
    )
    
    print("步骤2：读取并转换坐标系...")
    # 读取KML数据
    sites = gpd.read_file(str(sites_kml))
    
    # 转换为文档指定的原始坐标系 (EPSG:3158)
    sites_nad83 = sites.to_crs("EPSG:3158")
    
    # 转换为WGS84 (EPSG:4326)
    sites_wgs = sites_nad83.to_crs("EPSG:4326")
    
    print("步骤3：提取经纬度...")
    # 创建包含坐标的DataFrame
    sites_coords = pd.DataFrame({
        'Name': sites_wgs['Name'],
        'Longitude': sites_wgs.geometry.x,
        'Latitude': sites_wgs.geometry.y
    })
    
    print("步骤4：处理CSV站点数据...")
    # 扫描所有CSV文件
    csv_files = [
        "SV16M_V_CropHeight_Vers3.csv",
        "SV16M_V_CropBiomass_Vers4.csv"
    ]
    
    for csv_file in csv_files:
        csv_path = base_path / csv_file
        if not csv_path.exists():
            print(f"警告: 文件不存在 {csv_file}")
            continue
            
        # 读取CSV
        df = pd.read_csv(csv_path)
        
        # 检查是否包含SITE_ID列
        if 'SITE_ID' not in df.columns:
            print(f"跳过: {csv_file} 不包含SITE_ID列")
            continue
            
        print(f"处理: {csv_file}")
        
        # 合并坐标数据
        merged = df.merge(
            sites_coords, 
            left_on='SITE_ID', 
            right_on='Name',
            how='left'
        )
        
        # 移除辅助列
        merged = merged.drop(columns=['Name'], errors='ignore')
        merged['SITE_ID'] = merged['SITE_ID'].astype(str)
        
        # 保存结果
        output_path = output_dir / f"{csv_path.stem}_with_coords.csv"
        merged.to_csv(output_path, index=False, sep='\t')
        print(f"已添加经纬度: {output_path}")
    
    print("处理完成!")

if __name__ == "__main__":
    process_field_sites()

步骤1：处理KMZ站点文件...
步骤2：读取并转换坐标系...
步骤3：提取经纬度...
步骤4：处理CSV站点数据...
处理: SV16M_V_CropHeight_Vers3.csv
已添加经纬度: E:\data\VWC\test-VWC\Insitu SMAPVEX16 Manitoba\Processed_Results\SV16M_V_CropHeight_Vers3_with_coords.csv
处理: SV16M_V_CropBiomass_Vers4.csv
已添加经纬度: E:\data\VWC\test-VWC\Insitu SMAPVEX16 Manitoba\Processed_Results\SV16M_V_CropBiomass_Vers4_with_coords.csv
处理完成!


In [22]:
# 数据填充
import pandas as pd
import numpy as np
import os
import h5py
from datetime import datetime, timedelta
import warnings
warnings.filterwarnings("ignore")

# 全局变量记录插值详细信息
interpolation_details = []

# ====================== 改进的MAT文件读取函数 ======================
def read_hdf5_mat(file_path, expected_keys=None):
    """读取MATLAB v7.3格式的HDF5文件，优先查找特定变量"""
    try:
        with h5py.File(file_path, 'r') as f:
            data = {}
            
            def visitor_func(name, obj):
                if isinstance(obj, h5py.Dataset):
                    if h5py.check_string_dtype(obj.dtype):
                        value = ''.join(chr(c) for c in obj[:])
                    else:
                        value = np.array(obj)
                    if value.ndim >= 2:
                        value = value.T
                    base_name = name.split('/')[-1]
                    data[base_name] = value
            
            f.visititems(visitor_func)
            
            # 优先查找预期变量
            if expected_keys:
                for key in expected_keys:
                    if key in data:
                        return {key: data[key]}
            
            return data
    except Exception as e:
        print(f"  读取HDF5 MAT文件失败: {str(e)}")
        return {}

# ====================== 改进的双线性插值函数 ======================
def bilinear_interpolation_with_details(lat_grid, lon_grid, target_lat, target_lon, grid_data):
    """
    执行双线性插值并记录详细信息
    :param lat_grid: 网格纬度数组 (1D, 从北向南递减)
    :param lon_grid: 网格经度数组 (1D, 从西向东递增)
    :param target_lat: 目标点纬度
    :param target_lon: 目标点经度
    :param grid_data: 网格数据 (2D数组, 形状为(len(lat_grid), len(lon_grid)))
    :return: 插值值
    """
    global interpolation_details
    
    try:
        # 记录网格形状
        grid_shape = grid_data.shape
        
        # 验证网格尺寸
        if len(lat_grid) != grid_shape[0] or len(lon_grid) != grid_shape[1]:
            print(f"警告: 网格尺寸不匹配! 纬度网格: {len(lat_grid)}, 经度网格: {len(lon_grid)}, 数据形状: {grid_shape}")
            return np.nan
        
        # 查找最近的纬度索引（纬度从北向南递减）
        # 纬度网格: 89.95 (北) -> -89.95 (南)
        lat_idx = np.argmin(np.abs(lat_grid - target_lat))
        
        # 查找最近的经度索引（经度从西向东递增）
        # 经度网格: -179.95 (西) -> 179.95 (东)
        lon_idx = np.argmin(np.abs(lon_grid - target_lon))
        
        # 确定四个角点索引
        # 纬度处理：目标点位于两个纬度网格点之间
        if lat_idx == 0:
            lat_idx0, lat_idx1 = 0, 1
        elif lat_idx == len(lat_grid) - 1:
            lat_idx0, lat_idx1 = len(lat_grid) - 2, len(lat_grid) - 1
        else:
            if target_lat > lat_grid[lat_idx]:
                # 目标纬度大于当前网格点纬度（更北）
                if lat_idx > 0:
                    lat_idx0 = lat_idx - 1
                    lat_idx1 = lat_idx
                else:
                    lat_idx0 = lat_idx
                    lat_idx1 = lat_idx
            else:
                # 目标纬度小于当前网格点纬度（更南）
                if lat_idx < len(lat_grid) - 1:
                    lat_idx0 = lat_idx
                    lat_idx1 = lat_idx + 1
                else:
                    lat_idx0 = lat_idx
                    lat_idx1 = lat_idx
        
        # 经度处理：目标点位于两个经度网格点之间
        if lon_idx == 0:
            lon_idx0, lon_idx1 = 0, 1
        elif lon_idx == len(lon_grid) - 1:
            lon_idx0, lon_idx1 = len(lon_grid) - 2, len(lon_grid) - 1
        else:
            if target_lon > lon_grid[lon_idx]:
                # 目标经度大于当前网格点经度（更东）
                if lon_idx < len(lon_grid) - 1:
                    lon_idx0 = lon_idx
                    lon_idx1 = lon_idx + 1
                else:
                    lon_idx0 = lon_idx
                    lon_idx1 = lon_idx
            else:
                # 目标经度小于当前网格点经度（更西）
                if lon_idx > 0:
                    lon_idx0 = lon_idx - 1
                    lon_idx1 = lon_idx
                else:
                    lon_idx0 = lon_idx
                    lon_idx1 = lon_idx
        
        # 获取四个角点值
        Q00 = grid_data[lat_idx0, lon_idx0]
        Q01 = grid_data[lat_idx0, lon_idx1]
        Q10 = grid_data[lat_idx1, lon_idx0]
        Q11 = grid_data[lat_idx1, lon_idx1]
        
        # 四个角点坐标
        y0 = lat_grid[lat_idx0]
        y1 = lat_grid[lat_idx1]
        x0 = lon_grid[lon_idx0]
        x1 = lon_grid[lon_idx1]
        
        # 如果有NaN，使用最接近的点
        if np.isnan(Q00) or np.isnan(Q01) or np.isnan(Q10) or np.isnan(Q11):
            result = grid_data[lat_idx, lon_idx]
            details = {
                'type': 'nearest',
                'row': lat_idx,
                'col': lon_idx,
                'target_lat': target_lat,
                'target_lon': target_lon,
                'grid_shape': grid_shape,
                'values': [grid_data[lat_idx, lon_idx]],
                'lat_values': [lat_grid[lat_idx]],
                'lon_values': [lon_grid[lon_idx]]
            }
        else:
            # 双线性插值公式
            dx = (target_lon - x0) / (x1 - x0) if (x1 - x0) != 0 else 0
            dy = (target_lat - y0) / (y1 - y0) if (y1 - y0) != 0 else 0
            result = (1 - dx) * (1 - dy) * Q00 + dx * (1 - dy) * Q01 + (1 - dx) * dy * Q10 + dx * dy * Q11
            
            details = {
                'type': 'bilinear',
                'rows': [lat_idx0, lat_idx0, lat_idx1, lat_idx1],
                'cols': [lon_idx0, lon_idx1, lon_idx0, lon_idx1],
                'target_lat': target_lat,
                'target_lon': target_lon,
                'grid_shape': grid_shape,
                'values': [Q00, Q01, Q10, Q11],
                'lat_values': [y0, y0, y1, y1],
                'lon_values': [x0, x1, x0, x1]
            }
        
        # 保存插值详细信息
        interpolation_details.append(details)
        return result
    
    except Exception as e:
        print(f"插值错误: {str(e)}")
        return np.nan

# ====================== 主处理函数 ======================
def process_smapvex16_manitoba(input_file_path):
    """
    处理SMAPVEX16 Manitoba数据，执行多种插值操作
    """
    global interpolation_details
    
    try:
        interpolation_details = []  # 重置插值详情
        
        # ========== 1. 读取原始数据 ==========
        print(f"读取原始CSV文件: {input_file_path}")
        df = pd.read_csv(input_file_path)
        
        # 重命名列以保持一致性
        df.rename(columns={
            'DATE': 'Date',
            'Latitude': 'Latitude',
            'Longitude': 'Longitude'
        }, inplace=True)
        
        # 转换日期格式
        df['Date'] = pd.to_datetime(df['Date'])
        
        print(f"成功读取 {len(df)} 条记录")
        print(f"日期范围: {df['Date'].min().strftime('%Y-%m-%d')} 至 {df['Date'].max().strftime('%Y-%m-%d')}")
        
        # 定义标准经纬度网格 (0.1°分辨率)
        # 纬度: 北纬89.95°(0) -> 南纬-89.95°(1799)
        lat_grid = np.linspace(89.95, -89.95, 1800)
        
        # 经度: -179.95°(0) -> 179.95°(3599)
        lon_grid = np.linspace(-179.95, 179.95, 3600)
        
        # ========== 2. 准备PFT数据 (14个类别) ==========
        pft_file = r"E:\data\ESACCI PFT\Resample\Data\2016.mat"
        if os.path.exists(pft_file):
            print(f"\n处理PFT数据: {pft_file}")
            mat_data = read_hdf5_mat(pft_file)
            
            pft_columns = ['water','bare','snowice','built','grassnat','grassman',
                          'shrubbd','shrubbe','shrubnd','shrubne',
                          'treebd','treebe','treend','treene']
            
            available_pft = [col for col in pft_columns if col in mat_data]
            print(f"  文件中可用的PFT变量: {', '.join(available_pft)}")
            
            # 处理每个可用的PFT类别
            for col in available_pft:
                grid_data = mat_data[col] / 100.0
                df[f'PFT_{col}'] = df.apply(
                    lambda row: bilinear_interpolation_with_details(
                        lat_grid, lon_grid, 
                        row['Latitude'], row['Longitude'], 
                        grid_data
                    ) if not np.isnan(row['Latitude']) and not np.isnan(row['Longitude']) 
                    else np.nan, axis=1
                )
                print(f"  已添加列: PFT_{col}")
        else:
            print(f"\n警告: PFT文件不存在 - {pft_file}")
        
        # ========== 3. 处理VOD数据 (7个变量) ==========
        vod_base_dir = r"E:\data\VOD\mat\kuxcVOD\ASC"
        vod_cols = ['SM','ku_vod_H', 'ku_vod_V', 'x_vod_H','x_vod_V', 'c_vod_H','c_vod_V']
        
        for col in vod_cols:
            df[col] = np.nan
        
        print("\n处理VOD数据:")
        
        # 收集所有唯一日期并排序
        unique_dates = sorted(df['Date'].unique())
        vod_files_found = 0
        
        for date in unique_dates:
            # 转换为字符串格式YYYYMMDD
            try:
                if isinstance(date, pd.Timestamp):
                    date_str = date.strftime("%Y%m%d")
                else:
                    date_str = date.strftime("%Y%m%d")
            except:
                print(f"  无法解析日期: {date}")
                continue
            
            # SMAPVEX16 使用 AMSR2 传感器
            vod_file = os.path.join(vod_base_dir, f"MCCA_AMSR2_010D_CCXH_VSM_VOD_Asc_{date_str}_V0.nc4.mat")
            if not os.path.exists(vod_file):
                # 尝试备用命名格式
                vod_file = os.path.join(vod_base_dir, f"MCCA_AMSR2_010D_{date_str}_V0.nc4.mat")
            
            if os.path.exists(vod_file):
                vod_files_found += 1
                print(f"  处理日期: {date_str}, 文件: {os.path.basename(vod_file)}")
                vod_data = read_hdf5_mat(vod_file)
                
                for col in vod_cols:
                    if col in vod_data:
                        grid_data = vod_data[col]
                        mask = df['Date'] == date
                        df.loc[mask, col] = df[mask].apply(
                            lambda row: bilinear_interpolation_with_details(
                                lat_grid, lon_grid, 
                                row['Latitude'], row['Longitude'], 
                                grid_data
                            ) if not np.isnan(row['Latitude']) and not np.isnan(row['Longitude']) 
                            else np.nan, axis=1
                        )
                        print(f"    已更新: {col}")
                    else:
                        print(f"    警告: VOD变量 {col} 不存在于文件中")
            else:
                print(f"  警告: VOD文件不存在 - {os.path.basename(vod_file)}")
                
        if vod_files_found == 0:
            print("  警告: 没有找到任何VOD文件，VOD列将保留为空")
        
        # ========== 4. 处理LAI卫星数据 (动态日期范围) ==========
        print("\n处理LAI卫星数据...")
        df['LAI_Satellite'] = np.nan
        
        # 确定日期范围
        min_date = df['Date'].min().to_pydatetime()
        max_date = df['Date'].max().to_pydatetime()
        
        # 扩展日期范围（前后各加一个月）
        start_month = (min_date - timedelta(days=30)).replace(day=1)
        end_month = (max_date + timedelta(days=30)).replace(day=1)
        
        # 获取所有需要的LAI文件
        lai_files = {}
        current = start_month
        while current <= end_month:
            # 每个月15日作为代表日
            date_key = current.replace(day=15)
            # 文件路径格式: E:\data\GLASS LAI\mat\0.1Deg\Dataset\2016-04-01.tif.mat
            file_path = os.path.join(r"E:\data\GLASS LAI\mat\0.1Deg\Dataset", f"{date_key.strftime('%Y-%m')}-01.tif.mat")
            lai_files[date_key] = file_path
            current = current + timedelta(days=32)  # 移动到下一个月
            current = current.replace(day=1)
        
        print(f"根据数据日期范围 ({min_date.strftime('%Y-%m-%d')} 至 {max_date.strftime('%Y-%m-%d')})")
        print(f"确定需要 {len(lai_files)} 个LAI文件: {', '.join([d.strftime('%Y-%m') for d in lai_files.keys()])}")
        
        # 加载LAI数据
        expected_lai_keys = ['lai', 'LAI', 'data']
        lai_data = {}
        for date_key, file_path in lai_files.items():
            if os.path.exists(file_path):
                print(f"  加载LAI数据 ({date_key.strftime('%Y-%m-%d')}): {file_path}")
                file_data = read_hdf5_mat(file_path, expected_keys=expected_lai_keys)
                
                if file_data:
                    # 直接获取LAI数据
                    lai_value = list(file_data.values())[0]
                    # 验证数据形状
                    if lai_value.shape != (1800, 3600):
                        print(f"    警告: 数据形状异常 {lai_value.shape}, 调整为(1800, 3600)")
                        if lai_value.shape[0] < 1800 or lai_value.shape[1] < 3600:
                            lai_value = np.pad(lai_value, ((0, 1800-lai_value.shape[0]), (0, 3600-lai_value.shape[1])), 'constant')
                        else:
                            lai_value = lai_value[:1800, :3600]
                    
                    print(f"    成功读取LAI变量 '{list(file_data.keys())[0]}'，数据形状: {lai_value.shape}")
                    lai_data[date_key] = lai_value
                else:
                    print("    警告: 文件中未找到预期LAI变量，使用全零数组")
                    lai_data[date_key] = np.zeros((1800, 3600))
            else:
                print(f"  警告: LAI文件不存在 - {file_path}")
                lai_data[date_key] = np.zeros((1800, 3600))
        
        # 排序关键日期
        lai_dates = sorted(lai_data.keys())
        
        # 处理每个日期的数据
        print("\n插值LAI数据到观测点...")
        for date in unique_dates:
            try:
                date_dt = date.to_pydatetime()
                
                # 找到观测日期前后的最近关键日期
                # 找到第一个大于等于观测日期的关键日期索引
                idx = next((i for i, d in enumerate(lai_dates) if d >= date_dt), len(lai_dates))
                
                if idx == 0:
                    # 所有关键日期都在观测日期之后，使用第一个关键日期
                    prev_date = next_date = lai_dates[0]
                    weight = 0.0
                elif idx == len(lai_dates):
                    # 所有关键日期都在观测日期之前，使用最后一个关键日期
                    prev_date = next_date = lai_dates[-1]
                    weight = 0.0
                else:
                    # 获取前后两个关键日期
                    prev_date = lai_dates[idx-1]
                    next_date = lai_dates[idx]
                    
                    # 计算时间权重
                    total_days = (next_date - prev_date).days
                    days_passed = (date_dt - prev_date).days
                    weight = days_passed / total_days if total_days > 0 else 0
                
                # 应用时间插值
                interpolated_lai = (1 - weight) * lai_data[prev_date] + weight * lai_data[next_date]
                
                print(f"  处理日期: {date_dt.strftime('%Y-%m-%d')}, 使用日期: {prev_date.strftime('%Y-%m-%d')} 和 {next_date.strftime('%Y-%m-%d')}, 权重: {weight:.2f}")
                
                # 应用空间插值
                mask = df['Date'] == date
                df.loc[mask, 'LAI_Satellite'] = df[mask].apply(
                    lambda row: bilinear_interpolation_with_details(
                        lat_grid, lon_grid, 
                        row['Latitude'], row['Longitude'], 
                        interpolated_lai
                    ) if not np.isnan(row['Latitude']) and not np.isnan(row['Longitude']) 
                    else np.nan, axis=1
                )
                
                # 验证插值结果
                mean_lai = np.nanmean(interpolated_lai)
                valid_count = df.loc[mask, 'LAI_Satellite'].notna().sum()
                print(f"    平均LAI: {mean_lai:.4f}, 成功插值 {valid_count} 条记录")
            
            except Exception as e:
                print(f"  处理日期{date_dt.strftime('%Y-%m-%d')}时出错: {str(e)}")
                
        # ========== 5. 处理植被高度数据 ==========
        print("\n处理植被高度数据...")
        df['Hveg'] = np.nan
        
        hveg_file = r"E:\data\CanopyHeight\CH.mat"
        if os.path.exists(hveg_file):
            print(f"  加载植被高度数据: {hveg_file}")
            hveg_data = read_hdf5_mat(hveg_file, expected_keys=['CH', 'ch'])
            
            # 直接获取高度数据
            hveg_key = list(hveg_data.keys())[0] if hveg_data else None
            
            if hveg_key:
                hveg_values = hveg_data[hveg_key]
                df['Hveg'] = df.apply(
                    lambda row: bilinear_interpolation_with_details(
                        lat_grid, lon_grid, 
                        row['Latitude'], row['Longitude'], 
                        hveg_values
                    ) if not np.isnan(row['Latitude']) and not np.isnan(row['Longitude']) 
                    else np.nan, axis=1
                )
                print(f"  已添加植被高度列，数据形状: {hveg_values.shape}")
            else:
                print(f"  警告: 无法找到Hveg变量")
                df['Hveg'] = np.nan
        else:
            print(f"  警告: Hveg文件不存在 - {hveg_file}")
            df['Hveg'] = np.nan
        
        # ========== 6. 保存结果 ==========
        output_file_path = r"E:\data\VWC\test-VWC\Insitu SMAPVEX16 Manitoba\Processed_Results\SV16M_V_CropBiomass_Vers4_with_coords_ML.csv"
        print(f"\n保存结果到: {output_file_path}")
        df.to_csv(output_file_path, index=False)
        
        # 保存插值详细信息到CSV
        if interpolation_details:
            details_df = pd.DataFrame(interpolation_details)
            details_path = r"E:\data\VWC\test-VWC\Insitu SMAPVEX16 Manitoba\Processed_Results\interpolation_details.csv"
            details_df.to_csv(details_path, index=False)
            print(f"插值详细信息保存到: {details_path}")
        
        # ========== 7. 统计报告 ==========
        print("\n处理完成!")
        print(f"总记录数: {len(df)}")
        print(f"插值操作次数: {len(interpolation_details)}")
        print(f"填充的LAI值数量: {df['LAI_Satellite'].notna().sum()}")
        print(f"填充的VOD值数量: {df['SM'].notna().sum() if 'SM' in df.columns else 0}")
        
        return True
        
    except Exception as e:
        print(f"处理过程中出错: {str(e)}")
        import traceback
        print("错误详细信息:")
        print(traceback.format_exc())
        return False

# ========================== 主程序 ==========================
if __name__ == "__main__":
    # 输入文件路径
    input_file = r"E:\data\VWC\test-VWC\Insitu SMAPVEX16 Manitoba\Processed_Results\SV16M_V_CropBiomass_Vers4_with_coords.csv"
    
    print("="*60)
    print("开始处理SMAPVEX16 Manitoba数据插值任务")
    print("="*60)
    
    if not os.path.exists(input_file):
        print(f"错误: 输入文件不存在 - {input_file}")
        print(f"请检查路径: {os.path.abspath(input_file)}")
    else:
        print(f"输入文件: {input_file}")
        print(f"输出将保存到: E:\\data\\VWC\\test-VWC\\Insitu SMAPVEX16 Manitoba\\Processed_Results\\SV16M_V_CropBiomass_Vers4_with_coords_ML.csv")
        
        success = process_smapvex16_manitoba(input_file)
        if success:
            print("\n" + "="*30)
            print("任务成功完成!")
            print("="*30)
        else:
            print("\n" + "="*30)
            print("任务失败，请检查错误信息")
            print("="*30)

开始处理SMAPVEX16 Manitoba数据插值任务
输入文件: E:\data\VWC\test-VWC\Insitu SMAPVEX16 Manitoba\Processed_Results\SV16M_V_CropBiomass_Vers4_with_coords.csv
输出将保存到: E:\data\VWC\test-VWC\Insitu SMAPVEX16 Manitoba\Processed_Results\SV16M_V_CropBiomass_Vers4_with_coords_ML.csv
读取原始CSV文件: E:\data\VWC\test-VWC\Insitu SMAPVEX16 Manitoba\Processed_Results\SV16M_V_CropBiomass_Vers4_with_coords.csv
成功读取 1400 条记录
日期范围: 2016-06-13 至 2016-07-21

处理PFT数据: E:\data\ESACCI PFT\Resample\Data\2016.mat
  文件中可用的PFT变量: water, bare, snowice, built, grassnat, grassman, shrubbd, shrubbe, shrubnd, shrubne, treebd, treebe, treend, treene
  已添加列: PFT_water
  已添加列: PFT_bare
  已添加列: PFT_snowice
  已添加列: PFT_built
  已添加列: PFT_grassnat
  已添加列: PFT_grassman
  已添加列: PFT_shrubbd
  已添加列: PFT_shrubbe
  已添加列: PFT_shrubnd
  已添加列: PFT_shrubne
  已添加列: PFT_treebd
  已添加列: PFT_treebe
  已添加列: PFT_treend
  已添加列: PFT_treene

处理VOD数据:
  处理日期: 20160613, 文件: MCCA_AMSR2_010D_CCXH_VSM_VOD_Asc_20160613_V0.nc4.mat
    已更新: SM
    已更新: ku_vod_H
    已更新: 

In [22]:
# 机器学习结果填充，和实测值对比
import pandas as pd
import numpy as np
import os
import joblib
import matplotlib.pyplot as plt
import matplotlib
from sklearn.metrics import mean_squared_error, r2_score

# 设置全局字体
plt.rcParams['font.family'] = 'Arial'
plt.rcParams['font.weight'] = 'bold'

# 设置常量 - 更新为CSV文件路径
TEST_FILE = r"E:\data\VWC\test-VWC\Insitu SMAPVEX16 Manitoba\Processed_Results\SV16M_V_CropBiomass_Vers4_with_coords_ML.csv"
MODEL_DIR = "models"
SAVE_RESULTS = "model_predictions_results_SMAPVEX16.xlsx"
FIG_DIR = "figures"

# 定义波段和极化组合
BANDS = ['Ku', 'C', 'X']
POLS = ['H', 'V', 'HV']

# 波段颜色定义
BAND_COLORS = {
    'Ku': (253/255, 173/255, 115/255, 0.7),
    'C': (178/255, 125/255, 104/255, 0.7),
    'X': (224/255, 104/255, 46/255, 0.7)
}

# 极化类型标记定义
POL_MARKERS = {
    'H': 's',  # 方形
    'V': '^',  # 三角形
    'HV': 'o'  # 圆形
}

def normalize_LAI(lai_series):
    """对LAI进行归一化"""
    return lai_series.clip(0, 6) / 6

def normalize_VOD(vod_series):
    """对VOD进行归一化"""
    return vod_series.clip(0, 2) / 2

# PFT列名映射字典
PFT_MAPPING = {
    'PFT_grassnat': 'Grass_nat',
    'PFT_grassman': 'Grass_man',
    'PFT_shrubbd': 'Shrub_bd',
    'PFT_shrubbe': 'Shrub_be',
    'PFT_shrubnd': 'Shrub_nd',
    'PFT_shrubne': 'Shrub_ne',
    'PFT_treebd': 'Tree_bd',
    'PFT_treebe': 'Tree_be',
    'PFT_treend': 'Tree_nd',
    'PFT_treene': 'Tree_ne'
}

def get_model_columns(band, pol):
    """获取指定模型所需的列名"""
    base_columns = [
        'PLANT_WATER_CONTENT_AREA',  # 实测VWC列名
        'LAI_Satellite',             # LAI列名
        'SM'                         # 土壤湿度
    ]
    
    # 添加所有PFT列
    base_columns.extend(PFT_MAPPING.keys())
    
    # 根据极化类型添加VOD列
    if pol == 'H':
        return base_columns + [f'{band.lower()}_vod_H']
    elif pol == 'V':
        return base_columns + [f'{band.lower()}_vod_V']
    elif pol == 'HV':
        return base_columns + [f'{band.lower()}_vod_H', f'{band.lower()}_vod_V']

def get_feature_order(pol):
    """获取特征列的顺序（模型期望的列顺序）"""
    base_features = [
        'LAI', 'SM',
        'Grass_man', 'Grass_nat',
        'Shrub_bd', 'Shrub_be', 'Shrub_nd', 'Shrub_ne',
        'Tree_bd', 'Tree_be', 'Tree_nd', 'Tree_ne'
    ]
    
    if pol in ['H', 'V']:
        return ['VOD'] + base_features
    elif pol == 'HV':
        return ['VOD-Hpol', 'VOD-Vpol'] + base_features

def prepare_input_data(df, band, pol):
    """为指定模型准备输入数据"""
    # 创建数据副本
    data = df.copy()
    
    # 1. 归一化处理
    data['LAI'] = normalize_LAI(data['LAI_Satellite'])
    
    # 2. 处理VOD列
    if pol == 'H':
        vod_col = f'{band.lower()}_vod_H'
        data['VOD'] = normalize_VOD(data[vod_col])
    elif pol == 'V':
        vod_col = f'{band.lower()}_vod_V'
        data['VOD'] = normalize_VOD(data[vod_col])
    elif pol == 'HV':
        # 重命名列以匹配模型训练时的特征名
        data = data.rename(columns={
            f'{band.lower()}_vod_H': 'VOD-Hpol',
            f'{band.lower()}_vod_V': 'VOD-Vpol'
        })
        # 归一化处理
        data['VOD-Hpol'] = normalize_VOD(data['VOD-Hpol'])
        data['VOD-Vpol'] = normalize_VOD(data['VOD-Vpol'])
    
    # 3. 重命名PFT列为模型期望的名称
    data = data.rename(columns=PFT_MAPPING)
    
    # 4. 按模型要求排序特征列
    feature_order = get_feature_order(pol)
    
    return data[feature_order]

def plot_combined_scatter(actual, predictions_dict):
    """
    绘制组合散点图，包含所有波段和极化类型
    
    参数:
    actual -- 实际值 (Series)
    predictions_dict -- 字典结构: {
        'H': {band: pred_series},
        'V': {band: pred_series},
        'HV': {band: pred_series}
    }
    """
    # 创建图形
    plt.figure(figsize=(10, 10))
    ax = plt.gca()
    
    # 存储所有组合的RMSE值
    rmse_values = {}
    
    # 收集所有数据点
    max_val = 0
    
    # 遍历所有波段和极化组合
    for band in BANDS:
        for pol in POLS:
            pred_series = predictions_dict[pol].get(band)
            
            if pred_series is not None and not pred_series.isnull().all():
                # 创建实际值和预测值的临时DF
                temp_df = pd.DataFrame({
                    'actual': actual,
                    'pred': pred_series
                }).dropna()
                
                if not temp_df.empty:
                    # 计算RMSE
                    rmse = np.sqrt(mean_squared_error(temp_df['actual'], temp_df['pred']))
                    rmse_values[f"{band}-{pol}"] = rmse
                    
                    # 更新最大值
                    band_max = max(temp_df['actual'].max(), temp_df['pred'].max())
                    if band_max > max_val:
                        max_val = band_max
                    
                    # 绘制散点
                    plt.scatter(
                        temp_df['actual'], temp_df['pred'], 
                        alpha=0.7, 
                        color=BAND_COLORS[band],
                        marker=POL_MARKERS[pol],
                        s=50,
                        edgecolors='none',
                        zorder=2,
                        label=f"{band}-{pol}"
                    )
    
    # 如果没有数据可绘制，直接返回
    if not rmse_values:
        print("  警告: 没有有效的预测数据!")
        plt.close()
        return
    
    # 添加1:1参考线
    max_val *= 1.05
    plt.plot([0, max_val], [0, max_val], 'k--', lw=1.5, alpha=0.7, zorder=1)
    
    # 设置坐标轴范围
    plt.xlim(0, max_val)
    plt.ylim(0, max_val)
    
    # 设置坐标轴标签
    plt.xlabel('Insitu VWC (kg/m²)', fontsize=14, fontweight='bold')
    plt.ylabel('RF VWC (kg/m²)', fontsize=14, fontweight='bold')
    
    # 设置标题
    plt.title('SMAPVEX16 Insitu VWC', 
             fontsize=18, fontweight='bold', pad=20)
    
    # 添加RMSE文本（左上角，3×3网格布局）
    if rmse_values:
        # 设置文本位置
        x_pos = 0.05
        y_pos = 0.95
        
        # 计算RMSE表格的边界
        min_x = x_pos
        max_x = x_pos + 0.3  # 3列 * 0.1
        min_y = y_pos - 0.05 - 3 * 0.08  # 3行 * 0.08
        max_y = y_pos
        
        # 添加白色背景矩形框
        rect = plt.Rectangle(
            (min_x - 0.01, min_y - 0.01),  # 左下角位置
            max_x - min_x + 0.02,          # 宽度
            max_y - min_y + 0.02,          # 高度
            transform=ax.transAxes,
            facecolor='white',
            edgecolor='black',
            linewidth=1.5,
            zorder=3
        )
        ax.add_patch(rect)
        
        # 添加标题
        plt.text(x_pos, y_pos, 'RMSE (kg/m²):', 
                 transform=ax.transAxes,
                 fontsize=12,
                 fontweight='bold',
                 verticalalignment='top',
                 zorder=4)
        
        y_pos -= 0.05
        
        # 添加列标题（极化类型）
        for pol_idx, pol in enumerate(POLS):
            plt.text(
                x_pos + 0.05 + pol_idx * 0.1, y_pos,
                f"{pol}-pol",
                transform=ax.transAxes,
                fontsize=10,
                fontweight='bold',
                verticalalignment='center',
                horizontalalignment='center',
                zorder=4
            )
        
        # 添加行标题（波段）
        for band_idx, band in enumerate(BANDS):
            plt.text(
                x_pos - 0.02, y_pos - (band_idx + 1) * 0.08,
                f"{band} Band:",
                transform=ax.transAxes,
                fontsize=10,
                fontweight='bold',
                verticalalignment='center',
                horizontalalignment='right',
                zorder=4
            )
        
        # 遍历每个波段
        for band_idx, band in enumerate(BANDS):
            # 遍历每个极化类型
            for pol_idx, pol in enumerate(POLS):
                # 计算位置
                text_x = x_pos + 0.05 + pol_idx * 0.1
                text_y = y_pos - (band_idx + 1) * 0.08
                
                # 获取RMSE值
                rmse = rmse_values.get(f"{band}-{pol}", None)
                
                if rmse is not None:
                    # 绘制标记
                    plt.scatter(
                        text_x, text_y, 
                        transform=ax.transAxes,
                        marker=POL_MARKERS[pol],
                        color=BAND_COLORS[band],
                        s=80,
                        alpha=0.7,
                        zorder=4
                    )
                    
                    # 添加RMSE值
                    plt.text(
                        text_x + 0.03, text_y, 
                        f"{rmse:.3f}", 
                        transform=ax.transAxes,
                        fontsize=10,
                        fontweight='bold',
                        verticalalignment='center',
                        zorder=4
                    )
                else:
                    plt.text(
                        text_x, text_y, 
                        "N/A", 
                        transform=ax.transAxes,
                        fontsize=10,
                        fontweight='bold',
                        verticalalignment='center',
                        color='gray',
                        zorder=4
                    )
    
    # 添加网格线
    plt.grid(True, linestyle='--', alpha=0.3, zorder=0)
    
    # 调整布局
    plt.tight_layout()
    
    # 创建保存目录
    os.makedirs(FIG_DIR, exist_ok=True)
    
    # 保存图像
    fig_path = os.path.join(FIG_DIR, 'SMAPVEX16_VWC_Scatter.png')
    plt.savefig(fig_path, dpi=300, bbox_inches='tight')
    print(f"  组合散点图已保存至: {fig_path}")
    plt.close()

def predict_and_evaluate():
    """主函数：加载所有模型进行预测并评估结果"""
    # 1. 加载测试数据
    print(f"正在加载测试数据: {TEST_FILE}")
    
    # 收集所有可能的列
    all_columns = set([
        'PLANT_WATER_CONTENT_AREA',  # 实测VWC
        'LAI_Satellite',             # LAI
        'SM'                         # 土壤湿度
    ])
    # 添加所有PFT列
    all_columns.update(PFT_MAPPING.keys())
    # 添加所有VOD列
    for band in BANDS:
        all_columns.add(f'{band.lower()}_vod_H')
        all_columns.add(f'{band.lower()}_vod_V')
    
    # 读取CSV文件
    test_df = pd.read_csv(TEST_FILE, usecols=list(all_columns))
    print(f"加载完成，总样本数: {len(test_df)}")
    
    # 存储所有预测结果
    results = pd.DataFrame(index=test_df.index)
    results['Actual_VWC'] = test_df['PLANT_WATER_CONTENT_AREA']
    
    # 为每个极化类型存储预测结果
    predictions_by_pol = {
        'H': {band: None for band in BANDS},
        'V': {band: None for band in BANDS},
        'HV': {band: None for band in BANDS}
    }
    
    # 2. 对每个模型进行预测
    for band in BANDS:
        for pol in POLS:
            model_name = f"RFR_{band}_{pol}pol_Type1.pkl"
            model_path = os.path.join(MODEL_DIR, model_name)
            
            print(f"\n处理 {band}-{pol} 模型: {model_name}")
            
            # 准备输入数据
            model_cols = get_model_columns(band, pol)
            model_data = test_df[model_cols].copy()
            
            # 删除缺失值
            clean_data = model_data.dropna()
            print(f"  有效样本数: {len(clean_data)} (删除缺失值后)")
            
            if len(clean_data) == 0:
                print("  警告: 无有效样本可用于此模型!")
                results[f"{band}_{pol}_Predicted"] = np.nan
                predictions_by_pol[pol][band] = None
                continue
            
            # 预处理输入数据
            try:
                X_input = prepare_input_data(clean_data, band, pol)
                
                # 加载模型并进行预测
                if os.path.exists(model_path):
                    model = joblib.load(model_path)
                    predictions = model.predict(X_input)
                    
                    # 存储预测结果
                    results[f"{band}_{pol}_Predicted"] = np.nan
                    results.loc[clean_data.index, f"{band}_{pol}_Predicted"] = predictions
                    
                    # 存储到对应极化类型的字典
                    predictions_by_pol[pol][band] = results[f"{band}_{pol}_Predicted"].copy()
                    
                    # 计算评估指标
                    actual = clean_data['PLANT_WATER_CONTENT_AREA']
                    rmse = np.sqrt(mean_squared_error(actual, predictions))
                    r2 = r2_score(actual, predictions)
                    print(f"  预测完成 - RMSE: {rmse:.4f}, R²: {r2:.4f}")
                    
                else:
                    print(f"  警告: 未找到模型文件 {model_path}!")
                    results[f"{band}_{pol}_Predicted"] = np.nan
                    predictions_by_pol[pol][band] = None
            except Exception as e:
                import traceback
                print(f"  预测失败: {str(e)}")
                # 打印更详细的错误信息
                traceback.print_exc()
                results[f"{band}_{pol}_Predicted"] = np.nan
                predictions_by_pol[pol][band] = None
    
    # 3. 保存结果
    results.to_excel(SAVE_RESULTS)
    print(f"\n所有预测结果已保存至: {SAVE_RESULTS}")
    
    # 4. 绘制组合散点图
    print("\n正在绘制组合散点图...")
    plot_combined_scatter(
        results['Actual_VWC'], 
        predictions_by_pol
    )
    
    return results, predictions_by_pol

# 执行主函数
if __name__ == "__main__":
    results, predictions_by_pol = predict_and_evaluate()
    print("\n所有处理完成!")

正在加载测试数据: E:\data\VWC\test-VWC\Insitu SMAPVEX16 Manitoba\Processed_Results\SV16M_V_CropBiomass_Vers4_with_coords_ML.csv
加载完成，总样本数: 1400

处理 Ku-H 模型: RFR_Ku_Hpol_Type1.pkl
  有效样本数: 1375 (删除缺失值后)
  预测完成 - RMSE: 4.1583, R²: -8.3293

处理 Ku-V 模型: RFR_Ku_Vpol_Type1.pkl
  有效样本数: 1080 (删除缺失值后)
  预测完成 - RMSE: 5.7038, R²: -16.7305

处理 Ku-HV 模型: RFR_Ku_HVpol_Type1.pkl
  有效样本数: 1080 (删除缺失值后)
  预测完成 - RMSE: 5.2502, R²: -14.0223

处理 C-H 模型: RFR_C_Hpol_Type1.pkl
  有效样本数: 1375 (删除缺失值后)
  预测完成 - RMSE: 4.8618, R²: -11.7531

处理 C-V 模型: RFR_C_Vpol_Type1.pkl
  有效样本数: 855 (删除缺失值后)
  预测完成 - RMSE: 4.8582, R²: -13.7392

处理 C-HV 模型: RFR_C_HVpol_Type1.pkl
  有效样本数: 855 (删除缺失值后)
  预测完成 - RMSE: 4.6950, R²: -12.7658

处理 X-H 模型: RFR_X_Hpol_Type1.pkl
  有效样本数: 1375 (删除缺失值后)
  预测完成 - RMSE: 5.0755, R²: -12.8986

处理 X-V 模型: RFR_X_Vpol_Type1.pkl
  有效样本数: 1312 (删除缺失值后)
  预测完成 - RMSE: 4.0062, R²: -7.4932

处理 X-HV 模型: RFR_X_HVpol_Type1.pkl
  有效样本数: 1312 (删除缺失值后)
  预测完成 - RMSE: 4.0201, R²: -7.5520

所有预测结果已保存至: model_prediction

# SMEX02+CLASIC07+SMEX08+SMAPVEX16——处理为像元数据来检测，消除像元内出现的强烈的不一致性

In [1]:
# 整合以前处理的数据，计算平均值
import pandas as pd
import numpy as np
from pathlib import Path

# 常量定义
LAT_MIN, LAT_MAX = -89.95, 89.95  # 纬度范围
LON_MIN, LON_MAX = -179.95, 179.95  # 经度范围
RESOLUTION = 0.1  # 像元分辨率

def calculate_grid_index(lat, lon):
    """计算给定经纬度对应的像元行列索引"""
    # 检查NaN值并跳过
    if pd.isna(lat) or pd.isna(lon):
        return (None, None)
    
    # 确保纬度在有效范围内
    if lat < LAT_MIN or lat > LAT_MAX:
        # 调整超出范围的纬度
        lat = max(min(lat, LAT_MAX), LAT_MIN)
    
    # 计算纬度索引（行）
    lat_index = int(np.round((LAT_MAX - lat) / RESOLUTION))
    lat_index = max(0, min(lat_index, 1799))  # 确保在0~1799范围内
    
    # 调整经度范围 (0-360)
    if lon < LON_MIN:
        lon += 360
    elif lon > LON_MAX:
        lon -= 360
    
    # 计算经度索引（列）
    lon_index = int(np.round((lon - LON_MIN) / RESOLUTION)) % 3600
    
    return lat_index, lon_index

def get_center_coordinates(row, col):
    """根据行列索引计算像元中心经纬度"""
    center_lat = LAT_MAX - row * RESOLUTION
    center_lon = LON_MIN + col * RESOLUTION
    
    # 确保经度在[-180, 180]范围内
    if center_lon > 180:
        center_lon -= 360
    elif center_lon < -180:
        center_lon += 360
        
    return center_lat, center_lon

def process_file(file_path, date_col, lat_col, lon_col, dataset_name):
    """处理单个文件：读取数据、计算像元、聚合平均值"""
    try:
        print(f"\n开始处理: {dataset_name}")
        # 读取文件
        suffix = Path(file_path).suffix.lower()
        if suffix == '.csv':
            df = pd.read_csv(file_path)
        elif suffix in ('.xlsx', '.xls'):
            df = pd.read_excel(file_path)
        else:
            raise ValueError(f"不支持的格式: {suffix}")
        
        print(f"原始数据行数: {len(df)}")
        
        # 重命名列以统一处理
        df = df.rename(columns={
            lat_col: 'Latitude',
            lon_col: 'Longitude',
            date_col: 'Date'
        })
        
        # ======== 新增：日期格式统一转换为YYYYMMDD ========
        try:
            # 尝试转换为datetime对象
            df['Date'] = pd.to_datetime(df['Date'], errors='coerce')
            # 格式化为YYYYMMDD字符串
            df['Date'] = df['Date'].dt.strftime('%Y%m%d')
            print("日期格式成功转换为YYYYMMDD")
        except Exception as date_err:
            print(f"日期转换异常: {str(date_err)}")
            print("尝试直接处理原始格式...")
            # 备份原始日期列
            df['Date_orig'] = df['Date']
            # 移除非数字字符并取前8位
            df['Date'] = (
                df['Date'].astype(str)
                .str.replace(r'[^\d]', '', regex=True)  # 移除非数字
                .str[:8]  # 取前8位数字
            )
            # 检查转换后的格式有效性
            invalid_mask = df['Date'].str.len() != 8
            if invalid_mask.any():
                invalid_count = invalid_mask.sum()
                print(f"警告: {invalid_count}行日期格式无效 (非8位数字)")
                # 恢复无法转换的原始值
                df.loc[invalid_mask, 'Date'] = df.loc[invalid_mask, 'Date_orig']
            df.drop(columns='Date_orig', inplace=True)
        # ======== 日期处理结束 ========
        
        # 检查并处理缺失值 (包含转换后的日期)
        initial_size = len(df)
        df = df.dropna(subset=['Latitude', 'Longitude', 'Date'])
        na_count = initial_size - len(df)
        if na_count > 0:
            print(f"移除 {na_count} 行缺失经纬度或日期的数据")
        if len(df) == 0:
            print("警告: 删除缺失值后数据集为空!")
            return pd.DataFrame()
        
        # 计算像元行列索引
        print("计算网格索引...")
        grid_indices = df.apply(
            lambda x: calculate_grid_index(x['Latitude'], x['Longitude']), 
            axis=1
        )
        
        # 创建包含索引的DataFrame
        indices_df = pd.DataFrame(grid_indices.tolist(), columns=['row', 'col'], index=df.index)
        
        # 处理无效索引
        valid_indices = indices_df.notna().all(axis=1)
        invalid_count = len(indices_df) - valid_indices.sum()
        
        if invalid_count > 0:
            print(f"警告: {invalid_count} 行有无效经纬度，将被移除")
            df = df[valid_indices].copy()
            indices_df = indices_df[valid_indices]
        
        # 分配行列索引
        df['row'] = indices_df['row']
        df['col'] = indices_df['col']
        
        # 删除原始经纬度列
        df.drop(columns=['Latitude', 'Longitude'], inplace=True)
        
        # 按日期和像元分组求平均值
        print("分组求平均值...")
        grouped = df.groupby(['Date', 'row', 'col']).mean(numeric_only=True).reset_index()
        
        # 添加像元中心经纬度
        print("添加像元中心坐标...")
        center_coords = grouped.apply(
            lambda x: get_center_coordinates(x['row'], x['col']), 
            axis=1
        )
        # 修改列名以明确表示这是网格中心
        grouped[['Center_Latitude', 'Center_Longitude']] = pd.DataFrame(
            center_coords.tolist(), 
            columns=['Center_Latitude', 'Center_Longitude'], 
            index=grouped.index
        )
        
        # 确保所有计算列都在数据框中
        grouped = grouped.reindex(columns=['Date', 'row', 'col', 'Center_Latitude', 'Center_Longitude'] + 
                                [col for col in grouped.columns if col not in ['Date', 'row', 'col', 'Center_Latitude', 'Center_Longitude']])
        
        print(f"处理完成, 有效数据行数: {len(grouped)}")
        return grouped
    
    except Exception as e:
        print(f"处理 {dataset_name} 时发生错误: {str(e)}")
        import traceback
        traceback.print_exc()
        print(f"文件路径: {file_path}")
        print(f"日期列名: {date_col}, 纬度列名: {lat_col}, 经度列名: {lon_col}")
        return pd.DataFrame()  # 返回空DataFrame避免中断

# 文件配置列表
files_config = [
    {
        'name': 'SMEX02',
        'path': r'E:\data\VWC\test-VWC\NSIDC_0666\SMEX02\processed_SMEX02V.xlsx',
        'date_col': 'Date',
        'lat_col': 'Latitude',
        'lon_col': 'Longitude'
    },
    {
        'name': 'CLASIC07',
        'path': r'E:\data\VWC\test-VWC\Insitu CLASIC07\CL07V_SUM_VEG_CLASIC.xlsx',
        'date_col': 'Date',
        'lat_col': 'Latitude (WGS84)',
        'lon_col': 'Longitude (WGS84)'
    },
    {
        'name': 'SMEX08',
        'path': r'E:\data\VWC\test-VWC\Insitu SMEX08\processed_SV08V_Sum_VEG_SMAPVEX.xlsx',
        'date_col': 'Date',
        'lat_col': 'Latitude',
        'lon_col': 'Longitude'
    },
    {
        'name': 'SMAPVEX16',
        'path': r'E:\data\VWC\test-VWC\Insitu SMAPVEX16 Manitoba\Processed_Results\SV16M_V_CropBiomass_Vers4_with_coords.csv',
        'date_col': 'DATE',
        'lat_col': 'Latitude',
        'lon_col': 'Longitude'
    }
]

# 主处理过程
if __name__ == "__main__":
    output_path = r'E:\data\VWC\test-VWC\SMEX02_CLASIC07_SMEX08_SMAPVEX16\InsituData_Pixel.xlsx'
    Path(output_path).parent.mkdir(parents=True, exist_ok=True)  # 确保目录存在
    
    with pd.ExcelWriter(output_path, engine='openpyxl') as writer:
        for config in files_config:
            processed_df = process_file(
                config['path'],
                config['date_col'],
                config['lat_col'],
                config['lon_col'],
                config['name']
            )
            
            if not processed_df.empty:
                processed_df.to_excel(writer, sheet_name=config['name'], index=False)
                print(f"成功保存 {config['name']} 数据到Sheet\n")
            else:
                print(f"{config['name']} 无有效数据可保存\n")
    
    print(f"处理完成! 结果已保存至: {output_path}")


开始处理: SMEX02
原始数据行数: 104
日期格式成功转换为YYYYMMDD
计算网格索引...
分组求平均值...
添加像元中心坐标...
处理完成, 有效数据行数: 16
成功保存 SMEX02 数据到Sheet


开始处理: CLASIC07
原始数据行数: 22
日期格式成功转换为YYYYMMDD
计算网格索引...
分组求平均值...
添加像元中心坐标...
处理完成, 有效数据行数: 18
成功保存 CLASIC07 数据到Sheet


开始处理: SMEX08
原始数据行数: 10
日期格式成功转换为YYYYMMDD
计算网格索引...
分组求平均值...
添加像元中心坐标...
处理完成, 有效数据行数: 6
成功保存 SMEX08 数据到Sheet


开始处理: SMAPVEX16
原始数据行数: 1400
日期格式成功转换为YYYYMMDD
移除 25 行缺失经纬度或日期的数据
计算网格索引...
分组求平均值...
添加像元中心坐标...
处理完成, 有效数据行数: 115
成功保存 SMAPVEX16 数据到Sheet

处理完成! 结果已保存至: E:\data\VWC\test-VWC\SMEX02_CLASIC07_SMEX08_SMAPVEX16\InsituData_Pixel.xlsx


In [1]:
# 数据填充
import pandas as pd
import numpy as np
from pathlib import Path
import os
import h5py
from datetime import datetime, timedelta
import calendar
import warnings
warnings.filterwarnings('ignore', category=FutureWarning)

# 常量定义
TOTAL_ROWS = 1800  # 纬度方向像元数
TOTAL_COLS = 3600  # 经度方向像元数
VOD_VARIABLES = ['SM', 'ku_vod_H', 'ku_vod_V', 'x_vod_H', 'x_vod_V', 'c_vod_H', 'c_vod_V']
PFT_VARIABLES = ['water', 'bare', 'snowice', 'built', 'grassnat', 'grassman', 
                 'shrubbd', 'shrubbe', 'shrubnd', 'shrubne', 'treebd', 'treebe', 'treend', 'treene']

def read_mat_file(file_path, variable_names, silent=False):
    """
    读取MAT文件并返回所需变量的数据矩阵
    
    参数:
    file_path (str): MAT文件路径
    variable_names (list): 要读取的变量名列表
    silent (bool): 是否静默处理错误
    
    返回:
    dict: 包含变量名及其对应的矩阵数据
    """
    try:
        # 尝试使用h5py读取v7.3格式
        with h5py.File(file_path, 'r') as f:
            data = {}
            for var in variable_names:
                if var in f:
                    dataset = f[var]
                    # 如果数据是引用类型，获取实际数据
                    if isinstance(dataset, h5py.Reference):
                        dataset = f[dataset]
                    # 确保是二维数组
                    if len(dataset.shape) == 2:
                        matrix = dataset[()]
                        # 检查矩阵方向是否需要转置
                        if matrix.shape == (TOTAL_ROWS, TOTAL_COLS):
                            data[var] = matrix
                        elif matrix.shape == (TOTAL_COLS, TOTAL_ROWS):
                            data[var] = matrix.T
                        else:
                            # 尝试重塑为正确形状
                            try:
                                data[var] = matrix.reshape(TOTAL_ROWS, TOTAL_COLS)
                            except:
                                data[var] = np.full((TOTAL_ROWS, TOTAL_COLS), np.nan)
                    else:
                        data[var] = np.full((TOTAL_ROWS, TOTAL_COLS), np.nan)
            return data
    except Exception as e:
        if not silent:
            print(f"警告: 读取文件 {file_path} 时出错: {str(e)}")
        return None

def safe_date_to_str(date_val):
    """安全地将日期值转换为YYYYMMDD格式的字符串"""
    if pd.isna(date_val):
        return ""
    
    # 处理不同的日期格式
    if isinstance(date_val, datetime):
        return date_val.strftime('%Y%m%d')
    elif isinstance(date_val, np.datetime64):
        return pd.to_datetime(date_val).strftime('%Y%m%d')
    elif isinstance(date_val, (int, float)):
        # 数字日期 (如20220715.0)
        date_str = str(int(date_val))
        return date_str[:8] if len(date_str) > 8 else date_str.zfill(8)
    else:
        # 字符串日期
        date_str = str(date_val).replace('-', '').replace('/', '').replace(' ', '')
        return date_str[:8] if len(date_str) > 8 else date_str.zfill(8)

def calculate_lai_weight(date_str):
    """计算LAI插值权重（修正版）"""
    if len(date_str) != 8 or not date_str.isdigit():
        return None, None, 0.0
    
    try:
        year = int(date_str[:4])
        month = int(date_str[4:6])
        day = int(date_str[6:8])
    except:
        return None, None, 0.0
    
    # 处理无效日期
    try:
        current_date = datetime(year, month, day)
    except ValueError:
        # 处理无效日期（如2月31日）
        if month == 2 and day > 28:
            day = 28
        elif day > 30 and month in [4, 6, 9, 11]:
            day = 30
        elif day > 31 and month in [1, 3, 5, 7, 8, 10, 12]:
            day = 31
            
        try:
            current_date = datetime(year, month, day)
        except:
            return None, None, 0.0
    
    # 确定正确的月份对
    if day < 15:
        # 如果日期在15日之前，使用前一个月和当前月
        prev_month = month - 1
        prev_year = year
        if prev_month == 0:
            prev_month = 12
            prev_year = year - 1
        
        prev_month_mid = datetime(prev_year, prev_month, 15)
        current_month_mid = datetime(year, month, 15)
        
        total_days = (current_month_mid - prev_month_mid).days
        days_passed = (current_date - prev_month_mid).days
    else:
        # 如果日期在15日或之后，使用当前月和下一月
        current_month_mid = datetime(year, month, 15)
        
        next_month = month + 1
        next_year = year
        if next_month > 12:
            next_month = 1
            next_year += 1
        next_month_mid = datetime(next_year, next_month, 15)
        
        total_days = (next_month_mid - current_month_mid).days
        days_passed = (current_date - current_month_mid).days
    
    if total_days <= 0:
        weight = 0.0
    else:
        weight = max(0.0, min(1.0, days_passed / total_days))
    
    # 返回月份对和权重
    if day < 15:
        return (prev_year, prev_month), (year, month), weight
    else:
        return (year, month), (next_year, next_month), weight

def process_dataset(df, sheet_name):
    """
    处理单个数据集，添加额外的卫星数据列
    
    参数:
    df (pd.DataFrame): 输入数据集
    sheet_name (str): sheet名称
    
    返回:
    pd.DataFrame: 处理后的数据集
    """
    print(f"\n开始处理 {sheet_name} 数据集...")
    print(f"原始数据行数: {len(df)}")
    
    # ===================================================================
    # 步骤1: 添加VOD及SM数据
    # ===================================================================
    print("添加VOD及SM数据...")
    
    # 准备结果列
    for var in VOD_VARIABLES:
        col_name = 'SM_Satellite' if var == 'SM' else var
        df[col_name] = np.nan
    
    # 创建缓存以提高性能
    date_vod_map = {}
    
    # 处理每个日期
    dates = df['Date'].apply(safe_date_to_str).unique()
    for date_str in dates:
        if not date_str or len(date_str) != 8:
            continue
            
        try:
            year = int(date_str[:4])
            
            # 确定文件路径
            if year <= 2012:
                file_path = f"E:\\data\\VOD\\mat\\kuxcVOD\\ASC\\MCCA_AMSRE_010D_CCXH_VSM_VOD_Asc_{date_str}_V0.nc4.mat"
            else:
                file_path = f"E:\\data\\VOD\\mat\\kuxcVOD\\ASC\\MCCA_AMSR2_010D_CCXH_VSM_VOD_Asc_{date_str}_V0.nc4.mat"
            
            # 检查并读取文件
            if os.path.exists(file_path):
                vod_data = read_mat_file(file_path, VOD_VARIABLES, silent=True)
                if vod_data:
                    date_vod_map[date_str] = vod_data
        except Exception as e:
            print(f"处理日期 {date_str} 的VOD文件时出错: {str(e)}")
    
    # 填充数据
    for i in df.index:
        date_str = safe_date_to_str(df.at[i, 'Date'])
        if not date_str or date_str not in date_vod_map:
            continue
            
        vod_data = date_vod_map[date_str]
        row_index = int(df.at[i, 'row'])
        col_index = int(df.at[i, 'col'])
        
        for var in VOD_VARIABLES:
            col_name = 'SM_Satellite' if var == 'SM' else var
            matrix = vod_data.get(var)
            if matrix is not None and not np.isnan(matrix[row_index, col_index]):
                df.at[i, col_name] = matrix[row_index, col_index]
    
    # ===================================================================
    # 步骤2: 添加PFT数据
    # ===================================================================
    print("添加PFT数据...")
    
    # 准备结果列
    for var in PFT_VARIABLES:
        df[var] = np.nan
    
    # 获取所有年份
    years = set()
    for date_str in df['Date'].apply(safe_date_to_str):
        if len(date_str) >= 4 and date_str[:4].isdigit():
            years.add(int(date_str[:4]))
    
    # 创建缓存以提高性能
    year_pft_map = {}
    
    # 处理每个年份
    for year in years:
        file_path = f"E:\\data\\ESACCI PFT\\Resample\\Data\\{year}.mat"
        if os.path.exists(file_path):
            pft_data = read_mat_file(file_path, PFT_VARIABLES, silent=True)
            if pft_data:
                year_pft_map[year] = pft_data
    
    # 填充数据
    for i in df.index:
        date_str = safe_date_to_str(df.at[i, 'Date'])
        if not date_str or len(date_str) < 4:
            continue
            
        try:
            year = int(date_str[:4])
            pft_data = year_pft_map.get(year)
            
            if pft_data is None:
                continue
                
            row_index = int(df.at[i, 'row'])
            col_index = int(df.at[i, 'col'])
            
            for var in PFT_VARIABLES:
                matrix = pft_data.get(var)
                if matrix is not None and not np.isnan(matrix[row_index, col_index]):
                    df.at[i, var] = matrix[row_index, col_index]
        except Exception as e:
            print(f"处理行 {i} 的PFT数据时出错: {str(e)}")
    
    # ===================================================================
    # 步骤3: 添加LAI数据
    # ===================================================================
    print("添加LAI数据...")
    df['LAI_Satellite'] = np.nan
    
    # 创建缓存以提高性能
    lai_cache = {}
    
    # 处理每个日期
    for i in df.index:
        date_str = safe_date_to_str(df.at[i, 'Date'])
        if not date_str or len(date_str) != 8:
            continue
            
        # 计算权重和月份
        prev_month, next_month, weight = calculate_lai_weight(date_str)
        if prev_month is None:
            continue
            
        # 检查并读取当前月份文件
        lai1 = np.nan
        file1_path = f"E:\\data\\GLASS LAI\\mat\\0.1Deg\\Dataset\\{prev_month[0]:04d}-{prev_month[1]:02d}-01.tif.mat"
    
       
        if file1_path in lai_cache:
            lai1 = lai_cache[file1_path]
        elif os.path.exists(file1_path):
            lai_data1 = read_mat_file(file1_path, ['lai'], silent=True)
            if lai_data1 and 'lai' in lai_data1:
                matrix = lai_data1['lai']
                row_index = int(df.at[i, 'row'])
                col_index = int(df.at[i, 'col'])
                
                try:
                    lai1 = matrix[row_index, col_index]
                    lai_cache[file1_path] = lai1
                except:
                    lai_cache[file1_path] = np.nan
            else:
                lai_cache[file1_path] = np.nan
                lai1 = np.nan
        else:
            lai1 = np.nan
            
        # 检查并读取下个月份文件
        lai2 = np.nan
        file2_path = f"E:\\data\\GLASS LAI\\mat\\0.1Deg\\Dataset\\{next_month[0]:04d}-{next_month[1]:02d}-01.tif.mat"
        
        if file2_path in lai_cache:
            lai2 = lai_cache[file2_path]
        elif os.path.exists(file2_path):
            lai_data2 = read_mat_file(file2_path, ['lai'], silent=True)
            if lai_data2 and 'lai' in lai_data2:
                matrix = lai_data2['lai']
                row_index = int(df.at[i, 'row'])
                col_index = int(df.at[i, 'col'])
                
                try:
                    lai2 = matrix[row_index, col_index]
                    lai_cache[file2_path] = lai2
                except:
                    lai_cache[file2_path] = np.nan
            else:
                lai_cache[file2_path] = np.nan
                lai2 = np.nan
        else:
            lai2 = np.nan
            
        # 线性插值计算最终LAI值
        if not np.isnan(lai1) and not np.isnan(lai2):
            # 使用权重进行线性插值
            lai_final = (1 - weight) * lai1 + weight * lai2
        elif not np.isnan(lai1):
            lai_final = lai1
        elif not np.isnan(lai2):
            lai_final = lai2
        else:
            lai_final = np.nan
            
        df.at[i, 'LAI_Satellite'] = lai_final
    
    # ===================================================================
    # 步骤4: 添加Hveg数据
    # ===================================================================
    print("添加Hveg数据...")
    df['Hveg_Satellite'] = np.nan
    
    hveg_file = "E:\\data\\CanopyHeight\\CH.mat"
    if os.path.exists(hveg_file):
        hveg_data = read_mat_file(hveg_file, ['Hveg'])
        if hveg_data and 'Hveg' in hveg_data:
            matrix = hveg_data['Hveg']
            
            # 填充数据
            for i in df.index:
                row_index = int(df.at[i, 'row'])
                col_index = int(df.at[i, 'col'])
                
                try:
                    df.at[i, 'Hveg_Satellite'] = matrix[row_index, col_index]
                except (IndexError, ValueError):
                    # 保留NaN值
                    pass
    
    print(f"处理完成, 最终数据行数: {len(df)}")
    return df

# 主处理过程
if __name__ == "__main__":
    input_dir = r'E:\data\VWC\test-VWC\SMEX02_CLASIC07_SMEX08_SMAPVEX16'
    input_file = os.path.join(input_dir, 'InsituData_Pixel.xlsx')
    output_file = os.path.join(input_dir, 'InsituData_Pixel_ML.xlsx')
    
    # 确保输出目录存在
    os.makedirs(input_dir, exist_ok=True)
    
    # 处理每个sheet
    sheet_names = ['SMEX02', 'CLASIC07', 'SMEX08', 'SMAPVEX16']
    output_dfs = {}
    
    # 读取输入Excel文件
    for sheet in sheet_names:
        print(f"\n{'='*50}")
        print(f"处理数据集: {sheet}")
        try:
            df = pd.read_excel(input_file, sheet_name=sheet, engine='openpyxl')
            
            # 确保有足够的行
            if len(df) == 0:
                print(f"警告: {sheet} 中没有数据")
                output_dfs[sheet] = pd.DataFrame()
                continue
            
            processed_df = process_dataset(df, sheet)
            output_dfs[sheet] = processed_df
            
            # 添加虚拟行避免保存错误
            if processed_df.empty:
                # 创建至少一行数据防止ExcelWriter错误
                processed_df = pd.DataFrame(columns=df.columns)
                processed_df.loc[0] = [None] * len(processed_df.columns)
                
        except Exception as e:
            print(f"处理 {sheet} 时出错: {str(e)}")
            # 创建空DataFrame但有列名防止保存错误
            try:
                df = pd.read_excel(input_file, sheet_name=sheet, nrows=0, engine='openpyxl')
                output_dfs[sheet] = df
            except:
                output_dfs[sheet] = pd.DataFrame(columns=['Date', 'row', 'col'])
    
    # 保存结果
    with pd.ExcelWriter(output_file, engine='openpyxl') as writer:
        for sheet, df in output_dfs.items():
            if not df.empty:
                print(f"保存 '{sheet}' 到Excel文件 ({len(df)} 行)")
                df.to_excel(writer, sheet_name=sheet, index=False)
            else:
                # 创建空但有列名的sheet
                print(f"{sheet} 无有效数据，创建空工作表")
                empty_df = pd.DataFrame(columns=df.columns)
                empty_df.to_excel(writer, sheet_name=sheet, index=False)
    
    print(f"\n{'='*50}")
    print(f"处理完成! 结果已保存至: {output_file}")
    if os.path.exists(output_file):
        print(f"文件大小: {os.path.getsize(output_file)/1024/1024:.2f} MB")
    print("="*50)


处理数据集: SMEX02

开始处理 SMEX02 数据集...
原始数据行数: 16
添加VOD及SM数据...
添加PFT数据...
添加LAI数据...
添加Hveg数据...
处理完成, 最终数据行数: 16

处理数据集: CLASIC07

开始处理 CLASIC07 数据集...
原始数据行数: 18
添加VOD及SM数据...
添加PFT数据...
添加LAI数据...
添加Hveg数据...
处理完成, 最终数据行数: 18

处理数据集: SMEX08
处理 SMEX08 时出错: Worksheet named 'SMEX08' not found

处理数据集: SMAPVEX16

开始处理 SMAPVEX16 数据集...
原始数据行数: 115
添加VOD及SM数据...
添加PFT数据...
添加LAI数据...
添加Hveg数据...
处理完成, 最终数据行数: 115
保存 'SMEX02' 到Excel文件 (16 行)
保存 'CLASIC07' 到Excel文件 (18 行)
SMEX08 无有效数据，创建空工作表
保存 'SMAPVEX16' 到Excel文件 (115 行)

处理完成! 结果已保存至: E:\data\VWC\test-VWC\SMEX02_CLASIC07_SMEX08_SMAPVEX16\InsituData_Pixel_ML.xlsx
文件大小: 0.04 MB


In [10]:
##  添加迁移学习提升精度
# 散点图（4个数据画在一块，写出n，按照波段-极化组合绘制为3*3）
# 点形状及颜色：
# SMEX02：*；CLASIC07：^；SMEX08：+；SMAPVEX16：o

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib as mpl
import matplotlib.font_manager as fm
import joblib
import os
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')
from sklearn.metrics import mean_squared_error

# 常量定义
BANDS = ['Ku', 'X', 'C']
POLS = ['H', 'V', 'HV']
SHEET_NAMES = ['SMEX02', 'CLASIC07', 'SMAPVEX08', 'SMAPVEX16']
VWC_COLUMNS = {
    'SMEX02': 'VWC-Field',
    'CLASIC07': 'VWC (kg/m²)',
    'SMAPVEX08': 'VWC',
    'SMAPVEX16': 'PLANT_WATER_CONTENT_AREA'
}

# 标记和颜色设置
MARKER_STYLES = {
    'SMEX02': {'marker': 'x', 'color': '#F8766D'},
    'CLASIC07': {'marker': '^', 'facecolor': 'none', 'edgecolor': '#00BFC4'},
    'SMAPVEX08': {'marker': '+', 'color': '#C77CFF'},
    'SMAPVEX16': {'marker': 'o', 'facecolor': 'none', 'edgecolor': '#7CAE00'}
}

# 设置全局字体
plt.rcParams['font.family'] = 'Arial'
plt.rcParams['font.weight'] = 'bold'

def load_and_preprocess_data(file_path):
    """
    加载并预处理Excel文件中的所有sheet
    
    参数:
    file_path (str): Excel文件路径
    
    返回:
    dict: 包含预处理后数据的字典，键为sheet名称
    """
    print(f"加载文件: {file_path}")
    data_dict = {}
    
    for sheet in SHEET_NAMES:
        try:
            df = pd.read_excel(file_path, sheet_name=sheet)
            print(f"  - {sheet}: {len(df)}行")
            
            # 替换SM_Satellite和LAI_Satellite（如果存在地面实测数据）
            if 'SM' in df.columns:
                mask = df['SM'].notna()
                df.loc[mask, 'SM_Satellite'] = df.loc[mask, 'SM']
                print(f"    替换了 {mask.sum()} 行SM_Satellite数据")
            
            if 'LAI' in df.columns:
                mask = df['LAI'].notna()
                df.loc[mask, 'LAI_Satellite'] = df.loc[mask, 'LAI']
                print(f"    替换了 {mask.sum()} 行LAI_Satellite数据")
            
            data_dict[sheet] = df
        except Exception as e:
            print(f"  加载 {sheet} 时出错: {str(e)}")
            data_dict[sheet] = pd.DataFrame()
    
    return data_dict

def get_features_for_model(band, pol):
    """
    根据波段和极化类型获取特征列表（使用模型训练时的名称）
    
    参数:
    band (str): 波段 ('Ku', 'X', 'C')
    pol (str): 极化类型 ('H', 'V', 'HV')
    
    返回:
    list: 特征列名列表
    """
    # 使用模型训练时的特征名称
    features = [
        'LAI',  # 注意：训练时使用"LAI"而不是"LAI_Satellite"
        'SM',   # 注意：训练时使用"SM"而不是"SM_Satellite"
        'Grass_man', 
        'Grass_nat',
        'Shrub_bd', 'Shrub_be', 'Shrub_nd', 'Shrub_ne',
        'Tree_bd', 'Tree_be', 'Tree_nd', 'Tree_ne'
    ]
    
    # 添加VOD特征 - 根据模型类型
    if pol == 'H' or pol == 'V':
        # 单极化模型使用"VOD"
        features.append('VOD')
    elif pol == 'HV':
        # 双极化模型使用"VOD-Hpol"和"VOD-Vpol"
        features.extend(['VOD-Hpol', 'VOD-Vpol'])
    
    return features

##  新增迁移学习相关参数
FINE_TUNE_PARAMS = {
    'n_estimators': 50,        # 微调时新增树的数量（保持原有树不变）
    'max_depth': None,          # 新树的最大深度
    'max_samples': 0.8,         # 每个新树的样本采样比例
    'random_state': 42,         # 随机种子
    'n_jobs': -1,               # 使用所有CPU核心
}

def load_model_with_finetuning(model_path):
    """
    加载模型并为迁移学习做准备
    返回值：原始模型、是否支持迁移学习的标志
    """
    if not os.path.exists(model_path):
        print(f"  模型文件不存在: {model_path}")
        return None, False
        
    try:
        model = joblib.load(model_path)
        print(f"  成功加载原始模型，支持迁移学习: {hasattr(model, 'fit')}")
        return model, hasattr(model, 'fit')  # 仅支持可训练的模型
    except Exception as e:
        print(f"  加载模型失败: {str(e)}")
        return None, False

def predict_vwc(data_dict, band, pol):
    """
    使用指定模型预测VWC，包括特征归一化和迁移学习
    
    参数:
    data_dict (dict): 包含所有sheet数据的字典
    band (str): 波段 ('Ku', 'X', 'C')
    pol (str): 极化类型 ('H', 'V', 'HV')
    
    返回:
    dict: 包含每个sheet预测结果的字典
    """
    # 迁移学习配置
    FINE_TUNE_PARAMS = {
        'n_estimators': 50,        # 微调时新增树的数量
        'max_depth': 8,            # 新树的最大深度
        'max_samples': 0.8,        # 每个新树的样本采样比例
        'random_state': 42,        # 随机种子
        'n_jobs': -1,              # 使用所有CPU核心
    }
    
    # 加载模型
    model_path = f"models/RFR_{band}_{pol}pol_Type1.pkl"
    print(f"加载模型: {model_path}")
    
    if not os.path.exists(model_path):
        print(f"  模型文件不存在: {model_path}")
        return {}
    
    try:
        model = joblib.load(model_path)
        # 打印模型训练时的特征名称
        if hasattr(model, 'feature_names_in_'):
            print(f"  模型训练特征: {list(model.feature_names_in_)}")
    except Exception as e:
        print(f"  加载模型失败: {str(e)}")
        return {}
    
    # 获取特征列表
    features = get_features_for_model(band, pol)
    
    # 存储预测结果
    predictions = {}
    # 存储所有数据集特征统计信息
    feature_ranges = {}
    
    # 第一步：收集所有数据集的特征数据用于微调
    finetune_data = {}
    
    for sheet, df in data_dict.items():
        if df.empty:
            continue
            
        # 创建特征映射
        feature_mapping = {}
        for feature in features:
            if feature == 'VOD':
                if pol == 'H':
                    if band == 'Ku':
                        feature_mapping['ku_vod_H'] = 'VOD'
                    elif band == 'X':
                        feature_mapping['x_vod_H'] = 'VOD'
                    elif band == 'C':
                        feature_mapping['c_vod_H'] = 'VOD'
                elif pol == 'V':
                    if band == 'Ku':
                        feature_mapping['ku_vod_V'] = 'VOD'
                    elif band == 'X':
                        feature_mapping['x_vod_V'] = 'VOD'
                    elif band == 'C':
                        feature_mapping['c_vod_V'] = 'VOD'
            elif feature == 'VOD-Hpol':
                if band == 'Ku':
                    feature_mapping['ku_vod_H'] = 'VOD-Hpol'
                elif band == 'X':
                    feature_mapping['x_vod_H'] = 'VOD-Hpol'
                elif band == 'C':
                    feature_mapping['c_vod_H'] = 'VOD-Hpol'
            elif feature == 'VOD-Vpol':
                if band == 'Ku':
                    feature_mapping['ku_vod_V'] = 'VOD-Vpol'
                elif band == 'X':
                    feature_mapping['x_vod_V'] = 'VOD-Vpol'
                elif band == 'C':
                    feature_mapping['c_vod_V'] = 'VOD-Vpol'
            else:
                if feature == 'LAI':
                    feature_mapping['LAI_Satellite'] = 'LAI'
                elif feature == 'SM':
                    feature_mapping['SM_Satellite'] = 'SM'
                elif feature == 'Grass_man':
                    feature_mapping['grassman'] = 'Grass_man'
                elif feature == 'Grass_nat':
                    feature_mapping['grassnat'] = 'Grass_nat'
                elif feature == 'Shrub_bd':
                    feature_mapping['shrubbd'] = 'Shrub_bd'
                elif feature == 'Shrub_be':
                    feature_mapping['shrubbe'] = 'Shrub_be'
                elif feature == 'Shrub_nd':
                    feature_mapping['shrubnd'] = 'Shrub_nd'
                elif feature == 'Shrub_ne':
                    feature_mapping['shrubne'] = 'Shrub_ne'
                elif feature == 'Tree_bd':
                    feature_mapping['treebd'] = 'Tree_bd'
                elif feature == 'Tree_be':
                    feature_mapping['treebe'] = 'Tree_be'
                elif feature == 'Tree_nd':
                    feature_mapping['treend'] = 'Tree_nd'
                elif feature == 'Tree_ne':
                    feature_mapping['treene'] = 'Tree_ne'
        
        # 检查是否包含所有必要特征
        missing_features = []
        for data_feature in feature_mapping.keys():
            if data_feature not in df.columns:
                missing_features.append(data_feature)
        
        if missing_features:
            print(f"  {sheet} 缺少特征: {', '.join(missing_features)}")
            continue
        
        # 准备数据（使用重命名的特征）
        X = df[list(feature_mapping.keys())].copy()
        X.columns = [feature_mapping[col] for col in X.columns]
        
        # 确保特征顺序与模型期望一致
        if hasattr(model, 'feature_names_in_'):
            X = X[list(model.feature_names_in_)]
        
        # 记录原始特征范围（归一化前）
        if sheet not in feature_ranges:
            feature_ranges[sheet] = {}
        
        for col in X.columns:
            if col not in feature_ranges[sheet]:
                min_val = X[col].min()
                max_val = X[col].max()
                feature_ranges[sheet][col] = (min_val, max_val)
                print(f"  {sheet} {col} 原始范围: [{min_val:.4f}, {max_val:.4f}]")
        
        # 应用归一化处理
        vod_features = ['VOD', 'VOD-Hpol', 'VOD-Vpol']
        for vod_feature in vod_features:
            if vod_feature in X.columns:
                X[vod_feature] = X[vod_feature].clip(0, 2) / 2.0
                # 记录归一化后范围
                min_val = X[vod_feature].min()
                max_val = X[vod_feature].max()
                print(f"  {sheet} {vod_feature} 归一化后范围: [{min_val:.4f}, {max_val:.4f}]")
        
        if 'LAI' in X.columns:
            X['LAI'] = X['LAI'].clip(0, 6) / 6.0
            min_val = X['LAI'].min()
            max_val = X['LAI'].max()
            print(f"  {sheet} LAI 归一化后范围: [{min_val:.4f}, {max_val:.4f}]")
        
        pft_features = [
            'Grass_man', 'Grass_nat',
            'Shrub_bd', 'Shrub_be', 'Shrub_nd', 'Shrub_ne',
            'Tree_bd', 'Tree_be', 'Tree_nd', 'Tree_ne'
        ]
        
        for pft_feature in pft_features:
            if pft_feature in X.columns:
                X[pft_feature] = X[pft_feature] / 100.0
                min_val = X[pft_feature].min()
                max_val = X[pft_feature].max()
                print(f"  {sheet} {pft_feature} 归一化后范围: [{min_val:.4f}, {max_val:.4f}]")
        
        # 移除缺失值
        initial_count = len(X)
        X = X.dropna()
        removed_count = initial_count - len(X)
        if removed_count > 0:
            print(f"  {sheet} 移除了 {removed_count} 行包含缺失值的数据")
        
        if X.empty:
            print(f"  {sheet} 无有效数据可用于预测")
            continue
        
        y = df.loc[X.index, VWC_COLUMNS[sheet]]
        
        # 存储用于迁移学习的完整数据集
        finetune_data[sheet] = {
            'X': X,
            'y': y
        }
    
    # 打印整体特征范围摘要
    print("\n===== 输入特征范围摘要 =====")
    for sheet, ranges in feature_ranges.items():
        print(f"\n数据集: {sheet}")
        for feature, (min_val, max_val) in ranges.items():
            print(f"  {feature}: [{min_val:.4f}, {max_val:.4f}]")
    print("===========================\n")
    
    # 第二步：使用原始模型进行预测
    print("===== 使用原始模型预测 =====")
    for sheet, data in finetune_data.items():
        X = data['X']
        y = data['y']
        
        y_pred = model.predict(X)
        predictions[sheet] = {
            'actual': y,
            'predicted': y_pred,
            'finetuned': False,  # 标记为原始模型预测
            'source': sheet,
            'row': data['X'].index,
            'lat': df.loc[X.index, 'Latitude'] if 'Latitude' in df.columns else None,
            'lon': df.loc[X.index, 'Longitude'] if 'Longitude' in df.columns else None,
            'date': df.loc[X.index, 'Date'] if 'Date' in df.columns else None
        }
        
        rmse = calculate_rmse(y, y_pred)
        print(f"  {sheet} 原始模型预测RMSE: {rmse:.4f}")
    
    # 第三步：应用迁移学习（微调）
    def finetune_model(original_model, X_finetune, y_finetune):
        """在目标数据集上微调模型"""
        from copy import deepcopy
        
        # 创建新模型继承原模型参数
        new_model = deepcopy(original_model)
        
        # 获取原始模型的bootstrap设置
        bootstrap_setting = original_model.get_params().get('bootstrap', True)
        
        # 确保参数兼容性
        max_samples_setting = FINE_TUNE_PARAMS['max_samples'] if bootstrap_setting else None
        
        # 修改参数进行增量训练
        new_model.set_params(
            warm_start=True,
            n_estimators=original_model.n_estimators + FINE_TUNE_PARAMS['n_estimators'],
            max_depth=FINE_TUNE_PARAMS['max_depth'],
            bootstrap=bootstrap_setting,  # 保持原始设置
            max_samples=max_samples_setting,  # 只在启用bootstrap时设置
            random_state=FINE_TUNE_PARAMS['random_state'],
            n_jobs=FINE_TUNE_PARAMS['n_jobs']
        )
        
        print(f"    微调模型: 新增{FINE_TUNE_PARAMS['n_estimators']}棵树，使用{len(X_finetune)}个样本")
        new_model.fit(X_finetune, y_finetune)
        
        return new_model
    
    # 检查是否有足够的数据进行微调
    finetune_samples = sum(len(data['X']) for data in finetune_data.values())
    
    if finetune_samples >= 20:  # 至少有20个样本才进行微调
        print("\n===== 应用迁移学习 (微调) =====")
        print(f"  合并数据量: {finetune_samples}个样本")
        
        # 合并所有数据集进行微调
        X_finetune = pd.concat([data['X'] for data in finetune_data.values()])
        y_finetune = pd.concat([data['y'] for data in finetune_data.values()])
        
        # 微调模型
        finetuned_model = finetune_model(model, X_finetune, y_finetune)
        
        # 第四步：使用微调后的模型进行预测
        print("\n===== 使用微调模型预测 =====")
        for sheet, data in finetune_data.items():
            X = data['X']
            y = data['y']
            
            y_pred_finetuned = finetuned_model.predict(X)
            
            # 更新预测结果
            predictions[sheet] = {
                'actual': y,
                'predicted': y_pred_finetuned,
                'finetuned': True,  # 标记为微调模型预测
                'source': sheet,
                'row': data['X'].index,
                'lat': df.loc[X.index, 'Latitude'] if 'Latitude' in df.columns else None,
                'lon': df.loc[X.index, 'Longitude'] if 'Longitude' in df.columns else None,
                'date': df.loc[X.index, 'Date'] if 'Date' in df.columns else None
            }
            
            rmse = calculate_rmse(y, y_pred_finetuned)
            rmse_diff = calculate_rmse(y, predictions[sheet].get('predicted_original', y_pred_finetuned)) - rmse
            print(f"  {sheet} 微调模型预测RMSE: {rmse:.4f} (改进: {rmse_diff:.4f})")
    else:
        print(f"  样本不足({finetune_samples}<20)，跳过迁移学习")
    
    return predictions

def create_scatter_plots(all_predictions):
    """
    创建3x3散点子图，只显示微调模型结果
    """
    print("创建散点图...")
    
    # 创建图形
    fig = plt.figure(figsize=(18, 18))
    gs = gridspec.GridSpec(3, 3, figure=fig)
    
    # 设置全局标题
    fig.suptitle('VWC预测结果（迁移学习模型）', fontsize=24, fontweight='bold', y=0.95)
    
    # 遍历所有波段和极化组合
    for i, band in enumerate(BANDS):
        for j, pol in enumerate(POLS):
            ax = fig.add_subplot(gs[i, j])
            
            # 获取当前组合的预测结果
            predictions = all_predictions.get((band, pol), {})
            
            # 收集所有数据点
            all_actual = []
            all_predicted = []
            
            # 绘制每个sheet的数据点（只显示微调模型结果）
            for sheet in SHEET_NAMES:
                if sheet in predictions and predictions[sheet]['finetuned']:
                    data = predictions[sheet]
                    actual = data['actual']
                    predicted = data['predicted']
                    
                    # 添加到总集合
                    all_actual.extend(actual)
                    all_predicted.extend(predicted)
                    
                    # 获取该数据集的标记样式
                    style = MARKER_STYLES[sheet]
                    
                    # 创建统一的标记参数
                    scatter_params = {
                        'marker': style['marker'],
                        's': 60,  # 点大小
                        'alpha': 0.8,
                    }
                    
                    # 为特殊数据集设置空心点
                    if sheet in ['CLASIC07', 'SMAPVEX16']:
                        scatter_params.update({
                            'facecolor': 'none',
                            'edgecolor': style.get('edgecolor', 'black')
                        })
                    else:
                        scatter_params['color'] = style.get('color', 'black')
                    
                    # 绘制点
                    ax.scatter(actual, predicted, **scatter_params)
            
            # 如果没有数据，跳过
            if not all_actual:
                ax.text(0.5, 0.5, '无数据', 
                        horizontalalignment='center', 
                        verticalalignment='center', 
                        transform=ax.transAxes,
                        fontsize=16)
                ax.set_title(f"{band}-{pol}", fontsize=16, fontweight='bold')
                continue
            
            # 计算整体RMSE
            rmse = calculate_rmse(np.array(all_actual), np.array(all_predicted))
            
            # 添加1:1参考线
            max_val = max(max(all_actual), max(all_predicted)) * 1.05
            ax.plot([0, max_val], [0, max_val], 'k--', lw=1.5, alpha=0.7)
            
            # 设置坐标轴范围
            ax.set_xlim(0, max_val)
            ax.set_ylim(0, max_val)
            
            # 设置坐标轴标签
            if i == 2:  # 最后一行
                ax.set_xlabel('实测VWC (kg/m²)', fontsize=14, fontweight='bold')
            if j == 0:  # 第一列
                ax.set_ylabel('预测VWC (kg/m²)', fontsize=14, fontweight='bold')
            
            # 添加标题和RMSE
            ax.set_title(f"{band}-{pol}波段", fontsize=16, fontweight='bold')
            ax.text(0.05, 0.95, f"RMSE: {rmse:.3f} kg/m²", 
                    transform=ax.transAxes,
                    fontsize=16,
                    fontweight='bold',
                    verticalalignment='top')
            
            # 添加网格
            ax.grid(True, linestyle='--', alpha=0.3)
    
    # 添加图例（只显示数据集）
    handles = []
    labels = []
    
    for sheet in SHEET_NAMES:
        style = MARKER_STYLES[sheet]
        if sheet in ['CLASIC07', 'SMAPVEX16']:
            handles.append(plt.Line2D([0], [0], 
                                     marker=style['marker'], 
                                     color='w',
                                     markerfacecolor='none',
                                     markeredgecolor=style.get('edgecolor', 'black'),
                                     markersize=10,
                                     markeredgewidth=1.0))
        else:
            handles.append(plt.Line2D([0], [0], 
                                     marker=style['marker'], 
                                     color='w', 
                                     markerfacecolor=style.get('color', 'black'),
                                     markersize=10))
        labels.append(sheet)
    
    fig.legend(handles, labels, 
               loc='lower center', 
               ncol=len(SHEET_NAMES), 
               fontsize=12,
               frameon=True,
               fancybox=True,
               shadow=True,
               bbox_to_anchor=(0.5, 0.02))
    
    # 调整布局
    plt.tight_layout(rect=[0, 0.05, 1, 0.95])
    
    # 保存图像
    output_dir = Path(r"E:\data\VWC\test-VWC\SMEX02_CLASIC07_SMEX08_SMAPVEX16")
    output_dir.mkdir(parents=True, exist_ok=True)
    fig_path = "figures/AllSMAPInsituData_VWC_Scatter_FT.png"
    plt.savefig(fig_path, dpi=300, bbox_inches='tight')
    print(f"散点图已保存至: {fig_path}")
    plt.close()

def calculate_rmse(actual, predicted):
    """
    计算RMSE
    
    参数:
    actual (array-like): 实际值
    predicted (array-like): 预测值
    
    返回:
    float: RMSE值
    """
    return np.sqrt(np.mean((actual - predicted)**2))

def save_prediction_details(all_predictions):
    """
    将预测结果保存到Excel文件中
    
    参数:
    all_predictions (dict): 包含所有波段和极化组合预测结果的字典
    """
    output_dir = Path(r"E:\data\VWC\test-VWC\SMEX02_CLASIC07_SMEX08_SMAPVEX16")
    output_file = output_dir / "details.xlsx"
    
    # 创建Excel写入器
    with pd.ExcelWriter(output_file, engine='openpyxl') as writer:
        # 遍历所有波段和极化组合
        for (band, pol), predictions in all_predictions.items():
            if not predictions:
                continue
                
            # 创建当前组合的数据框
            all_data = []
            
            # 收集所有sheet的数据
            for sheet, data in predictions.items():
                # 创建当前sheet的数据框
                sheet_df = pd.DataFrame({
                    'Date': data['date'],
                    # 'Row': data['row'],
                    # 'Col': data['col'],
                    'Latitude': data['lat'],
                    'Longitude': data['lon'],
                    'Actual_VWC': data['actual'],
                    'Predicted_VWC': data['predicted'],
                    'Source': data['source']
                })
                
                # 添加波段和极化信息
                sheet_df['Band'] = band
                sheet_df['Polarization'] = pol
                
                all_data.append(sheet_df)
            
            # 合并所有数据
            if all_data:
                combined_df = pd.concat(all_data, ignore_index=True)
                
                # 保存到Excel
                sheet_name = f"{band}_{pol}"
                combined_df.to_excel(writer, sheet_name=sheet_name, index=False)
                print(f"保存预测结果到: {sheet_name} ({len(combined_df)}行)")
    
    print(f"所有预测结果已保存至: {output_file}")

def main():
    # 输入文件路径
    input_file = r"E:\data\VWC\test-VWC\SMEX02_CLASIC07_SMEX08_SMAPVEX16\InsituData_Pixel_ML.xlsx"
    
    # 加载并预处理数据
    data_dict = load_and_preprocess_data(input_file)
    
    # 存储所有预测结果
    all_predictions = {}
    
    # 遍历所有波段和极化组合
    for band in BANDS:
        for pol in POLS:
            print(f"\n处理波段-极化组合: {band}-{pol}")
            predictions = predict_vwc(data_dict, band, pol)
            all_predictions[(band, pol)] = predictions
    
    # 创建散点图
    create_scatter_plots(all_predictions)
    
    # 保存预测结果到Excel
    save_prediction_details(all_predictions)
    
    print("\n处理完成!")

if __name__ == "__main__":
    main()

加载文件: E:\data\VWC\test-VWC\SMEX02_CLASIC07_SMEX08_SMAPVEX16\InsituData_Pixel_ML.xlsx
  - SMEX02: 16行
    替换了 14 行SM_Satellite数据
  - CLASIC07: 18行
  - SMAPVEX08: 6行
    替换了 6 行LAI_Satellite数据
  - SMAPVEX16: 115行

处理波段-极化组合: Ku-H
加载模型: models/RFR_Ku_Hpol_Type1.pkl
  模型训练特征: ['VOD', 'LAI', 'SM', 'Grass_man', 'Grass_nat', 'Shrub_bd', 'Shrub_be', 'Shrub_nd', 'Shrub_ne', 'Tree_bd', 'Tree_be', 'Tree_nd', 'Tree_ne']
  SMEX02 VOD 原始范围: [0.5369, 1.0814]
  SMEX02 LAI 原始范围: [2.7000, 3.4800]
  SMEX02 SM 原始范围: [0.1133, 0.2167]
  SMEX02 Grass_man 原始范围: [93.1111, 99.3827]
  SMEX02 Grass_nat 原始范围: [0.6173, 4.8974]
  SMEX02 Shrub_bd 原始范围: [0.0000, 0.0000]
  SMEX02 Shrub_be 原始范围: [0.0000, 0.0000]
  SMEX02 Shrub_nd 原始范围: [0.0000, 0.0000]
  SMEX02 Shrub_ne 原始范围: [0.0000, 0.0000]
  SMEX02 Tree_bd 原始范围: [0.0000, 0.9398]
  SMEX02 Tree_be 原始范围: [0.0000, 0.0000]
  SMEX02 Tree_nd 原始范围: [0.0000, 0.0000]
  SMEX02 Tree_ne 原始范围: [0.0000, 0.0000]
  SMEX02 VOD 归一化后范围: [0.2684, 0.5407]
  SMEX02 LAI 归一化后范围: [0.4500, 0.5

In [6]:
# 绘制上述4个数据集所填充自变量的情况
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')
import matplotlib.font_manager as fm

# 设置全局样式
plt.rcParams['font.family'] = 'Arial'
plt.rcParams['font.weight'] = 'bold'
plt.rcParams['axes.labelweight'] = 'bold'
plt.rcParams['axes.titleweight'] = 'bold'
sns.set_style("whitegrid")

# 常量定义
BANDS = ['Ku', 'X', 'C']
POLS = ['H', 'V', 'HV']
SHEET_NAMES = ['SMEX02', 'CLASIC07', 'SMAPVEX08', 'SMAPVEX16']
VWC_COLUMNS = {
    'SMEX02': 'VWC-Field',
    'CLASIC07': 'VWC (kg/m²)',
    'SMAPVEX08': 'VWC',
    'SMAPVEX16': 'PLANT_WATER_CONTENT_AREA'
}

# 颜色设置
DATASET_COLORS = {
    'SMEX02': '#F8766D',
    'CLASIC07': '#00BFC4',
    'SMAPVEX08': '#C77CFF',
    'SMAPVEX16': '#7CAE00'
}

def load_and_preprocess_data(file_path):
    """
    Load and preprocess all sheets in Excel file
    """
    print(f"Loading file: {file_path}")
    data_dict = {}
    
    for sheet in SHEET_NAMES:
        try:
            df = pd.read_excel(file_path, sheet_name=sheet)
            print(f"  - {sheet}: {len(df)} rows")
            
            # Replace SM_Satellite and LAI_Satellite if ground measurements exist
            if 'SM' in df.columns:
                mask = df['SM'].notna()
                df.loc[mask, 'SM_Satellite'] = df.loc[mask, 'SM']
                print(f"    Replaced {mask.sum()} rows of SM_Satellite data")
            
            if 'LAI' in df.columns:
                mask = df['LAI'].notna()
                df.loc[mask, 'LAI_Satellite'] = df.loc[mask, 'LAI']
                print(f"    Replaced {mask.sum()} rows of LAI_Satellite data")
            
            # Merge PFT variables
            grass_cols = [col for col in df.columns if 'grass' in col.lower()]
            shrub_cols = [col for col in df.columns if 'shrub' in col.lower()]
            tree_cols = [col for col in df.columns if 'tree' in col.lower()]
            
            if grass_cols:
                df['grass'] = df[grass_cols].sum(axis=1)
            if shrub_cols:
                df['shrub'] = df[shrub_cols].sum(axis=1)
            if tree_cols:
                df['tree'] = df[tree_cols].sum(axis=1)
            
            data_dict[sheet] = df
        except Exception as e:
            print(f"  Error loading {sheet}: {str(e)}")
            data_dict[sheet] = pd.DataFrame()
    
    return data_dict

def plot_feature_distributions(data_dict):
    """
    Plot distributions of features across datasets
    """
    print("\nPlotting feature distributions...")
    
    # Create output directory
    output_dir = Path("figures")
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Key features list - using merged PFT variables
    key_features = [
        'LAI_Satellite', 
        'SM_Satellite',
        'grass',       # Merged grass fraction
        'shrub',       # Merged shrub fraction
        'tree',        # Merged tree fraction
        'ku_vod_H',    # Ku-band H-pol VOD
        'x_vod_H',     # X-band H-pol VOD
        'c_vod_H',     # C-band H-pol VOD
        'VWC_Measured' # Measured VWC
    ]
    
    # Create figure - 3x3 grid
    fig, axes = plt.subplots(3, 3, figsize=(18, 15))
    axes = axes.flatten()
    
    # Create subplot for each feature
    for i, feature in enumerate(key_features):
        if i >= len(axes):  # Prevent index out of range
            break
            
        ax = axes[i]
        all_data = []
        
        # Collect data from all datasets
        for sheet in SHEET_NAMES:
            if sheet in data_dict:
                df = data_dict[sheet]
                
                # Special handling for measured VWC
                if feature == 'VWC_Measured':
                    vwc_col = VWC_COLUMNS.get(sheet)
                    if vwc_col and vwc_col in df.columns:
                        values = df[vwc_col].dropna()
                        if not values.empty:
                            temp_df = pd.DataFrame({
                                'Value': values,
                                'Dataset': sheet
                            })
                            all_data.append(temp_df)
                else:
                    if feature in df.columns:
                        values = df[feature].dropna()
                        if not values.empty:
                            temp_df = pd.DataFrame({
                                'Value': values,
                                'Dataset': sheet
                            })
                            all_data.append(temp_df)
        
        if not all_data:
            ax.text(0.5, 0.5, 'No data', 
                    horizontalalignment='center', 
                    verticalalignment='center', 
                    transform=ax.transAxes,
                    fontsize=12)
            ax.set_title(feature, fontsize=14)
            continue
        
        # Combine all data
        combined_df = pd.concat(all_data, ignore_index=True)
        
        # Plot violin plot
        sns.violinplot(
            x='Dataset', 
            y='Value', 
            data=combined_df,
            ax=ax,
            palette=DATASET_COLORS,
            inner="quartile",  # Show quartiles
            cut=0  # Limit to data range
        )
        
        # Add data points
        sns.stripplot(
            x='Dataset', 
            y='Value', 
            data=combined_df,
            ax=ax,
            color='black',
            alpha=0.3,
            jitter=True
        )
        
        # Set titles and labels
        if feature == 'VWC_Measured':
            ax.set_title('Measured VWC', fontsize=14, fontweight='bold')
            ax.set_ylabel('VWC (kg/m²)')
        else:
            ax.set_title(feature, fontsize=14, fontweight='bold')
            ax.set_ylabel('Value')
        
        ax.set_xlabel('Dataset')
        
        # Rotate x-axis labels
        plt.setp(ax.get_xticklabels(), rotation=45)
    
    # Hide last empty subplot (9th)
    if len(key_features) < len(axes):
        axes[len(key_features)].axis('off')
    
    # Adjust layout
    plt.tight_layout()
    
    # Add main title
    fig.suptitle('Feature Distributions Across Datasets', fontsize=18, fontweight='bold', y=0.98)
    
    # Save figure
    fig_path = output_dir / "Dataset_Feature_Distributions.png"
    plt.savefig(fig_path, dpi=300, bbox_inches='tight')
    print(f"Feature distributions plot saved to: {fig_path}")
    plt.close()

def plot_feature_statistics(data_dict):
    """
    Plot feature statistics comparison
    """
    print("\nPlotting feature statistics...")
    
    # Create output directory
    output_dir = Path("figures")
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Key features list - using merged PFT variables
    key_features = [
        'LAI_Satellite', 
        'SM_Satellite',
        'grass',       # Merged grass fraction
        'shrub',       # Merged shrub fraction
        'tree',        # Merged tree fraction
        'ku_vod_H',    # Ku-band H-pol VOD
        'x_vod_H',     # X-band H-pol VOD
        'c_vod_H'      # C-band H-pol VOD
    ]
    
    # Collect statistics
    stats_data = []
    
    for sheet in SHEET_NAMES:
        if sheet in data_dict:
            df = data_dict[sheet]
            for feature in key_features:
                if feature in df.columns:
                    values = df[feature].dropna()
                    if not values.empty:
                        stats_data.append({
                            'Dataset': sheet,
                            'Feature': feature,
                            'Mean': values.mean(),
                            'Median': values.median(),
                            'StdDev': values.std(),
                            'Min': values.min(),
                            'Max': values.max()
                        })
    
    if not stats_data:
        print("No valid statistics data")
        return
    
    # Create DataFrame
    stats_df = pd.DataFrame(stats_data)
    
    # Create figure
    fig, axes = plt.subplots(2, 2, figsize=(18, 12))
    axes = axes.flatten()
    
    # 1. Mean comparison
    ax = axes[0]
    sns.barplot(
        x='Feature', 
        y='Mean', 
        hue='Dataset', 
        data=stats_df, 
        ax=ax,
        palette=DATASET_COLORS
    )
    ax.set_title('Feature Mean Comparison', fontsize=14, fontweight='bold')
    ax.set_xlabel('Feature')
    ax.set_ylabel('Mean Value')
    ax.legend(title='Dataset')
    
    # 2. Standard deviation comparison
    ax = axes[1]
    sns.barplot(
        x='Feature', 
        y='StdDev', 
        hue='Dataset', 
        data=stats_df, 
        ax=ax,
        palette=DATASET_COLORS
    )
    ax.set_title('Feature Standard Deviation Comparison', fontsize=14, fontweight='bold')
    ax.set_xlabel('Feature')
    ax.set_ylabel('Standard Deviation')
    ax.legend(title='Dataset')
    
    # 3. Range comparison (Max - Min)
    stats_df['Range'] = stats_df['Max'] - stats_df['Min']
    ax = axes[2]
    sns.barplot(
        x='Feature', 
        y='Range', 
        hue='Dataset', 
        data=stats_df, 
        ax=ax,
        palette=DATASET_COLORS
    )
    ax.set_title('Feature Range Comparison', fontsize=14, fontweight='bold')
    ax.set_xlabel('Feature')
    ax.set_ylabel('Range (Max - Min)')
    ax.legend(title='Dataset')
    
    # 4. Median comparison
    ax = axes[3]
    sns.barplot(
        x='Feature', 
        y='Median', 
        hue='Dataset', 
        data=stats_df, 
        ax=ax,
        palette=DATASET_COLORS
    )
    ax.set_title('Feature Median Comparison', fontsize=14, fontweight='bold')
    ax.set_xlabel('Feature')
    ax.set_ylabel('Median Value')
    ax.legend(title='Dataset')
    
    # Adjust layout
    plt.tight_layout()
    
    # Add main title
    fig.suptitle('Feature Statistics Comparison Across Datasets', fontsize=18, fontweight='bold', y=0.98)
    
    # Save figure
    fig_path = output_dir / "Dataset_Feature_Statistics.png"
    plt.savefig(fig_path, dpi=300, bbox_inches='tight')
    print(f"Feature statistics plot saved to: {fig_path}")
    plt.close()

def plot_feature_correlations(data_dict):
    """
    Plot feature correlation heatmaps
    """
    print("\nPlotting feature correlations...")
    
    # Create output directory
    output_dir = Path("figures")
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Key features list - using merged PFT variables
    key_features = [
        'LAI_Satellite', 
        'SM_Satellite',
        'grass',       # Merged grass fraction
        'shrub',       # Merged shrub fraction
        'tree',        # Merged tree fraction
        'ku_vod_H',    # Ku-band H-pol VOD
        'x_vod_H',     # X-band H-pol VOD
        'c_vod_H'      # C-band H-pol VOD
    ]
    
    # Create figure
    fig, axes = plt.subplots(2, 2, figsize=(18, 15))
    axes = axes.flatten()
    
    for i, sheet in enumerate(SHEET_NAMES):
        if i >= len(axes):  # Prevent index out of range
            break
            
        if sheet in data_dict:
            df = data_dict[sheet]
            
            # Select features
            features = [f for f in key_features if f in df.columns]
            
            if len(features) < 2:  # Need at least 2 features for correlation
                ax = axes[i]
                ax.text(0.5, 0.5, 'Insufficient features', 
                        horizontalalignment='center', 
                        verticalalignment='center', 
                        transform=ax.transAxes,
                        fontsize=12)
                ax.set_title(sheet, fontsize=14)
                continue
            
            # Calculate correlation matrix
            corr = df[features].corr()
            
            # Plot heatmap
            ax = axes[i]
            sns.heatmap(
                corr, 
                annot=True, 
                fmt=".2f", 
                cmap="coolwarm", 
                vmin=-1, 
                vmax=1, 
                ax=ax,
                cbar_kws={"shrink": 0.7}
            )
            ax.set_title(f'{sheet} - Feature Correlations', fontsize=14, fontweight='bold')
    
    # Adjust layout
    plt.tight_layout()
    
    # Add main title
    fig.suptitle('Feature Correlations Across Datasets', fontsize=18, fontweight='bold', y=0.98)
    
    # Save figure
    fig_path = output_dir / "Dataset_Feature_Correlations.png"
    plt.savefig(fig_path, dpi=300, bbox_inches='tight')
    print(f"Feature correlations plot saved to: {fig_path}")
    plt.close()

def main():
    # 输入文件路径
    input_file = r"E:\data\VWC\test-VWC\SMEX02_CLASIC07_SMEX08_SMAPVEX16\InsituData_Pixel_ML.xlsx"
    
    # 加载并预处理数据
    data_dict = load_and_preprocess_data(input_file)
    
    # 绘制自变量分布图
    plot_feature_distributions(data_dict)
    
    # 绘制特征统计量对比图
    plot_feature_statistics(data_dict)
    
    # 绘制特征相关性热图
    plot_feature_correlations(data_dict)
    
    print("\n分析完成!")

if __name__ == "__main__":
    main()

Loading file: E:\data\VWC\test-VWC\SMEX02_CLASIC07_SMEX08_SMAPVEX16\InsituData_Pixel_ML.xlsx
  - SMEX02: 16 rows
    Replaced 14 rows of SM_Satellite data
  - CLASIC07: 18 rows
  - SMAPVEX08: 6 rows
    Replaced 6 rows of LAI_Satellite data
  - SMAPVEX16: 115 rows

Plotting feature distributions...
Feature distributions plot saved to: figures\Dataset_Feature_Distributions.png

Plotting feature statistics...
Feature statistics plot saved to: figures\Dataset_Feature_Statistics.png

Plotting feature correlations...
Feature correlations plot saved to: figures\Dataset_Feature_Correlations.png

分析完成!


# .VWC-Sites（多频多角度数据）----2017\2018

In [7]:
# 多频多角度数据填充
import os
import numpy as np
import pandas as pd
import h5py
from datetime import datetime, timedelta
from pathlib import Path
import warnings
import openpyxl
warnings.filterwarnings('ignore')

def latlon_to_rowcol(lat, lon):
    """将经纬度转换为0.1°栅格的行列号"""
    row = int((89.95 - lat) / 0.1)
    col = int((lon + 179.95) / 0.1)
    return row, col

def get_nearest_lai_files(date):
    """获取指定日期前后两个月的LAI文件路径（精确到每月15日）"""
    # 获取当前日期所在月份的前一个月15日
    prev_month_15 = (date.replace(day=1) - timedelta(days=1)).replace(day=15)
    
    # 获取当前日期所在月份的下一个月15日
    next_month_15 = (date.replace(day=28) + timedelta(days=4)).replace(day=15)
    
    # 构建文件路径
    prev_file = Path(f"E:/data/GLASS LAI/mat/0.1Deg/Dataset/{prev_month_15.strftime('%Y-%m')}-01.tif.mat")
    next_file = Path(f"E:/data/GLASS LAI/mat/0.1Deg/Dataset/{next_month_15.strftime('%Y-%m')}-01.tif.mat")
    
    return prev_file, next_file, prev_month_15, next_month_15

def read_mat_v73(file_path, variable_names):
    """
    读取 v7.3 格式的 .mat 文件
    返回字典：{变量名: 矩阵数据}
    """
    data = {}
    try:
        with h5py.File(file_path, 'r') as f:
            for var in variable_names:
                if var in f:
                    dataset = f[var]
                    # 读取数据（不自动转置）
                    matrix = dataset[()]
                    
                    # 确保数据是二维数组
                    if len(matrix.shape) == 2:
                        # 检查形状是否匹配全局常量（1800×3600）
                        if matrix.shape == (1800, 3600):
                            data[var] = matrix
                        elif matrix.shape == (3600, 1800):
                            # 如果是转置的形状，则手动转置
                            data[var] = matrix.T
                        else:
                            # 尝试重塑为正确形状
                            try:
                                data[var] = matrix.reshape(1800, 3600)
                            except:
                                data[var] = np.full((1800, 3600), np.nan)
                    else:
                        data[var] = np.full((1800, 3600), np.nan)
    except Exception as e:
        print(f"警告: 读取文件 {file_path} 时出错: {str(e)}")
        return None
    
    return data

def process_sheet(df, sheet_name, lat, lon, year):
    """处理单个sheet的数据"""
    # 计算固定位置的栅格行列号
    row, col = latlon_to_rowcol(lat, lon)
    
    # 添加位置信息
    df['Latitude'] = lat
    df['Longitude'] = lon
    df['row'] = row
    df['col'] = col
    
    # 创建日期列 - 使用英文列名
    # 检查并删除有空值的行
    initial_count = len(df)
    df = df.dropna(subset=['Year', 'Month', 'Day'])
    removed_count = initial_count - len(df)
    if removed_count > 0:
        print(f"警告: {sheet_name} 中删除了 {removed_count} 行包含空值的行")
    
    # 转换为整数
    df[['Year', 'Month', 'Day']] = df[['Year', 'Month', 'Day']].astype(int)
    
    # 创建日期列
    df['Date'] = pd.to_datetime(df[['Year', 'Month', 'Day']].astype(str).agg('-'.join, axis=1))
    
    # 准备新列
    vod_columns = ['SM', 'ku_vod_H', 'ku_vod_V', 'x_vod_H', 'x_vod_V', 'c_vod_H', 'c_vod_V']
    pft_columns = ['water', 'bare', 'snowice', 'built', 'grassnat', 'grassman', 
                   'shrubbd', 'shrubbe', 'shrubnd', 'shrubne', 'treebd', 'treebe', 'treend', 'treene']
    lai_column = 'lai'  # LAI变量名
    
    for col_name in vod_columns + pft_columns + ['LAI_Satellite', 'Hveg_Satellite']:
        df[col_name] = np.nan
    
    # 加载PFT数据 (一次性加载全年的)
    pft_file = Path(f"E:/data/ESACCI PFT/Resample/Data/{year}.mat")
    if pft_file.exists():
        pft_data = read_mat_v73(pft_file, pft_columns)
        if pft_data:
            for pft_col in pft_columns:
                if pft_col in pft_data:
                    try:
                        # 直接使用行列索引（不再使用位置索引）
                        df[pft_col] = pft_data[pft_col][row, col]
                    except Exception as e:
                        print(f"处理PFT数据时出错: {str(e)}")
    
    # 加载Hveg数据 (不随时间变化)
    hveg_file = Path("E:/data/CanopyHeight/CH.mat")
    if hveg_file.exists():
        ch_data = read_mat_v73(hveg_file, ['Hveg'])
        if ch_data and 'Hveg' in ch_data:
            try:
                # 直接使用行列索引
                df['Hveg_Satellite'] = ch_data['Hveg'][row, col]
            except Exception as e:
                print(f"处理Hveg数据时出错: {str(e)}")
    
    # 逐行处理VOD和LAI数据
    for idx, row_data in df.iterrows():
        date_str = row_data['Date'].strftime('%Y%m%d')
        year_int = row_data['Date'].year
        
        # 确定VOD文件路径
        if year_int <= 2012:
            vod_file = Path(f"E:/data/VOD/mat/kuxcVOD/ASC/MCCA_AMSRE_010D_CCXH_VSM_VOD_Asc_{date_str}_V0.nc4.mat")
        else:
            vod_file = Path(f"E:/data/VOD/mat/kuxcVOD/ASC/MCCA_AMSR2_010D_CCXH_VSM_VOD_Asc_{date_str}_V0.nc4.mat")
        
        # 加载VOD数据
        if vod_file.exists():
            try:
                vod_data = read_mat_v73(vod_file, vod_columns)
                if vod_data:
                    for var in vod_columns:
                        if var in vod_data:
                            try:
                                # 获取特定位置的数值
                                value = vod_data[var][row, col]
                                if not np.isnan(value):
                                    if var == 'SM':
                                        df.at[idx, 'SM_Satellite'] = value
                                    else:
                                        df.at[idx, var] = value
                            except:
                                print(f"提取VOD数据时出错 (文件: {vod_file}, 变量: {var})")
            except Exception as e:
                print(f"加载VOD文件 {vod_file} 时出错: {str(e)}")
        
        # 处理LAI数据（插值）
        prev_file, next_file, prev_date, next_date = get_nearest_lai_files(row_data['Date'])
        if prev_file.exists() and next_file.exists():
            try:
                # 读取前一个月数据
                prev_data = read_mat_v73(prev_file, [lai_column])
                prev_lai = prev_data[lai_column][row, col] if prev_data and lai_column in prev_data else np.nan
                
                # 读取后一个月数据
                next_data = read_mat_v73(next_file, [lai_column])
                next_lai = next_data[lai_column][row, col] if next_data and lai_column in next_data else np.nan
                
                # 计算日期差（精确到天）
                total_days = (next_date - prev_date).days
                current_days = (row_data['Date'] - prev_date).days
                
                # 线性插值
                if total_days > 0 and 0 <= current_days <= total_days:
                    weight = current_days / total_days
                    df.at[idx, 'LAI_Satellite'] = (1 - weight) * prev_lai + weight * next_lai
                else:
                    # 如果日期超出范围，使用最近的一个值
                    if current_days < 0:
                        df.at[idx, 'LAI_Satellite'] = prev_lai
                    else:
                        df.at[idx, 'LAI_Satellite'] = next_lai
            except Exception as e:
                print(f"处理LAI插值失败，日期 {date_str}: {str(e)}")
    
    return df

def process_2017_data():
    """处理2017年的数据"""
    file_path = r"E:\data\VWC\test-VWC\多频多角度地基微波辐射计及地表参量观测数据集\DuolunExp_Veg.xlsx"
    save_path = r"E:\data\VWC\test-VWC\多频多角度地基微波辐射计及地表参量观测数据集\DuolunExp_Veg_ML.xlsx"
    
    # 创建保存目录
    save_dir = Path(save_path).parent
    save_dir.mkdir(parents=True, exist_ok=True)
    
    # 多伦位置 - 东经116.47，北纬42.18
    lat = 42.18
    lon = 116.47
    
    # 获取所有sheet名称
    xl = pd.ExcelFile(file_path)
    all_sheets = xl.sheet_names
    
    # 排除BuckwheatMeasured
    sheets_to_process = [sheet for sheet in all_sheets if "BuckwheatMeasured" not in sheet]
    
    print(f"将处理以下工作表: {', '.join(sheets_to_process)}")
    
    # 创建一个新的Excel文件
    with pd.ExcelWriter(save_path, engine='openpyxl') as writer:
        # 添加一个空的工作表作为占位符（避免"没有可见工作表"错误）
        pd.DataFrame().to_excel(writer, sheet_name='Placeholder', index=False)
        
        for sheet_name in sheets_to_process:
            try:
                # 跳过首行标题（中文列名）
                df = pd.read_excel(file_path, sheet_name=sheet_name, skiprows=1)
                
                # 处理数据
                print(f"处理 2017: {sheet_name}")
                df_processed = process_sheet(df, sheet_name, lat, lon, year=2017)
                
                # 保存到Excel
                df_processed.to_excel(writer, sheet_name=sheet_name, index=False)
            except Exception as e:
                print(f"处理工作表 {sheet_name} 时出错: {str(e)}")
    
    # 删除占位符工作表
    wb = openpyxl.load_workbook(save_path)
    if 'Placeholder' in wb.sheetnames:
        del wb['Placeholder']
    wb.save(save_path)
    
    print(f"2017年数据处理完成，保存至: {save_path}")

def process_2018_data():
    """处理2018年的数据"""
    file_path = r"E:\data\VWC\test-VWC\多频多角度地基微波辐射计及地表参量观测数据集\ZhenglanqiExp_VWC.xlsx"
    save_path = r"E:\data\VWC\test-VWC\多频多角度地基微波辐射计及地表参量观测数据集\ZhenglanqiExp_VWC_ML.xlsx"
    
    # 正蓝旗位置 - 东经115.93，北纬42.04
    lat = 42.04
    lon = 115.93
    
    # 获取所有sheet名称
    xl = pd.ExcelFile(file_path)
    all_sheets = xl.sheet_names
    
    # 2018年只有一个名为GrassVWC的工作表
    sheets_to_process = [sheet for sheet in all_sheets if "GrassVWC" in sheet]
    
    if not sheets_to_process:
        print(f"警告: 在 {file_path} 中未找到名为 'GrassVWC' 的工作表")
        sheets_to_process = all_sheets  # 尝试处理所有工作表
    
    print(f"将处理以下工作表: {', '.join(sheets_to_process)}")
    
    # 创建一个新的Excel文件
    with pd.ExcelWriter(save_path, engine='openpyxl') as writer:
        # 添加一个空的工作表作为占位符（避免"没有可见工作表"错误）
        pd.DataFrame().to_excel(writer, sheet_name='Placeholder', index=False)
        
        for sheet_name in sheets_to_process:
            try:
                # 跳过首行标题（中文列名）
                df = pd.read_excel(file_path, sheet_name=sheet_name, skiprows=1)
                
                # 处理数据
                print(f"处理 2018: {sheet_name}")
                df_processed = process_sheet(df, sheet_name, lat, lon, year=2018)
                
                # 保存到Excel
                df_processed.to_excel(writer, sheet_name=sheet_name, index=False)
            except Exception as e:
                print(f"处理工作表 {sheet_name} 时出错: {str(e)}")
    
    # 删除占位符工作表
    wb = openpyxl.load_workbook(save_path)
    if 'Placeholder' in wb.sheetnames:
        del wb['Placeholder']
    wb.save(save_path)
    
    print(f"2018年数据处理完成，保存至: {save_path}")

def main():
    # 处理2017年数据
    process_2017_data()
    
    # 处理2018年数据
    process_2018_data()

if __name__ == "__main__":
    main()

将处理以下工作表: CornVegMeasured, CornVegFitting, OatVegMeasured, OatVegFitting
处理 2017: CornVegMeasured
警告: CornVegMeasured 中删除了 5 行包含空值的行
处理 2017: CornVegFitting
处理 2017: OatVegMeasured
处理 2017: OatVegFitting
2017年数据处理完成，保存至: E:\data\VWC\test-VWC\多频多角度地基微波辐射计及地表参量观测数据集\DuolunExp_Veg_ML.xlsx
将处理以下工作表: GrassVWC
处理 2018: GrassVWC
2018年数据处理完成，保存至: E:\data\VWC\test-VWC\多频多角度地基微波辐射计及地表参量观测数据集\ZhenglanqiExp_VWC_ML.xlsx


In [8]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import matplotlib.gridspec as gridspec
import joblib
import os
from pathlib import Path
import warnings
from datetime import datetime
from sklearn.metrics import mean_squared_error, r2_score
from scipy.interpolate import make_interp_spline  # 导入样条插值函数
warnings.filterwarnings('ignore')

# 设置全局字体
plt.rcParams['font.family'] = 'Arial'
plt.rcParams['font.weight'] = 'bold'
plt.rcParams['axes.labelweight'] = 'bold'
plt.rcParams['axes.titleweight'] = 'bold'
plt.rcParams['figure.titlesize'] = 16
plt.rcParams['figure.titleweight'] = 'bold'

# 常量定义
BANDS = ['Ku', 'X', 'C']
BAND_COLORS = {
    'Ku': '#1f77b4',  # 蓝色
    'X': '#ff7f0e',   # 橙色
    'C': '#2ca02c'    # 绿色
}
POLS = ['H', 'V', 'HV']
POL_LINESTYLES = {
    'H': '-',     # 实线
    'V': '--',    # 虚线
    'HV': ':'     # 点线
}
POL_MARKERS = {
    'H': '+',  # 加号
    'V': '^',  # 三角形
    'HV': 's'  # 正方形
}
POL_LABELS = {
    'H': 'H-Pol',
    'V': 'V-Pol',
    'HV': 'H&V-Pol'  # 修改这里
}

# 植被类型映射
VEGETATION_TYPES = {
    'CornVegMeasured': 'Corn (2017)',
    'OatVegMeasured': 'Oat (2017)',
    'GrassVWC': 'Grass (2018)'
}

# 实测与拟合数据映射
FITTING_MAPPING = {
    'CornVegMeasured': 'CornVegFitting',
    'OatVegMeasured': 'OatVegFitting',
    'GrassVWC': 'GrassVWC'  # 2018年没有拟合数据
}

# 实测VWC列名映射
ACTUAL_COL_MAPPING = {
    'CornVegMeasured': 'total_VWC(kg/m2)',
    'OatVegMeasured': 'total_VWC(kg/m2)',
    'GrassVWC': 'vegetation water content(kg/m2)'
}

# 实测数据样式
ACTUAL_STYLE = {
    'color': 'black',
    'marker': 'o',
    'markersize': 8,
    'markerfacecolor': 'none',
    'markeredgewidth': 1.5,
    'label': 'Measured'
}

def load_data(file_path):
    """加载Excel文件中的所有工作表"""
    print(f"加载文件: {file_path}")
    data_dict = {}
    
    # 获取所有工作表名称
    xl = pd.ExcelFile(file_path)
    sheet_names = xl.sheet_names
    
    for sheet in sheet_names:
        try:
            df = pd.read_excel(file_path, sheet_name=sheet)
            print(f"  - {sheet}: {len(df)}行")
            
            # 确保日期是datetime类型
            if 'Date' in df.columns:
                df['Date'] = pd.to_datetime(df['Date'])
            
            data_dict[sheet] = df
        except Exception as e:
            print(f"  加载 {sheet} 时出错: {str(e)}")
            data_dict[sheet] = pd.DataFrame()
    
    return data_dict

def predict_vwc_for_sheet(df, band, pol):
    """
    使用机器学习模型预测VWC，确保特征名称匹配
    """
    # 加载模型
    model_path = f"models/RFR_{band}_{pol}pol_Type1.pkl"
    if not os.path.exists(model_path):
        print(f"警告: 模型文件不存在: {model_path}")
        return pd.Series(np.nan, index=df.index)
    
    try:
        model = joblib.load(model_path)
        print(f"加载模型: {model_path}")
        
        # 获取模型期望的特征名称
        if hasattr(model, 'feature_names_in_'):
            expected_features = list(model.feature_names_in_)
            print(f"  模型期望特征: {expected_features}")
        else:
            print("  警告: 模型没有feature_names_in_属性")
            expected_features = []
    except Exception as e:
        print(f"加载模型失败: {str(e)}")
        return pd.Series(np.nan, index=df.index)
    
    # 1. 优先使用地面实测数据替换卫星数据
    if 'SM' in df.columns:
        sm_mask = df['SM'].notna() & (df['SM'] > 0)
        if sm_mask.any():
            df.loc[sm_mask, 'SM_Satellite'] = df.loc[sm_mask, 'SM']
            print(f"  使用实测SM替换了 {sm_mask.sum()} 行数据")
    
    if 'LAI' in df.columns:
        lai_mask = df['LAI'].notna() & (df['LAI'] > 0)
        if lai_mask.any():
            df.loc[lai_mask, 'LAI_Satellite'] = df.loc[lai_mask, 'LAI']
            print(f"  使用实测LAI替换了 {lai_mask.sum()} 行数据")
    
    # 2. 根据波段和极化组合确定特征映射
    feature_mapping = {}
    
    # Ku波段
    if band == 'Ku':
        if pol == 'H':
            feature_mapping = {
                'ku_vod_H': 'VOD',
                'LAI_Satellite': 'LAI',
                'SM_Satellite': 'SM',
                'grassman': 'Grass_man',
                'grassnat': 'Grass_nat',
                'shrubbd': 'Shrub_bd',
                'shrubbe': 'Shrub_be',
                'shrubnd': 'Shrub_nd',
                'shrubne': 'Shrub_ne',
                'treebd': 'Tree_bd',
                'treebe': 'Tree_be',
                'treend': 'Tree_nd',
                'treene': 'Tree_ne'
            }
        elif pol == 'V':
            feature_mapping = {
                'ku_vod_V': 'VOD',
                'LAI_Satellite': 'LAI',
                'SM_Satellite': 'SM',
                'grassman': 'Grass_man',
                'grassnat': 'Grass_nat',
                'shrubbd': 'Shrub_bd',
                'shrubbe': 'Shrub_be',
                'shrubnd': 'Shrub_nd',
                'shrubne': 'Shrub_ne',
                'treebd': 'Tree_bd',
                'treebe': 'Tree_be',
                'treend': 'Tree_nd',
                'treene': 'Tree_ne'
            }
        elif pol == 'HV':
            feature_mapping = {
                'ku_vod_H': 'VOD-Hpol',
                'ku_vod_V': 'VOD-Vpol',
                'LAI_Satellite': 'LAI',
                'SM_Satellite': 'SM',
                'grassman': 'Grass_man',
                'grassnat': 'Grass_nat',
                'shrubbd': 'Shrub_bd',
                'shrubbe': 'Shrub_be',
                'shrubnd': 'Shrub_nd',
                'shrubne': 'Shrub_ne',
                'treebd': 'Tree_bd',
                'treebe': 'Tree_be',
                'treend': 'Tree_nd',
                'treene': 'Tree_ne'
            }
    
    # X波段
    elif band == 'X':
        if pol == 'H':
            feature_mapping = {
                'x_vod_H': 'VOD',
                'LAI_Satellite': 'LAI',
                'SM_Satellite': 'SM',
                'grassman': 'Grass_man',
                'grassnat': 'Grass_nat',
                'shrubbd': 'Shrub_bd',
                'shrubbe': 'Shrub_be',
                'shrubnd': 'Shrub_nd',
                'shrubne': 'Shrub_ne',
                'treebd': 'Tree_bd',
                'treebe': 'Tree_be',
                'treend': 'Tree_nd',
                'treene': 'Tree_ne'
            }
        elif pol == 'V':
            feature_mapping = {
                'x_vod_V': 'VOD',
                'LAI_Satellite': 'LAI',
                'SM_Satellite': 'SM',
                'grassman': 'Grass_man',
                'grassnat': 'Grass_nat',
                'shrubbd': 'Shrub_bd',
                'shrubbe': 'Shrub_be',
                'shrubnd': 'Shrub_nd',
                'shrubne': 'Shrub_ne',
                'treebd': 'Tree_bd',
                'treebe': 'Tree_be',
                'treend': 'Tree_nd',
                'treene': 'Tree_ne'
            }
        elif pol == 'HV':
            feature_mapping = {
                'x_vod_H': 'VOD-Hpol',
                'x_vod_V': 'VOD-Vpol',
                'LAI_Satellite': 'LAI',
                'SM_Satellite': 'SM',
                'grassman': 'Grass_man',
                'grassnat': 'Grass_nat',
                'shrubbd': 'Shrub_bd',
                'shrubbe': 'Shrub_be',
                'shrubnd': 'Shrub_nd',
                'shrubne': 'Shrub_ne',
                'treebd': 'Tree_bd',
                'treebe': 'Tree_be',
                'treend': 'Tree_nd',
                'treene': 'Tree_ne'
            }
    
    # C波段
    elif band == 'C':
        if pol == 'H':
            feature_mapping = {
                'c_vod_H': 'VOD',
                'LAI_Satellite': 'LAI',
                'SM_Satellite': 'SM',
                'grassman': 'Grass_man',
                'grassnat': 'Grass_nat',
                'shrubbd': 'Shrub_bd',
                'shrubbe': 'Shrub_be',
                'shrubnd': 'Shrub_nd',
                'shrubne': 'Shrub_ne',
                'treebd': 'Tree_bd',
                'treebe': 'Tree_be',
                'treend': 'Tree_nd',
                'treene': 'Tree_ne'
            }
        elif pol == 'V':
            feature_mapping = {
                'c_vod_V': 'VOD',
                'LAI_Satellite': 'LAI',
                'SM_Satellite': 'SM',
                'grassman': 'Grass_man',
                'grassnat': 'Grass_nat',
                'shrubbd': 'Shrub_bd',
                'shrubbe': 'Shrub_be',
                'shrubnd': 'Shrub_nd',
                'shrubne': 'Shrub_ne',
                'treebd': 'Tree_bd',
                'treebe': 'Tree_be',
                'treend': 'Tree_nd',
                'treene': 'Tree_ne'
            }
        elif pol == 'HV':
            feature_mapping = {
                'c_vod_H': 'VOD-Hpol',
                'c_vod_V': 'VOD-Vpol',
                'LAI_Satellite': 'LAI',
                'SM_Satellite': 'SM',
                'grassman': 'Grass_man',
                'grassnat': 'Grass_nat',
                'shrubbd': 'Shrub_bd',
                'shrubbe': 'Shrub_be',
                'shrubnd': 'Shrub_nd',
                'shrubne': 'Shrub_ne',
                'treebd': 'Tree_bd',
                'treebe': 'Tree_be',
                'treend': 'Tree_nd',
                'treene': 'Tree_ne'
            }
    
    # 时间序列插值
    if 'Date' in df.columns and not df.empty:
        # 确保按日期排序
        df = df.sort_values('Date')
        
        # 确定需要插值的特征列
        interpolate_cols = list(feature_mapping.keys())
        valid_cols = [col for col in interpolate_cols if col in df.columns]
        
        # 设置时间索引
        date_index = pd.DatetimeIndex(df['Date'])
        df_temp = df.set_index('Date')
        
        # 生成完整的时间序列范围
        full_range = pd.date_range(start=date_index.min(), end=date_index.max(), freq='D')
        df_full = df_temp.reindex(full_range)
        
        # 对特征列进行线性插值
        for col in valid_cols:
            df_full[col] = df_full[col].interpolate(method='time', limit_direction='both')
            print(f"  已完成{col}的时间序列插值")
        
        # 重置索引
        df = df_full.reset_index().rename(columns={'index': 'Date'})
    else:
        print("  无日期列或数据为空，跳过插值")
    
    # 3. 检查是否所有映射后的特征都存在
    missing_features = []
    for data_col in feature_mapping.keys():
        if data_col not in df.columns:
            missing_features.append(data_col)
    
    if missing_features:
        print(f"  缺少特征: {', '.join(missing_features)}")
        return pd.Series(np.nan, index=df.index)
    
    # 4. 准备特征数据
    X = pd.DataFrame()
    for data_col, model_feature in feature_mapping.items():
        X[model_feature] = df[data_col]
    
    # 5. 应用特征归一化
    # VOD特征归一化（除以2）
    vod_features = ['VOD', 'VOD-Hpol', 'VOD-Vpol']
    for vod_feature in vod_features:
        if vod_feature in X.columns:
            X[vod_feature] = X[vod_feature].clip(0, 2) / 2.0
    
    # LAI特征归一化（除以6）
    if 'LAI' in X.columns:
        X['LAI'] = X['LAI'].clip(0, 6) / 6.0
    
    # PFT特征归一化（除以100）
    pft_features = [
        'Grass_man', 'Grass_nat',
        'Shrub_bd', 'Shrub_be', 'Shrub_nd', 'Shrub_ne',
        'Tree_bd', 'Tree_be', 'Tree_nd', 'Tree_ne'
    ]
    for pft_feature in pft_features:
        if pft_feature in X.columns:
            X[pft_feature] = X[pft_feature] / 100.0
    
    # 6. 移除缺失值
    initial_count = len(X)
    X = X.dropna()
    removed_count = initial_count - len(X)
    if removed_count > 0:
        print(f"  移除了 {removed_count} 行包含缺失值的数据")
    
    if X.empty:
        print("  无有效数据可用于预测")
        return pd.Series(np.nan, index=df.index)
    
    # 7. 确保特征顺序与模型期望一致
    if hasattr(model, 'feature_names_in_'):
        X = X[expected_features]
    
    # 8. 预测VWC
    try:
        y_pred = model.predict(X)
        
        # 创建完整长度的预测序列
        full_pred = pd.Series(np.nan, index=df.index)
        full_pred.loc[X.index] = y_pred
        
        return full_pred
    except Exception as e:
        print(f"  预测失败: {str(e)}")
        return pd.Series(np.nan, index=df.index)

def create_combined_plots(data_dict_2017, data_dict_2018):
    """创建组合时间序列图并保存预测结果"""
    print("创建组合时间序列图...")
    
    # 创建输出目录
    output_dir = Path("prediction_results")
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # 创建图形 - 增加底部空间用于多行图例
    fig = plt.figure(figsize=(15, 18))
    gs = gridspec.GridSpec(4, 1, figure=fig, height_ratios=[1, 1, 1, 0.4], hspace=0.3)
    
    # 设置全局标题
    fig.suptitle('Vegetation Water Content Time Series', fontsize=20, fontweight='bold', y=0.95)
    
    # 植被类型列表
    vegetation_types = [
        ('CornVegMeasured', data_dict_2017),  # 玉米
        ('OatVegMeasured', data_dict_2017),   # 燕麦
        ('GrassVWC', data_dict_2018)           # 草
    ]
    
    # 实测VWC列名映射
    ACTUAL_COL_MAPPING = {
        'CornVegMeasured': 'total_VWC(kg/m2)',
        'OatVegMeasured': 'total_VWC(kg/m2)',
        'GrassVWC': 'vegetation water content(kg/m2)'
    }
    
    # 存储所有评估指标
    all_metrics = {}
    
    # 存储所有预测结果
    all_predictions = {}
    
    # 根据图片定义极化标记样式
    POL_MARKERS = {
        'H': '+',   # 加号
        'V': '^',   # 三角形
        'HV': 's'   # 正方形
    }
    
    # 波段颜色
    BAND_COLORS = {
        'Ku': 'blue',
        'X': 'green',
        'C': 'red'
    }
    
    # 极化线型
    POL_LINESTYLES = {
        'H': '-',
        'V': '--',
        'HV': '-.'
    }
    
    # 极化名称映射 - 修改1：使用新的标题格式
    POL_NAMES = {
        'H': 'H-Pol',
        'V': 'V-Pol',
        'HV': 'H&V-Pol'
    }
    
    # 波段名称映射 - 修改2：使用新的标题格式
    BAND_NAMES = {
        'Ku': 'Ku-Band',
        'X': 'X-Band',
        'C': 'C-Band'
    }
    
    # 植被类型显示名称
    VEGETATION_TYPES = {
        'CornVegMeasured': 'Corn',
        'OatVegMeasured': 'Oat',
        'GrassVWC': 'Grass'
    }
    
    # 遍历所有植被类型
    for idx, (veg_type, data_dict) in enumerate(vegetation_types):
        ax = fig.add_subplot(gs[idx])
        
        # 获取当前植被类型的实测列名
        actual_col = ACTUAL_COL_MAPPING[veg_type]
        
        # 初始化Y轴范围
        y_min = float('inf')
        y_max = float('-inf')
        
        # 获取实测数据
        if veg_type in data_dict:
            df_measured = data_dict[veg_type].copy()
            
            # 确保日期列存在
            if 'Date' not in df_measured.columns:
                print(f"警告: {veg_type} 中没有 'Date' 列")
                continue
            
            # 按日期排序
            df_measured = df_measured.sort_values('Date')
            
            # 更新Y轴范围（实测值）
            if actual_col in df_measured.columns:
                measured_values = df_measured[actual_col].dropna()
                if not measured_values.empty:
                    y_min = min(y_min, measured_values.min())
                    y_max = max(y_max, measured_values.max())
            
            # 获取拟合数据用于预测
            fitting_sheet = FITTING_MAPPING.get(veg_type, veg_type)
            if fitting_sheet in data_dict:
                df_fitting = data_dict[fitting_sheet].copy()
                
                # 确保日期列存在
                if 'Date' not in df_fitting.columns:
                    print(f"警告: {fitting_sheet} 中没有 'Date' 列")
                    continue
                
                # 按日期排序
                df_fitting = df_fitting.sort_values('Date')
            else:
                # 2018年没有单独的拟合数据
                df_fitting = df_measured.copy()
            
            # 存储评估指标
            metrics = []
            
            # 为每个波段和极化组合预测VWC
            for band in BAND_COLORS.keys():
                for pol in POL_LINESTYLES.keys():
                    # 生成列名
                    col_name = f"Predicted_VWC_{band}_{pol}"
                    
                    # 如果列不存在，使用模型预测
                    if col_name not in df_fitting.columns:
                        print(f"为 {fitting_sheet} 预测 {band}-{pol} VWC...")
                        df_fitting[col_name] = predict_vwc_for_sheet(df_fitting, band, pol)
                    
                    # 只在有有效预测值的点进行绘制和评估
                    if col_name in df_fitting.columns:
                        # 更新Y轴范围（预测值）
                        pred_values = df_fitting[col_name].dropna()
                        if not pred_values.empty:
                            y_min = min(y_min, pred_values.min())
                            y_max = max(y_max, pred_values.max())
                        
                        # 获取有效预测数据点
                        valid_mask = df_fitting[col_name].notna()
                        valid_dates = df_fitting['Date'][valid_mask]
                        valid_values = df_fitting[col_name][valid_mask]
                        
                        # 如果数据点足够多，使用样条插值生成平滑曲线
                        if len(valid_dates) > 3:
                            try:
                                # 将日期转换为数值（从最小日期开始的天数）
                                date_numeric = (valid_dates - valid_dates.min()).dt.days
                                
                                # 创建样条插值对象
                                spline = make_interp_spline(date_numeric, valid_values, k=3)
                                
                                # 生成更密集的时间点
                                dense_dates = np.linspace(date_numeric.min(), date_numeric.max(), 300)
                                dense_values = spline(dense_dates)
                                
                                # 将数值日期转换回实际日期
                                dense_dates = valid_dates.min() + pd.to_timedelta(dense_dates, unit='D')
                                
                                # 绘制平滑曲线
                                ax.plot(dense_dates, dense_values,
                                        color=BAND_COLORS[band],
                                        linestyle=POL_LINESTYLES[pol],
                                        linewidth=1.5)
                            except Exception as e:
                                print(f"样条插值失败: {str(e)}")
                                # 如果插值失败，使用原始数据点绘制折线
                                ax.plot(valid_dates, valid_values,
                                        color=BAND_COLORS[band],
                                        linestyle=POL_LINESTYLES[pol],
                                        linewidth=1.5)
                        else:
                            # 数据点太少，直接绘制折线
                            ax.plot(valid_dates, valid_values,
                                    color=BAND_COLORS[band],
                                    linestyle=POL_LINESTYLES[pol],
                                    linewidth=1.5)
                        
                        # 找出同时有实测值和预测值的点
                        common_data = pd.merge(
                            df_measured[['Date', actual_col]], 
                            df_fitting[['Date', col_name]], 
                            on='Date', 
                            how='inner'
                        ).dropna(subset=[actual_col, col_name])
                        
                        if not common_data.empty:
                            # 更新Y轴范围（共同数据）
                            common_min = min(common_data[actual_col].min(), common_data[col_name].min())
                            common_max = max(common_data[actual_col].max(), common_data[col_name].max())
                            y_min = min(y_min, common_min)
                            y_max = max(y_max, common_max)
                            
                            # 在实测日期位置绘制实测值点（空心圆）
                            ax.plot(common_data['Date'], common_data[actual_col],
                                    linestyle='',  # 无线条
                                    color='black',
                                    marker='o',
                                    markersize=8,
                                    markerfacecolor='none',  # 透明填充（空心）
                                    markeredgewidth=1.5)
                            
                            # 在实测日期位置绘制预测点（空心标记）
                            ax.plot(common_data['Date'], common_data[col_name],
                                    linestyle='',  # 无线条
                                    color=BAND_COLORS[band],
                                    marker=POL_MARKERS[pol],
                                    markersize=10,
                                    markerfacecolor='none',  # 透明填充（空心）
                                    markeredgewidth=1.5)
                            
                            # 计算评估指标
                            rmse = np.sqrt(mean_squared_error(common_data[actual_col], common_data[col_name]))
                            r2 = r2_score(common_data[actual_col], common_data[col_name])
                            
                            # 添加到指标列表
                            metrics.append({
                                'band': band,
                                'pol': pol,
                                'rmse': rmse,
                                'r2': r2
                            })
                            
                            # 保存预测结果
                            model_key = f"{veg_type}_{band}_{pol}"
                            all_predictions[model_key] = {
                                'dates': common_data['Date'].tolist(),
                                'measured': common_data[actual_col].tolist(),
                                'predicted': common_data[col_name].tolist(),
                                'rmse': rmse,
                                'r2': r2
                            }
            
            # 设置子图标题 - 修改3：使用新的标题格式
            ax.set_title(VEGETATION_TYPES.get(veg_type, veg_type), 
                         fontsize=16, fontweight='bold')
            
            # 设置坐标轴标签
            if idx == 2:  # 最后一行
                ax.set_xlabel('Date', fontsize=12, fontweight='bold')
            ax.set_ylabel('VWC (kg/m²)', fontsize=12, fontweight='bold')
            
            # 设置X轴格式
            ax.xaxis.set_major_locator(mdates.DayLocator(interval=10))
            ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d'))
            
            # 添加网格
            ax.grid(True, linestyle='--', alpha=0.3)
            
            # 动态设置Y轴范围
            if y_min != float('inf') and y_max != float('-inf'):
                # 添加10%的边距
                y_range = y_max - y_min
                padding = y_range * 0.1
                
                # 确保最小值不小于0
                y_min = max(0, y_min - padding)
                y_max = y_max + padding
                
                ax.set_ylim(y_min, y_max)
            else:
                # 默认范围
                ax.set_ylim(0, 10)
            
            # 存储指标
            all_metrics[veg_type] = metrics
    
    # ==================================
    # 创建图例（精确匹配要求）
    # ==================================
    
    # 创建图例区域的轴
    ax_legend = fig.add_subplot(gs[3])
    ax_legend.axis('off')  # 隐藏坐标轴
    
    # 定义图例行内容（标题+项目） - 修改4：使用新的标题格式
    legend_rows = [
        # 第一行：Ku波段
        [f"{BAND_NAMES['Ku']},{POL_NAMES[pol]}" for pol in ['H', 'V', 'HV']],
        
        # 第二行：X波段
        [f"{BAND_NAMES['X']},{POL_NAMES[pol]}" for pol in ['H', 'V', 'HV']],
        
        # 第三行：C波段
        [f"{BAND_NAMES['C']},{POL_NAMES[pol]}" for pol in ['H', 'V', 'HV']],

        # 第四行：实测点
        [f"Insitu VWC"]
    ]
    
    # 创建代理艺术家
    proxies = {}
    
    # Insitu VWC代理（空心圆）
    proxies['insitu'] = plt.Line2D([], [], 
                     linestyle='', 
                     marker='o',
                     markersize=10,
                     markerfacecolor='none',
                     markeredgecolor='black',
                     markeredgewidth=1.5,
                     label='Insitu VWC')
    
    # 波段-极化组合代理 - 修改5：使用新的标题格式
    for band in ['Ku', 'X', 'C']:
        color = BAND_COLORS[band]
        for pol in ['H', 'V', 'HV']:
            proxies[f"{band}-{pol}"] = plt.Line2D([], [],
                color=color,
                linestyle=POL_LINESTYLES[pol],
                linewidth=2,
                marker=POL_MARKERS[pol],
                markersize=10,
                markerfacecolor='none',
                markeredgecolor=color,
                markeredgewidth=1.5,
                label=f"{BAND_NAMES[band]},{POL_NAMES[pol]}")
    
    # 为每行创建图例
    y_positions = [0.85, 0.60, 0.35, 0.10]  # 三行垂直位置
    
    for row_idx, row_items in enumerate(legend_rows):
        handles = []
        labels = []
        
        for item in row_items:
            # 处理Insitu项
            if item == "Insitu VWC":
                handles.append(proxies['insitu'])
                labels.append(item)
            # 处理波段-极化项
            else:
                # 解析新的标签格式
                band_part, pol_part = item.split(',')
                band = band_part.split('-')[0]  # 提取波段名称
                
                handles.append(proxies[f"{band}-{pol}"])
                labels.append(item)  # 使用完整的标签文本
        
        # 计算当前行文本宽度（均匀分布）
        n_items = len(handles)
        x_positions = np.linspace(0.05, 0.95, n_items)
        
        # 绘制当前行的图例项
        for i, (handle, label) in enumerate(zip(handles, labels)):
            ax_legend.plot([], [])  # 空白绘图以创建图例项
            
            # 创建图例句柄
            leg = ax_legend.legend([handle], [label], 
                                  loc='lower center',
                                  bbox_to_anchor=(x_positions[i], y_positions[row_idx]),
                                  frameon=False,
                                  handlelength=2,
                                  fontsize=10,
                                  handletextpad=0.8)
            
            # 添加到轴（否则会被覆盖）
            ax_legend.add_artist(leg)
    
    # 在子图中显示评估指标
    for idx, (veg_type, metrics) in enumerate(all_metrics.items()):
        if idx < 3:  # 确保索引有效（排除图例轴）
            ax = fig.axes[idx]
            
            # 创建指标文本
            if metrics:
                # 使用多列格式显示所有指标
                metric_text = "Evaluation Metrics:\n"
                
                # 按波段分组指标
                band_metrics = {}
                for metric in metrics:
                    band = metric['band']
                    if band not in band_metrics:
                        band_metrics[band] = []
                    band_metrics[band].append(metric)
                
                # 为每个波段创建一行文本
                for band in ['Ku', 'X', 'C']:
                    if band in band_metrics:
                        band_text = f"{BAND_NAMES[band]}: "
                        pol_texts = []
                        for metric in band_metrics[band]:
                            pol_texts.append(f"{POL_NAMES[metric['pol']]}(RMSE={metric['rmse']:.3f})")
                        band_text += ", ".join(pol_texts)
                        metric_text += band_text + "\n"
                
                # 添加文本框
                ax.text(0.02, 0.95, metric_text, 
                        transform=ax.transAxes,
                        fontsize=9,
                        verticalalignment='top',
                        bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    # 调整布局
    plt.tight_layout(rect=[0, 0, 1, 0.93])
    
    # 保存图像
    figures_dir = Path("figures")
    figures_dir.mkdir(parents=True, exist_ok=True)
    fig_path = figures_dir / "Combined_VWC_Time_Series.png"
    plt.savefig(fig_path, dpi=300, bbox_inches='tight')
    print(f"组合时间序列图已保存至: {fig_path}")
    plt.close()
    
    # 保存所有预测结果到CSV文件
    for model_key, data in all_predictions.items():
        veg_type, band, pol = model_key.split('_')
        df = pd.DataFrame({
            'Date': data['dates'],
            'Measured': data['measured'],
            'Predicted': data['predicted']
        })
        csv_path = output_dir / f"{veg_type}_{band}_{pol}_predictions.csv"
        df.to_csv(csv_path, index=False)
        print(f"保存预测结果至: {csv_path}")
    
    # 保存评估指标
    metrics_path = output_dir / "model_metrics.csv"
    metrics_data = []
    for veg_type, metrics in all_metrics.items():
        for metric in metrics:
            metrics_data.append({
                'Vegetation': veg_type,
                'Band': metric['band'],
                'Polarization': metric['pol'],
                'RMSE': metric['rmse'],
                'R2': metric['r2']
            })
    
    metrics_df = pd.DataFrame(metrics_data)
    metrics_df.to_csv(metrics_path, index=False)
    print(f"保存模型评估指标至: {metrics_path}")

def main():
    # 2017年数据文件
    file_2017 = r"E:\data\VWC\test-VWC\多频多角度地基微波辐射计及地表参量观测数据集\DuolunExp_Veg_ML.xlsx"
    data_2017 = load_data(file_2017)
    
    # 2018年数据文件
    file_2018 = r"E:\data\VWC\test-VWC\多频多角度地基微波辐射计及地表参量观测数据集\ZhenglanqiExp_VWC_ML.xlsx"
    data_2018 = load_data(file_2018)
    
    # 创建组合时间序列图
    create_combined_plots(data_2017, data_2018)
    
    print("\n处理完成!")

if __name__ == "__main__":
    main()

加载文件: E:\data\VWC\test-VWC\多频多角度地基微波辐射计及地表参量观测数据集\DuolunExp_Veg_ML.xlsx
  - CornVegMeasured: 8行
  - CornVegFitting: 64行
  - OatVegMeasured: 7行
  - OatVegFitting: 64行
加载文件: E:\data\VWC\test-VWC\多频多角度地基微波辐射计及地表参量观测数据集\ZhenglanqiExp_VWC_ML.xlsx
  - GrassVWC: 13行
创建组合时间序列图...
为 CornVegFitting 预测 Ku-H VWC...
加载模型: models/RFR_Ku_Hpol_Type1.pkl
  模型期望特征: ['VOD', 'LAI', 'SM', 'Grass_man', 'Grass_nat', 'Shrub_bd', 'Shrub_be', 'Shrub_nd', 'Shrub_ne', 'Tree_bd', 'Tree_be', 'Tree_nd', 'Tree_ne']
  使用实测LAI替换了 63 行数据
  已完成ku_vod_H的时间序列插值
  已完成LAI_Satellite的时间序列插值
  已完成SM_Satellite的时间序列插值
  已完成grassman的时间序列插值
  已完成grassnat的时间序列插值
  已完成shrubbd的时间序列插值
  已完成shrubbe的时间序列插值
  已完成shrubnd的时间序列插值
  已完成shrubne的时间序列插值
  已完成treebd的时间序列插值
  已完成treebe的时间序列插值
  已完成treend的时间序列插值
  已完成treene的时间序列插值
为 CornVegFitting 预测 Ku-V VWC...
加载模型: models/RFR_Ku_Vpol_Type1.pkl
  模型期望特征: ['VOD', 'LAI', 'SM', 'Grass_man', 'Grass_nat', 'Shrub_bd', 'Shrub_be', 'Shrub_nd', 'Shrub_ne', 'Tree_bd', 'Tree_be', 'Tree_nd', 'Tree_ne']
  使用实

In [9]:
# 3*3 散点图结果
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from pathlib import Path
import warnings
from sklearn.metrics import mean_squared_error
warnings.filterwarnings('ignore')

# 设置全局字体
plt.rcParams['font.family'] = 'Arial'
plt.rcParams['font.weight'] = 'bold'
plt.rcParams['axes.labelweight'] = 'bold'
plt.rcParams['axes.titleweight'] = 'bold'
plt.rcParams['figure.titlesize'] = 16
plt.rcParams['figure.titleweight'] = 'bold'

# 常量定义
BANDS = ['Ku', 'X', 'C']
BAND_COLORS = {
    'Ku': '#1f77b4',  # 蓝色
    'X': '#ff7f0e',   # 橙色
    'C': '#2ca02c'    # 绿色
}
POLS = ['H', 'V', 'HV']

# 植被类型标记样式 - 玉米标记改为空心方形（'s'）
VEG_MARKERS = {
    'CornVegMeasured': {'marker': 's', 'size': 80, 'label': 'Corn (2017)'},  # 改为方形
    'OatVegMeasured': {'marker': '^', 'size': 80, 'label': 'Oat (2017)'},
    'GrassVWC': {'marker': 'o', 'size': 80, 'label': 'Grass (2018)'}
}

def load_prediction_data(prediction_dir):
    """从CSV文件加载预测结果"""
    print(f"加载预测结果: {prediction_dir}")
    all_predictions = {}
    
    # 遍历所有CSV文件
    for csv_file in prediction_dir.glob("*_predictions.csv"):
        # 解析文件名获取模型信息
        filename = csv_file.stem
        parts = filename.split('_')
        
        if len(parts) >= 4:  # 格式: {植被类型}_{波段}_{极化}_predictions
            veg_type = parts[0]
            band = parts[1]
            pol = parts[2]
            model_key = f"{band}_{pol}"
            
            # 加载数据
            df = pd.read_csv(csv_file)
            
            # 确保日期是datetime类型
            if 'Date' in df.columns:
                df['Date'] = pd.to_datetime(df['Date'])
            
            # 存储数据
            if model_key not in all_predictions:
                all_predictions[model_key] = {}
            
            all_predictions[model_key][veg_type] = df
    
    return all_predictions

def get_model_title(band, pol):
    """根据波段和极化返回自定义标题"""
    band_names = {
        'Ku': 'Ku-Band',
        'X': 'X-Band',
        'C': 'C-Band'
    }
    pol_names = {
        'H': 'H-Pol',
        'V': 'V-Pol',
        'HV': 'H&V-Pol'  # 修改这里
    }
    return f"{band_names.get(band, band)},{pol_names.get(pol, pol)}"

def create_scatter_plots_from_predictions(prediction_dir):
    """从预测结果文件创建9个模型的真值与预测值散点图（3x3网格）"""
    # 加载预测结果
    all_predictions = load_prediction_data(prediction_dir)
    
    if not all_predictions:
        print("警告: 没有找到预测结果文件")
        return
    
    # 创建3x3网格图
    fig = plt.figure(figsize=(15, 15))
    fig.suptitle('', fontsize=20, y=0.95)
    gs = gridspec.GridSpec(3, 3, wspace=0.3, hspace=0.3)
    
    # 收集所有散点的最小和最大值（用于统一坐标轴）
    all_actual_min, all_actual_max = np.inf, -np.inf
    all_pred_min, all_pred_max = np.inf, -np.inf
    
    # 收集所有评估指标
    all_metrics = {}

    # 处理每个模型（波段和极化组合）
    for i, band in enumerate(BANDS):
        for j, pol in enumerate(POLS):
            model_key = f"{band}_{pol}"
            ax = plt.subplot(gs[i, j])
            print(f"处理模型: {model_key}")
            
            # 检查该模型是否有预测数据
            if model_key not in all_predictions:
                print(f"警告: {model_key} 模型没有预测数据")
                ax.text(0.5, 0.5, 'No Data', horizontalalignment='center', 
                        verticalalignment='center', transform=ax.transAxes,
                        fontsize=14, color='red')
                ax.set_title(get_model_title(band, pol), fontsize=14)
                continue
                
            # 收集该模型的所有植被类型的数据
            all_actual = []
            all_predicted = []
            all_veg_types = []
            
            # 存储各植被类型的数据点
            veg_data = {
                'CornVegMeasured': {'actual': [], 'predicted': []},
                'OatVegMeasured': {'actual': [], 'predicted': []},
                'GrassVWC': {'actual': [], 'predicted': []}
            }
            
            # 处理玉米数据
            veg_type = 'CornVegMeasured'
            if veg_type in all_predictions[model_key]:
                df = all_predictions[model_key][veg_type]
                if 'Measured' in df.columns and 'Predicted' in df.columns:
                    # 添加数据点
                    veg_data[veg_type]['actual'] = df['Measured'].tolist()
                    veg_data[veg_type]['predicted'] = df['Predicted'].tolist()
                    
                    # 添加到总数据
                    all_actual.extend(df['Measured'])
                    all_predicted.extend(df['Predicted'])
                    all_veg_types.extend([veg_type] * len(df))
                    print(f"  - 玉米数据点: {len(df)}")
            
            # 处理燕麦数据
            veg_type = 'OatVegMeasured'
            if veg_type in all_predictions[model_key]:
                df = all_predictions[model_key][veg_type]
                if 'Measured' in df.columns and 'Predicted' in df.columns:
                    # 添加数据点
                    veg_data[veg_type]['actual'] = df['Measured'].tolist()
                    veg_data[veg_type]['predicted'] = df['Predicted'].tolist()
                    
                    # 添加到总数据
                    all_actual.extend(df['Measured'])
                    all_predicted.extend(df['Predicted'])
                    all_veg_types.extend([veg_type] * len(df))
                    print(f"  - 燕麦数据点: {len(df)}")
            
            # 处理草数据
            veg_type = 'GrassVWC'
            if veg_type in all_predictions[model_key]:
                df = all_predictions[model_key][veg_type]
                if 'Measured' in df.columns and 'Predicted' in df.columns:
                    # 添加数据点
                    veg_data[veg_type]['actual'] = df['Measured'].tolist()
                    veg_data[veg_type]['predicted'] = df['Predicted'].tolist()
                    
                    # 添加到总数据
                    all_actual.extend(df['Measured'])
                    all_predicted.extend(df['Predicted'])
                    all_veg_types.extend([veg_type] * len(df))
                    print(f"  - 草数据点: {len(df)}")
            
            # 如果没有数据点，跳过
            if len(all_actual) == 0:
                print(f"警告: {model_key} 模型没有有效数据点")
                ax.text(0.5, 0.5, 'No Data', horizontalalignment='center', 
                        verticalalignment='center', transform=ax.transAxes,
                        fontsize=14, color='red')
                ax.set_title(get_model_title(band, pol), fontsize=14)
                continue
                
            # 转换为numpy数组
            all_actual = np.array(all_actual)
            all_predicted = np.array(all_predicted)
            
            # 更新全局最小/最大值
            all_actual_min = min(all_actual_min, np.min(all_actual))
            all_actual_max = max(all_actual_max, np.max(all_actual))
            all_pred_min = min(all_pred_min, np.min(all_predicted))
            all_pred_max = max(all_pred_max, np.max(all_predicted))
            
            # 计算各植被类型的RMSE
            rmse_corn = None
            rmse_oat = None
            rmse_grass = None
            
            if veg_data['CornVegMeasured']['actual']:
                actual_corn = np.array(veg_data['CornVegMeasured']['actual'])
                predicted_corn = np.array(veg_data['CornVegMeasured']['predicted'])
                rmse_corn = np.sqrt(mean_squared_error(actual_corn, predicted_corn))
            
            if veg_data['OatVegMeasured']['actual']:
                actual_oat = np.array(veg_data['OatVegMeasured']['actual'])
                predicted_oat = np.array(veg_data['OatVegMeasured']['predicted'])
                rmse_oat = np.sqrt(mean_squared_error(actual_oat, predicted_oat))
            
            if veg_data['GrassVWC']['actual']:
                actual_grass = np.array(veg_data['GrassVWC']['actual'])
                predicted_grass = np.array(veg_data['GrassVWC']['predicted'])
                rmse_grass = np.sqrt(mean_squared_error(actual_grass, predicted_grass))
            
            # 计算整体RMSE
            rmse_total = np.sqrt(mean_squared_error(all_actual, all_predicted))
            
            # 存储评估指标
            all_metrics[model_key] = {
                'RMSE_Corn': rmse_corn,
                'RMSE_Oat': rmse_oat,
                'RMSE_Grass': rmse_grass,
                'RMSE_Total': rmse_total
            }
            
            # 绘制散点图 - 按植被类型区分标记
            # 先绘制草和燕麦，最后绘制玉米（确保玉米在最上层）
            for veg_type in ['GrassVWC', 'OatVegMeasured', 'CornVegMeasured']:
                if veg_data[veg_type]['actual']:
                    actual_values = np.array(veg_data[veg_type]['actual'])
                    predicted_values = np.array(veg_data[veg_type]['predicted'])
                    
                    marker_style = VEG_MARKERS[veg_type]
                    
                    # 为玉米标记使用更大的尺寸和线宽
                    if veg_type == 'CornVegMeasured':
                        size = 100  # 增加大小
                        edgewidth = 1.5  # 更粗的线宽
                        alpha = 0.9  # 更高的不透明度
                    else:
                        size = marker_style['size']
                        edgewidth = 1.0
                        alpha = 0.8
                    
                    # 所有标记使用相同的波段颜色
                    ax.scatter(actual_values, predicted_values, 
                              marker=marker_style['marker'], 
                              s=size,
                              alpha=alpha,  # 调整透明度
                              facecolor='none', 
                              edgecolor=BAND_COLORS[band],  # 使用波段颜色
                              linewidths=edgewidth,
                              label=marker_style['label'])
            
            # 添加1:1参考线
            ax.plot([0, 4], [0, 4], 'k--', linewidth=1, label='1:1 Line')
            
            # 设置标题和坐标轴标签
            ax.set_title(get_model_title(band, pol), fontsize=14)
            if j == 0:  # 第一列添加y轴标签
                ax.set_ylabel('RF VWC (kg/m²)', fontsize=12)
            if i == 2:  # 最后一行添加x轴标签
                ax.set_xlabel('In Situ VWC (kg/m²)', fontsize=12)
            
            # 添加网格
            ax.grid(True, linestyle='--', alpha=0.3)
            
            # 显示评估指标
            metric_text = ""
            if rmse_corn is not None:
                metric_text += f"Corn RMSE = {rmse_corn:.3f}\n"
            if rmse_oat is not None:
                metric_text += f"Oat RMSE = {rmse_oat:.3f}\n"
            if rmse_grass is not None:
                metric_text += f"Grass RMSE = {rmse_grass:.3f}\n"
            metric_text += f"Total RMSE = {rmse_total:.3f}"
            
            ax.text(0.05, 0.95, metric_text, transform=ax.transAxes, 
                   fontsize=9, verticalalignment='top',
                   bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    # 设置所有子图的坐标轴范围一致
    max_val = 4
    min_val = 0
    for ax in fig.get_axes():
        ax.set_xlim(min_val, max_val)
        ax.set_ylim(min_val, max_val)
    
    # 添加图例
    # 创建代理艺术家用于图例
    handles = []
    labels = []
    
    # 添加植被类型标记
    for veg_type, style in VEG_MARKERS.items():
        # 为玉米标记使用特殊大小
        if veg_type == 'CornVegMeasured':
            markersize = 10  # 图例中保持相同大小
        else:
            markersize = 8
            
        handles.append(
            plt.Line2D([], [], marker=style['marker'], linestyle='None', 
                       markersize=markersize, alpha=0.7, markerfacecolor='none', 
                       markeredgecolor='gray', label=style['label'])
        )
    
    # 添加1:1参考线
    handles.append(
        plt.Line2D([], [], color='k', linestyle='--', linewidth=1, label='1:1 Line')
    )
    
    plt.tight_layout(rect=[0, 0.01, 1, 0.95])  # 调整底部空间
 
    # 添加图例到整个图形
    fig.legend(handles=handles, loc='lower center', 
               bbox_to_anchor=(0.5, 0.05), ncol=4, fontsize=10, 
               title="")
    output_dir = Path("figures")
    output_dir.mkdir(parents=True, exist_ok=True)
    fig_path = output_dir / "Scatter_Predictions_From_Saved_Data.png"
    plt.savefig(fig_path, dpi=1000, bbox_inches='tight', pad_inches=0.1)
    print(f"散点图已保存至: {fig_path}")
    plt.close()
    
    # 打印所有模型的评估指标
    print("\n模型评估指标:")
    for model_name, metrics in all_metrics.items():
        print(f"{model_name}:")
        if metrics['RMSE_Corn'] is not None:
            print(f"  Corn RMSE = {metrics['RMSE_Corn']:.4f}")
        if metrics['RMSE_Oat'] is not None:
            print(f"  Oat RMSE = {metrics['RMSE_Oat']:.4f}")
        if metrics['RMSE_Grass'] is not None:
            print(f"  Grass RMSE = {metrics['RMSE_Grass']:.4f}")
        print(f"  Total RMSE = {metrics['RMSE_Total']:.4f}")

def main():
    # 设置预测结果目录
    prediction_dir = Path("prediction_results")
    
    # 创建散点图
    create_scatter_plots_from_predictions(prediction_dir)
    
    print("\n处理完成!")

if __name__ == "__main__":
    main()

加载预测结果: prediction_results
处理模型: Ku_H
  - 玉米数据点: 8
  - 燕麦数据点: 7
  - 草数据点: 13
处理模型: Ku_V
  - 玉米数据点: 8
  - 燕麦数据点: 7
  - 草数据点: 13
处理模型: Ku_HV
  - 玉米数据点: 8
  - 燕麦数据点: 7
  - 草数据点: 13
处理模型: X_H
  - 玉米数据点: 8
  - 燕麦数据点: 7
  - 草数据点: 13
处理模型: X_V
  - 玉米数据点: 8
  - 燕麦数据点: 7
  - 草数据点: 13
处理模型: X_HV
  - 玉米数据点: 8
  - 燕麦数据点: 7
  - 草数据点: 13
处理模型: C_H
  - 玉米数据点: 8
  - 燕麦数据点: 7
  - 草数据点: 13
处理模型: C_V
  - 玉米数据点: 8
  - 燕麦数据点: 7
  - 草数据点: 13
处理模型: C_HV
  - 玉米数据点: 8
  - 燕麦数据点: 7
  - 草数据点: 13
散点图已保存至: figures\Scatter_Predictions_From_Saved_Data.png

模型评估指标:
Ku_H:
  Corn RMSE = 0.4584
  Oat RMSE = 0.2044
  Grass RMSE = 0.3865
  Total RMSE = 0.3739
Ku_V:
  Corn RMSE = 0.3945
  Oat RMSE = 0.2822
  Grass RMSE = 0.4523
  Total RMSE = 0.3992
Ku_HV:
  Corn RMSE = 0.6910
  Oat RMSE = 0.3684
  Grass RMSE = 0.5800
  Total RMSE = 0.5715
X_H:
  Corn RMSE = 0.6264
  Oat RMSE = 0.2532
  Grass RMSE = 0.4367
  Total RMSE = 0.4655
X_V:
  Corn RMSE = 0.4515
  Oat RMSE = 0.2022
  Grass RMSE = 0.5206
  Total RMSE = 0.4408
X_HV:


In [4]:
# 3*2 散点图结果
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from pathlib import Path
import warnings
from sklearn.metrics import mean_squared_error
warnings.filterwarnings('ignore')

# 设置全局字体
plt.rcParams['font.family'] = 'Arial'
plt.rcParams['font.weight'] = 'bold'
plt.rcParams['axes.labelweight'] = 'bold'
plt.rcParams['axes.titleweight'] = 'bold'
plt.rcParams['figure.titlesize'] = 16
plt.rcParams['figure.titleweight'] = 'bold'

# 常量定义 - 修改1：只包含H和V极化
BANDS = ['Ku', 'X', 'C']
POLS = ['H', 'V']  # 只包含H和V极化，排除HV

BAND_COLORS = {
    'Ku': '#1f77b4',  # 蓝色
    'X': '#ff7f0e',   # 橙色
    'C': '#2ca02c'    # 绿色
}

# 植被类型标记样式 - 玉米标记改为空心方形（'s'）
VEG_MARKERS = {
    'CornVegMeasured': {'marker': 's', 'size': 80, 'label': 'Corn (2017)'},  # 改为方形
    'OatVegMeasured': {'marker': '^', 'size': 80, 'label': 'Oat (2017)'},
    'GrassVWC': {'marker': 'o', 'size': 80, 'label': 'Grass (2018)'}
}

def load_prediction_data(prediction_dir):
    """从CSV文件加载预测结果"""
    print(f"加载预测结果: {prediction_dir}")
    all_predictions = {}
    
    # 遍历所有CSV文件
    for csv_file in prediction_dir.glob("*_predictions.csv"):
        # 解析文件名获取模型信息
        filename = csv_file.stem
        parts = filename.split('_')
        
        if len(parts) >= 4:  # 格式: {植被类型}_{波段}_{极化}_predictions
            veg_type = parts[0]
            band = parts[1]
            pol = parts[2]
            model_key = f"{band}_{pol}"
            
            # 加载数据
            df = pd.read_csv(csv_file)
            
            # 确保日期是datetime类型
            if 'Date' in df.columns:
                df['Date'] = pd.to_datetime(df['Date'])
            
            # 存储数据
            if model_key not in all_predictions:
                all_predictions[model_key] = {}
            
            all_predictions[model_key][veg_type] = df
    
    return all_predictions

def get_model_title(band, pol):
    """根据波段和极化返回自定义标题"""
    band_names = {
        'Ku': 'Ku-Band',
        'X': 'X-Band',
        'C': 'C-Band'
    }
    pol_names = {
        'H': 'H-Pol',
        'V': 'V-Pol'
    }
    return f"{band_names.get(band, band)},{pol_names.get(pol, pol)}"

def create_scatter_plots_from_predictions(prediction_dir):
    """从预测结果文件创建6个模型的真值与预测值散点图（2x3网格）"""
    # 加载预测结果
    all_predictions = load_prediction_data(prediction_dir)
    
    if not all_predictions:
        print("警告: 没有找到预测结果文件")
        return
    
    # 修改2：创建2x3网格图
    fig = plt.figure(figsize=(15, 10))  # 调整高度以适应2行
    fig.suptitle('', fontsize=20, y=0.95)
    gs = gridspec.GridSpec(2, 3, wspace=0.25, hspace=0.3)  # 调整间距
    
    # 收集所有散点的最小和最大值（用于统一坐标轴）
    all_actual_min, all_actual_max = np.inf, -np.inf
    all_pred_min, all_pred_max = np.inf, -np.inf
    
    # 收集所有评估指标
    all_metrics = {}

    # 修改3：调整循环顺序 - 外层为极化，内层为波段
    for i, pol in enumerate(POLS):  # 行索引 - 极化方式
        for j, band in enumerate(BANDS):  # 列索引 - 波段
            model_key = f"{band}_{pol}"
            ax = plt.subplot(gs[i, j])
            print(f"处理模型: {model_key}")
            
            # 检查该模型是否有预测数据
            if model_key not in all_predictions:
                print(f"警告: {model_key} 模型没有预测数据")
                ax.text(0.5, 0.5, 'No Data', horizontalalignment='center', 
                        verticalalignment='center', transform=ax.transAxes,
                        fontsize=14, color='red')
                ax.set_title(get_model_title(band, pol), fontsize=14)
                continue
                
            # 收集该模型的所有植被类型的数据
            all_actual = []
            all_predicted = []
            all_veg_types = []
            
            # 存储各植被类型的数据点
            veg_data = {
                'CornVegMeasured': {'actual': [], 'predicted': []},
                'OatVegMeasured': {'actual': [], 'predicted': []},
                'GrassVWC': {'actual': [], 'predicted': []}
            }
            
            # 处理玉米数据
            veg_type = 'CornVegMeasured'
            if veg_type in all_predictions[model_key]:
                df = all_predictions[model_key][veg_type]
                if 'Measured' in df.columns and 'Predicted' in df.columns:
                    # 添加数据点
                    veg_data[veg_type]['actual'] = df['Measured'].tolist()
                    veg_data[veg_type]['predicted'] = df['Predicted'].tolist()
                    
                    # 添加到总数据
                    all_actual.extend(df['Measured'])
                    all_predicted.extend(df['Predicted'])
                    all_veg_types.extend([veg_type] * len(df))
                    print(f"  - 玉米数据点: {len(df)}")
            
            # 处理燕麦数据
            veg_type = 'OatVegMeasured'
            if veg_type in all_predictions[model_key]:
                df = all_predictions[model_key][veg_type]
                if 'Measured' in df.columns and 'Predicted' in df.columns:
                    # 添加数据点
                    veg_data[veg_type]['actual'] = df['Measured'].tolist()
                    veg_data[veg_type]['predicted'] = df['Predicted'].tolist()
                    
                    # 添加到总数据
                    all_actual.extend(df['Measured'])
                    all_predicted.extend(df['Predicted'])
                    all_veg_types.extend([veg_type] * len(df))
                    print(f"  - 燕麦数据点: {len(df)}")
            
            # 处理草数据
            veg_type = 'GrassVWC'
            if veg_type in all_predictions[model_key]:
                df = all_predictions[model_key][veg_type]
                if 'Measured' in df.columns and 'Predicted' in df.columns:
                    # 添加数据点
                    veg_data[veg_type]['actual'] = df['Measured'].tolist()
                    veg_data[veg_type]['predicted'] = df['Predicted'].tolist()
                    
                    # 添加到总数据
                    all_actual.extend(df['Measured'])
                    all_predicted.extend(df['Predicted'])
                    all_veg_types.extend([veg_type] * len(df))
                    print(f"  - 草数据点: {len(df)}")
            
            # 如果没有数据点，跳过
            if len(all_actual) == 0:
                print(f"警告: {model_key} 模型没有有效数据点")
                ax.text(0.5, 0.5, 'No Data', horizontalalignment='center', 
                        verticalalignment='center', transform=ax.transAxes,
                        fontsize=14, color='red')
                ax.set_title(get_model_title(band, pol), fontsize=14)
                continue
                
            # 转换为numpy数组
            all_actual = np.array(all_actual)
            all_predicted = np.array(all_predicted)
            
            # 更新全局最小/最大值
            all_actual_min = min(all_actual_min, np.min(all_actual))
            all_actual_max = max(all_actual_max, np.max(all_actual))
            all_pred_min = min(all_pred_min, np.min(all_predicted))
            all_pred_max = max(all_pred_max, np.max(all_predicted))
            
            # 计算各植被类型的RMSE
            rmse_corn = None
            rmse_oat = None
            rmse_grass = None
            
            if veg_data['CornVegMeasured']['actual']:
                actual_corn = np.array(veg_data['CornVegMeasured']['actual'])
                predicted_corn = np.array(veg_data['CornVegMeasured']['predicted'])
                rmse_corn = np.sqrt(mean_squared_error(actual_corn, predicted_corn))
            
            if veg_data['OatVegMeasured']['actual']:
                actual_oat = np.array(veg_data['OatVegMeasured']['actual'])
                predicted_oat = np.array(veg_data['OatVegMeasured']['predicted'])
                rmse_oat = np.sqrt(mean_squared_error(actual_oat, predicted_oat))
            
            if veg_data['GrassVWC']['actual']:
                actual_grass = np.array(veg_data['GrassVWC']['actual'])
                predicted_grass = np.array(veg_data['GrassVWC']['predicted'])
                rmse_grass = np.sqrt(mean_squared_error(actual_grass, predicted_grass))
            
            # 计算整体RMSE
            rmse_total = np.sqrt(mean_squared_error(all_actual, all_predicted))
            
            # 存储评估指标
            all_metrics[model_key] = {
                'RMSE_Corn': rmse_corn,
                'RMSE_Oat': rmse_oat,
                'RMSE_Grass': rmse_grass,
                'RMSE_Total': rmse_total
            }
            
            # 绘制散点图 - 按植被类型区分标记
            # 先绘制草和燕麦，最后绘制玉米（确保玉米在最上层）
            for veg_type in ['GrassVWC', 'OatVegMeasured', 'CornVegMeasured']:
                if veg_data[veg_type]['actual']:
                    actual_values = np.array(veg_data[veg_type]['actual'])
                    predicted_values = np.array(veg_data[veg_type]['predicted'])
                    
                    marker_style = VEG_MARKERS[veg_type]
                    
                    # 为玉米标记使用更大的尺寸和线宽
                    if veg_type == 'CornVegMeasured':
                        size = 100  # 增加大小
                        edgewidth = 1.5  # 更粗的线宽
                        alpha = 0.9  # 更高的不透明度
                    else:
                        size = marker_style['size']
                        edgewidth = 1.0
                        alpha = 0.8
                    
                    # 所有标记使用相同的波段颜色
                    ax.scatter(actual_values, predicted_values, 
                              marker=marker_style['marker'], 
                              s=size,
                              alpha=alpha,  # 调整透明度
                              facecolor='none', 
                              edgecolor=BAND_COLORS[band],  # 使用波段颜色
                              linewidths=edgewidth,
                              label=marker_style['label'])
            
            # 添加1:1参考线
            ax.plot([0, 4], [0, 4], 'k--', linewidth=1, label='1:1 Line')
            
            # 修改4：调整坐标轴标签位置
            ax.set_title(get_model_title(band, pol), fontsize=14)
            if j == 0:  # 第一列添加y轴标签
                ax.set_ylabel('RF VWC (kg/m²)', fontsize=12)
            if i == 1:  # 第二行（V极化行）添加x轴标签
                ax.set_xlabel('In Situ VWC (kg/m²)', fontsize=12)
            
            # 添加网格
            ax.grid(True, linestyle='--', alpha=0.3)
            
            # 显示评估指标
            metric_text = ""
            if rmse_corn is not None:
                metric_text += f"Corn RMSE = {rmse_corn:.3f}\n"
            if rmse_oat is not None:
                metric_text += f"Oat RMSE = {rmse_oat:.3f}\n"
            if rmse_grass is not None:
                metric_text += f"Grass RMSE = {rmse_grass:.3f}\n"
            metric_text += f"Total RMSE = {rmse_total:.3f}"
            
            ax.text(0.05, 0.95, metric_text, transform=ax.transAxes, 
                   fontsize=9, verticalalignment='top',
                   bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    # 设置所有子图的坐标轴范围一致
    max_val = 4
    min_val = 0
    for ax in fig.get_axes():
        ax.set_xlim(min_val, max_val)
        ax.set_ylim(min_val, max_val)
    
    # 添加图例
    # 创建代理艺术家用于图例
    handles = []
    labels = []
    
    # 添加植被类型标记
    for veg_type, style in VEG_MARKERS.items():
        # 为玉米标记使用特殊大小
        if veg_type == 'CornVegMeasured':
            markersize = 10  # 图例中保持相同大小
        else:
            markersize = 8
            
        handles.append(
            plt.Line2D([], [], marker=style['marker'], linestyle='None', 
                       markersize=markersize, alpha=0.7, markerfacecolor='none', 
                       markeredgecolor='gray', label=style['label'])
        )
    
    # 添加1:1参考线
    handles.append(
        plt.Line2D([], [], color='k', linestyle='--', linewidth=1, label='1:1 Line')
    )
    
    # 修改5：优化布局调整
    plt.tight_layout(rect=[0, 0.01, 1, 0.95])  # 调整底部空间
 
    # 修改6：重新定位图例位置（提高位置）
    fig.legend(handles=handles, loc='lower center', 
               bbox_to_anchor=(0.5, 0.01), ncol=4, fontsize=16, 
               title="")
    
    # 修改7：更改输出文件名以反映新布局
    output_dir = Path("figures")
    output_dir.mkdir(parents=True, exist_ok=True)
    fig_path = output_dir / "Scatter_Predictions_HV_Only.png"
    
    # 修改8：提高输出质量
    plt.savefig(fig_path, dpi=600, bbox_inches='tight', pad_inches=0.05)
    print(f"散点图已保存至: {fig_path}")
    plt.close()
    
    # 打印所有模型的评估指标
    print("\n模型评估指标:")
    for model_name, metrics in all_metrics.items():
        print(f"{model_name}:")
        if metrics['RMSE_Corn'] is not None:
            print(f"  Corn RMSE = {metrics['RMSE_Corn']:.4f}")
        if metrics['RMSE_Oat'] is not None:
            print(f"  Oat RMSE = {metrics['RMSE_Oat']:.4f}")
        if metrics['RMSE_Grass'] is not None:
            print(f"  Grass RMSE = {metrics['RMSE_Grass']:.4f}")
        print(f"  Total RMSE = {metrics['RMSE_Total']:.4f}")

def main():
    # 设置预测结果目录
    prediction_dir = Path("prediction_results")
    
    # 创建散点图
    create_scatter_plots_from_predictions(prediction_dir)
    
    print("\n处理完成!")

if __name__ == "__main__":
    main()

加载预测结果: prediction_results
处理模型: Ku_H
  - 玉米数据点: 8
  - 燕麦数据点: 7
  - 草数据点: 13
处理模型: X_H
  - 玉米数据点: 8
  - 燕麦数据点: 7
  - 草数据点: 13
处理模型: C_H
  - 玉米数据点: 8
  - 燕麦数据点: 7
  - 草数据点: 13
处理模型: Ku_V
  - 玉米数据点: 8
  - 燕麦数据点: 7
  - 草数据点: 13
处理模型: X_V
  - 玉米数据点: 8
  - 燕麦数据点: 7
  - 草数据点: 13
处理模型: C_V
  - 玉米数据点: 8
  - 燕麦数据点: 7
  - 草数据点: 13
散点图已保存至: figures\Scatter_Predictions_HV_Only.png

模型评估指标:
Ku_H:
  Corn RMSE = 0.4584
  Oat RMSE = 0.2044
  Grass RMSE = 0.3865
  Total RMSE = 0.3739
X_H:
  Corn RMSE = 0.6264
  Oat RMSE = 0.2532
  Grass RMSE = 0.4367
  Total RMSE = 0.4655
C_H:
  Corn RMSE = 0.3580
  Oat RMSE = 0.1685
  Grass RMSE = 0.4001
  Total RMSE = 0.3436
Ku_V:
  Corn RMSE = 0.3945
  Oat RMSE = 0.2822
  Grass RMSE = 0.4523
  Total RMSE = 0.3992
X_V:
  Corn RMSE = 0.4515
  Oat RMSE = 0.2022
  Grass RMSE = 0.5206
  Total RMSE = 0.4408
C_V:
  Corn RMSE = 0.6191
  Oat RMSE = 0.3555
  Grass RMSE = 0.4970
  Total RMSE = 0.5058

处理完成!


In [None]:
# 使用加权的代码，设置weight为1


# SMEX02+CLASIC07+SMAPVEX08+SMAPVEX16——对每一部分重新处理，如果有说明采样植被类型，则按照植被类型计算像元VWC加权平均；如果没有，则①去除并处理为像元后计算权重；②直接处理站点数据，自定义PFT的值

（SMEX08更名为SMAPVEX08）

## 1.处理为站点区域的验证，未标记类型的数据不采纳

In [7]:
# SMEX02：无植被类型信息，唯一标注的就是Class，10为Grassland-Grassnat；12为Cropland-Grassman；E:\data\VWC\test-VWC\NSIDC_0666\SMEX02\processed_SMEX02V_ML.xlsx
# CLASIC07：Crop列：Corn, Cotton,Cut WW: Harvested Winter Wheat, Pasture, WW:Winter Wheat；E:\data\VWC\test-VWC\Insitu CLASIC07\CL07V_SUM_VEG_CLASIC_ML.xlsx
# SMAPVEX08：Crop列：SB: Soybean、Corn；E:\data\VWC\test-VWC\Insitu SMEX08\processed_SV08V_ML.xlsx
# SMAPVEX16：CROP列：Alfalfa、Black Bean、Canola、Corn；Oat、Soybean、Wheat（没有所谓的树木类型，全是农作物）；E:\data\VWC\test-VWC\Insitu SMAPVEX16 Manitoba\Processed_Results\SV16M_V_CropBiomass_Vers4_with_coords_ML.csv

# 直接读取各部分数据，进行填充，不进行日内均值，保存原先的经纬度——直接合并各ML后缀数据为E:\data\VWC\test-VWC\SMEX02_CLASIC07_SMEX08_SMAPVEX16\InsituData_Point_ML.xlsx，根据采样植被修改PFT的值
import pandas as pd
import os
from pathlib import Path

# 定义PFT列名
PFT_COLUMNS = [
    'PFT_water', 'PFT_bare', 'PFT_snowice', 'PFT_built', 
    'PFT_grassnat', 'PFT_grassman', 
    'PFT_shrubbd', 'PFT_shrubbe', 'PFT_shrubnd', 'PFT_shrubne',
    'PFT_treebd', 'PFT_treebe', 'PFT_treend', 'PFT_treene'
]

# 定义文件路径
input_files = {
    'SMEX02': r'E:\data\VWC\test-VWC\NSIDC_0666\SMEX02\processed_SMEX02V_ML.xlsx',
    'CLASIC07': r'E:\data\VWC\test-VWC\Insitu CLASIC07\CL07V_SUM_VEG_CLASIC_ML.xlsx',
    'SMAPVEX08': r'E:\data\VWC\test-VWC\Insitu SMEX08\processed_SV08V_ML.xlsx',
    'SMAPVEX16': r'E:\data\VWC\test-VWC\Insitu SMAPVEX16 Manitoba\Processed_Results\SV16M_V_CropBiomass_Vers4_with_coords_ML.csv'
}

# 输出文件路径
output_dir = r'E:\data\VWC\test-VWC\SMEX02_CLASIC07_SMEX08_SMAPVEX16'
output_file = os.path.join(output_dir, 'InsituData_Point_ML.xlsx')

# 确保输出目录存在
Path(output_dir).mkdir(parents=True, exist_ok=True)

# 创建Excel写入器
with pd.ExcelWriter(output_file, engine='openpyxl') as writer:
    for sheet_name, file_path in input_files.items():
        print(f"处理 {sheet_name} 数据集...")
        print(f"  文件路径: {file_path}")
        
        # 检查文件是否存在
        if not os.path.exists(file_path):
            print(f"  警告: 文件不存在，跳过")
            continue
            
        try:
            # 根据文件扩展名读取数据
            suffix = Path(file_path).suffix.lower()
            if suffix == '.xlsx':
                df = pd.read_excel(file_path)
            elif suffix == '.csv':
                df = pd.read_csv(file_path)
            else:
                print(f"  警告: 不支持的文件格式 {suffix}，跳过")
                continue
                
            # 添加PFT列（如果不存在）
            for col in PFT_COLUMNS:
                if col not in df.columns:
                    df[col] = 0.0
                    
            # 根据数据集类型设置PFT值
            if sheet_name == 'SMEX02':
                # SMEX02: 根据Class列设置PFT值
                if 'Class' in df.columns:
                    # 设置所有PFT列为0
                    for col in PFT_COLUMNS:
                        df[col] = 0.0
                    
                    # 设置grassnat
                    grassnat_mask = df['Class'] == 10
                    df.loc[grassnat_mask, 'PFT_grassnat'] = 1.0
                    print(f"  设置 {grassnat_mask.sum()} 行为 PFT_grassnat=1")
                    
                    # 设置grassman
                    grassman_mask = df['Class'] == 12
                    df.loc[grassman_mask, 'PFT_grassman'] = 1.0
                    print(f"  设置 {grassman_mask.sum()} 行为 PFT_grassman=1")
                else:
                    print("  警告: SMEX02数据集中缺少'Class'列，无法设置PFT")
            else:
                # 其他数据集: 设置PFT_grassman=1，其他=0
                for col in PFT_COLUMNS:
                    df[col] = 0.0
                df['PFT_grassman'] = 1.0
                print(f"  设置所有行 PFT_grassman=1")
                
            # 保存到Excel
            df.to_excel(writer, sheet_name=sheet_name, index=False)
            print(f"  成功保存 {len(df)} 行数据")
            
        except Exception as e:
            print(f"  处理 {sheet_name} 时出错: {str(e)}")
            # 创建空工作表
            pd.DataFrame().to_excel(writer, sheet_name=sheet_name, index=False)
            print(f"  创建空工作表")

print(f"\n处理完成! 结果已保存至: {output_file}")


处理 SMEX02 数据集...
  文件路径: E:\data\VWC\test-VWC\NSIDC_0666\SMEX02\processed_SMEX02V_ML.xlsx
  设置 8 行为 PFT_grassnat=1
  设置 96 行为 PFT_grassman=1
  成功保存 104 行数据
处理 CLASIC07 数据集...
  文件路径: E:\data\VWC\test-VWC\Insitu CLASIC07\CL07V_SUM_VEG_CLASIC_ML.xlsx
  设置所有行 PFT_grassman=1
  成功保存 22 行数据
处理 SMAPVEX08 数据集...
  文件路径: E:\data\VWC\test-VWC\Insitu SMEX08\processed_SV08V_ML.xlsx
  设置所有行 PFT_grassman=1
  成功保存 10 行数据
处理 SMAPVEX16 数据集...
  文件路径: E:\data\VWC\test-VWC\Insitu SMAPVEX16 Manitoba\Processed_Results\SV16M_V_CropBiomass_Vers4_with_coords_ML.csv
  设置所有行 PFT_grassman=1
  成功保存 1400 行数据

处理完成! 结果已保存至: E:\data\VWC\test-VWC\SMEX02_CLASIC07_SMEX08_SMAPVEX16\InsituData_Point_ML.xlsx


In [11]:
# 散点图（4个数据画在一块，写出n，按照波段-极化组合绘制为3 * 3）
# 点形状及颜色：
# SMEX02：*；CLASIC07：^；SMAPVEX08：+；SMAPVEX16：o

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib as mpl
import matplotlib.font_manager as fm
import joblib
import os
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')
from sklearn.metrics import mean_squared_error

# 常量定义
BANDS = ['Ku', 'X', 'C']
POLS = ['H', 'V', 'HV']
SHEET_NAMES = ['SMEX02', 'CLASIC07', 'SMAPVEX08', 'SMAPVEX16']
VWC_COLUMNS = {
    'SMEX02': 'VWC-Field',
    'CLASIC07': 'VWC (kg/m²)',
    'SMAPVEX08': 'VWC',
    'SMAPVEX16': 'PLANT_WATER_CONTENT_AREA'
}

# 标记和颜色设置
MARKER_STYLES = {
    'SMEX02': {'marker': '*', 'color': '#F8766D'},
    'CLASIC07': {'marker': '^', 'facecolor': 'none', 'edgecolor': '#00BFC4'},
    'SMAPVEX08': {'marker': '+', 'color': '#C77CFF'},
    'SMAPVEX16': {'marker': 'o', 'facecolor': 'none', 'edgecolor': '#7CAE00'}
}

# 设置全局字体
plt.rcParams['font.family'] = 'Arial'
plt.rcParams['font.weight'] = 'bold'

def load_and_preprocess_data(file_path):
    """
    加载并预处理Excel文件中的所有sheet
    
    参数:
    file_path (str): Excel文件路径
    
    返回:
    dict: 包含预处理后数据的字典，键为sheet名称
    """
    print(f"加载文件: {file_path}")
    data_dict = {}
    
    for sheet in SHEET_NAMES:
        try:
            df = pd.read_excel(file_path, sheet_name=sheet)
            print(f"  - {sheet}: {len(df)}行")
            
            # 替换SM_Satellite和LAI_Satellite（如果存在地面实测数据）
            if 'SM' in df.columns:
                mask = df['SM'].notna()
                df.loc[mask, 'SM_Satellite'] = df.loc[mask, 'SM']
                print(f"    替换了 {mask.sum()} 行SM_Satellite数据")
            
            if 'LAI' in df.columns:
                mask = df['LAI'].notna()
                df.loc[mask, 'LAI_Satellite'] = df.loc[mask, 'LAI']
                print(f"    替换了 {mask.sum()} 行LAI_Satellite数据")
            
            data_dict[sheet] = df
        except Exception as e:
            print(f"  加载 {sheet} 时出错: {str(e)}")
            data_dict[sheet] = pd.DataFrame()
    
    return data_dict

def get_features_for_model(band, pol):
    """
    根据波段和极化类型获取特征列表（使用模型训练时的名称）
    
    参数:
    band (str): 波段 ('Ku', 'X', 'C')
    pol (str): 极化类型 ('H', 'V', 'HV')
    
    返回:
    list: 特征列名列表
    """
    # 使用模型训练时的特征名称
    features = [
        'LAI',  # 注意：训练时使用"LAI"而不是"LAI_Satellite"
        'SM',   # 注意：训练时使用"SM"而不是"SM_Satellite"
        'Grass_man', 
        'Grass_nat',
        'Shrub_bd', 'Shrub_be', 'Shrub_nd', 'Shrub_ne',
        'Tree_bd', 'Tree_be', 'Tree_nd', 'Tree_ne'
    ]
    
    # 添加VOD特征 - 根据模型类型
    if pol == 'H' or pol == 'V':
        # 单极化模型使用"VOD"
        features.append('VOD')
    elif pol == 'HV':
        # 双极化模型使用"VOD-Hpol"和"VOD-Vpol"
        features.extend(['VOD-Hpol', 'VOD-Vpol'])
    
    return features

def predict_vwc(data_dict, band, pol):
    """
    使用指定模型预测VWC，包括特征归一化
    
    参数:
    data_dict (dict): 包含所有sheet数据的字典
    band (str): 波段 ('Ku', 'X', 'C')
    pol (str): 极化类型 ('H', 'V', 'HV')
    
    返回:
    dict: 包含每个sheet预测结果的字典
    """
    # 加载模型 - 使用新模型命名规则
    model_path = f"models/RFR_{band}_{pol}-pol_Weighted_Type1.pkl"
    print(f"加载模型: {model_path}")
    
    if not os.path.exists(model_path):
        print(f"  模型文件不存在: {model_path}")
        return {}
    
    try:
        model = joblib.load(model_path)
        # 打印模型训练时的特征名称（如果可用）
        if hasattr(model, 'feature_names_in_'):
            print(f"  模型训练特征: {list(model.feature_names_in_)}")
    except Exception as e:
        print(f"  加载模型失败: {str(e)}")
        return {}
    
    # 获取特征列表
    features = get_features_for_model(band, pol)
    
    # 存储预测结果
    predictions = {}
    
    for sheet, df in data_dict.items():
        if df.empty:
            continue
        
        # 创建特征映射（将数据列名映射到模型期望的特征名）
        feature_mapping = {}
        for feature in features:
            # 特殊处理VOD特征
            if feature == 'VOD':
                # 单极化模型
                if pol == 'H':
                    if band == 'Ku':
                        feature_mapping['ku_vod_H'] = 'VOD'
                    elif band == 'X':
                        feature_mapping['x_vod_H'] = 'VOD'
                    elif band == 'C':
                        feature_mapping['c_vod_H'] = 'VOD'
                elif pol == 'V':
                    if band == 'Ku':
                        feature_mapping['ku_vod_V'] = 'VOD'
                    elif band == 'X':
                        feature_mapping['x_vod_V'] = 'VOD'
                    elif band == 'C':
                        feature_mapping['c_vod_V'] = 'VOD'
            elif feature == 'VOD-Hpol':
                # 双极化模型中的H极化
                if band == 'Ku':
                    feature_mapping['ku_vod_H'] = 'VOD-Hpol'
                elif band == 'X':
                    feature_mapping['x_vod_H'] = 'VOD-Hpol'
                elif band == 'C':
                    feature_mapping['c_vod_H'] = 'VOD-Hpol'
            elif feature == 'VOD-Vpol':
                # 双极化模型中的V极化
                if band == 'Ku':
                    feature_mapping['ku_vod_V'] = 'VOD-Vpol'
                elif band == 'X':
                    feature_mapping['x_vod_V'] = 'VOD-Vpol'
                elif band == 'C':
                    feature_mapping['c_vod_V'] = 'VOD-Vpol'
            else:
                # 其他特征映射 - 使用实际数据中的列名
                if feature == 'LAI':
                    feature_mapping['LAI_Satellite'] = 'LAI'
                elif feature == 'SM':
                    feature_mapping['SM_Satellite'] = 'SM'
                elif feature == 'Grass_man':
                    feature_mapping['PFT_grassman'] = 'Grass_man'
                elif feature == 'Grass_nat':
                    feature_mapping['PFT_grassnat'] = 'Grass_nat'
                elif feature == 'Shrub_bd':
                    feature_mapping['PFT_shrubbd'] = 'Shrub_bd'
                elif feature == 'Shrub_be':
                    feature_mapping['PFT_shrubbe'] = 'Shrub_be'
                elif feature == 'Shrub_nd':
                    feature_mapping['PFT_shrubnd'] = 'Shrub_nd'
                elif feature == 'Shrub_ne':
                    feature_mapping['PFT_shrubne'] = 'Shrub_ne'
                elif feature == 'Tree_bd':
                    feature_mapping['PFT_treebd'] = 'Tree_bd'
                elif feature == 'Tree_be':
                    feature_mapping['PFT_treebe'] = 'Tree_be'
                elif feature == 'Tree_nd':
                    feature_mapping['PFT_treend'] = 'Tree_nd'
                elif feature == 'Tree_ne':
                    feature_mapping['PFT_treene'] = 'Tree_ne'
        
        # 检查是否包含所有必要特征
        missing_features = []
        for data_feature in feature_mapping.keys():
            if data_feature not in df.columns:
                missing_features.append(data_feature)
        
        if missing_features:
            print(f"  {sheet} 缺少特征: {', '.join(missing_features)}")
            continue
        
        # 准备数据（使用重命名的特征）
        X = df[list(feature_mapping.keys())].copy()
        X.columns = [feature_mapping[col] for col in X.columns]
        
        # 确保特征顺序与模型期望一致
        if hasattr(model, 'feature_names_in_'):
            X = X[list(model.feature_names_in_)]
        
        # ========== 添加归一化处理 ==========
        print(f"  {sheet} 应用归一化处理...")
        
        # 1. VOD特征归一化（除以2）
        vod_features = ['VOD', 'VOD-Hpol', 'VOD-Vpol']
        for vod_feature in vod_features:
            if vod_feature in X.columns:
                X[vod_feature] = X[vod_feature].clip(0, 2) / 2.0
                print(f"    归一化 {vod_feature}: 除以2")
        
        # 2. LAI特征归一化（除以6）
        if 'LAI' in X.columns:
            X['LAI'] = X['LAI'].clip(0, 6) / 6.0
            print(f"    归一化 LAI: 除以6")
        
        # 3. PFT特征归一化 - 跳过，因为已经归一化
        # 根据要求，PFT特征不需要再次归一化
        print(f"    跳过PFT特征归一化（已归一化）")
        # =================================
        
        # 移除缺失值
        initial_count = len(X)
        X = X.dropna()
        removed_count = initial_count - len(X)
        if removed_count > 0:
            print(f"  {sheet} 移除了 {removed_count} 行包含缺失值的数据")
        
        if X.empty:
            print(f"  {sheet} 无有效数据可用于预测")
            continue
        
        # 预测VWC
        y_pred = model.predict(X)
        predictions[sheet] = {
            'actual': df.loc[X.index, VWC_COLUMNS[sheet]],
            'predicted': y_pred,
            'source': sheet,
            'lat': df.loc[X.index, 'Latitude'],
            'lon': df.loc[X.index, 'Longitude'],
            'date': df.loc[X.index, 'Date']
        }
        print(f"  {sheet} 预测完成: {len(y_pred)} 个样本")
    
    return predictions

def calculate_rmse(actual, predicted):
    """
    计算RMSE
    
    参数:
    actual (array-like): 实际值
    predicted (array-like): 预测值
    
    返回:
    float: RMSE值
    """
    return np.sqrt(np.mean((actual - predicted)**2))

def create_scatter_plots(all_predictions):
    """
    创建3x3散点子图
    
    参数:
    all_predictions (dict): 包含所有波段和极化组合预测结果的字典
    """
    print("创建散点图...")
    
    # 创建图形
    fig = plt.figure(figsize=(18, 18))
    gs = gridspec.GridSpec(3, 3, figure=fig)
    
    # 设置全局标题
    fig.suptitle('', fontsize=24, fontweight='bold', y=0.95)
    
    # 遍历所有波段和极化组合
    for i, band in enumerate(BANDS):
        for j, pol in enumerate(POLS):
            ax = fig.add_subplot(gs[i, j])
            
            # 获取当前组合的预测结果
            predictions = all_predictions.get((band, pol), {})
            
            # 收集所有数据点
            all_actual = []
            all_predicted = []
            
            # 绘制每个sheet的数据点
            for sheet in SHEET_NAMES:
                if sheet in predictions:
                    actual = predictions[sheet]['actual']
                    predicted = predictions[sheet]['predicted']
                    
                    # 添加到总集合
                    all_actual.extend(actual)
                    all_predicted.extend(predicted)
                    
                    # 绘制当前sheet的点
                    if sheet in ['CLASIC07', 'SMAPVEX16']:
                        # 对CLASIC07、SMAPVEX16特殊处理：空心
                        ax.scatter(
                            actual, predicted,
                            marker=MARKER_STYLES[sheet]['marker'],
                            facecolor=MARKER_STYLES[sheet]['facecolor'],  # 内部无填充
                            edgecolor=MARKER_STYLES[sheet]['edgecolor'],  # 使用边缘颜色
                            s=50,
                            alpha=0.7,
                            linewidths=1.0,  # 确保边框可见
                            label=sheet
                        )
                    else:
                        # 其他数据集保持原样
                        ax.scatter(
                            actual, predicted,
                            marker=MARKER_STYLES[sheet]['marker'],
                            color=MARKER_STYLES[sheet].get('color', MARKER_STYLES[sheet].get('edgecolor', None)),
                            s=50,
                            alpha=0.7,
                            label=sheet
                        )
            
            # 如果没有数据，跳过
            if not all_actual:
                ax.text(0.5, 0.5, '无数据', 
                        horizontalalignment='center', 
                        verticalalignment='center', 
                        transform=ax.transAxes,
                        fontsize=16)
                ax.set_title(f"{band}-{pol}", fontsize=16, fontweight='bold')
                continue
            
            # 计算整体RMSE
            rmse = calculate_rmse(np.array(all_actual), np.array(all_predicted))
            
            # 添加1:1参考线
            max_val = max(max(all_actual), max(all_predicted)) * 1.05
            ax.plot([0, max_val], [0, max_val], 'k--', lw=1.5, alpha=0.7)
            
            # 设置坐标轴范围
            ax.set_xlim(0, max_val)
            ax.set_ylim(0, max_val)
            
            # 设置坐标轴标签
            if i == 2:  # 最后一行
                ax.set_xlabel('Insitu VWC (kg/m²)', fontsize=14, fontweight='bold')
            if j == 0:  # 第一列
                ax.set_ylabel('Predicted VWC (kg/m²)', fontsize=14, fontweight='bold')
            
            # 添加标题和RMSE
            ax.set_title(f"{band}-{pol}", fontsize=16, fontweight='bold')
            ax.text(0.05, 0.95, f"RMSE: {rmse:.3f} kg/m²", 
                    transform=ax.transAxes,
                    fontsize=16,
                    fontweight='bold',
                    verticalalignment='top')
            
            # 添加网格
            ax.grid(True, linestyle='--', alpha=0.3)
    
    # 添加图例
    handles, labels = [], []
    for sheet in SHEET_NAMES:
        style = MARKER_STYLES[sheet]
        
        if sheet in ['CLASIC07', 'SMAPVEX16']:
            # 为CLASIC07、SMAPVEX16创建空心图例
            handles.append(plt.Line2D([0], [0], 
                                     marker=style['marker'], 
                                     color='w',
                                     markerfacecolor=style['facecolor'],  # 内部白色
                                     markeredgecolor=style['edgecolor'],  # 边缘颜色
                                     markersize=10,
                                     markeredgewidth=1.0))  # 边框宽度
        else:
            handles.append(plt.Line2D([0], [0], 
                                     marker=style['marker'], 
                                     color='w', 
                                     markerfacecolor=style.get('color', style.get('edgecolor')),
                                     markeredgecolor=style.get('color', style.get('edgecolor')), 
                                     markersize=10))
        labels.append(sheet)
    
    fig.legend(handles, labels, 
               loc='lower center', 
               ncol=4, 
               fontsize=12,
               frameon=True,
               fancybox=True,
               shadow=True,
               bbox_to_anchor=(0.5, 0.02))
    
    # 调整布局
    plt.tight_layout(rect=[0, 0.05, 1, 0.95])
    
    # 保存图像 - 使用新路径
    fig_path = "figures/AllSMAPInsituData_PointVWC_Scatter.png"
    os.makedirs(os.path.dirname(fig_path), exist_ok=True)
    plt.savefig(fig_path, dpi=300, bbox_inches='tight')
    print(f"散点图已保存至: {fig_path}")
    plt.close()

def save_prediction_details(all_predictions):
    """
    将预测结果保存到Excel文件中
    
    参数:
    all_predictions (dict): 包含所有波段和极化组合预测结果的字典
    """
    output_dir = Path(r"E:\data\VWC\test-VWC\SMEX02_CLASIC07_SMEX08_SMAPVEX16")
    output_file = output_dir / "details_Point.xlsx"  # 修改输出文件名
    
    # 创建Excel写入器
    with pd.ExcelWriter(output_file, engine='openpyxl') as writer:
        # 遍历所有波段和极化组合
        for (band, pol), predictions in all_predictions.items():
            if not predictions:
                continue
                
            # 创建当前组合的数据框
            all_data = []
            
            # 收集所有sheet的数据
            for sheet, data in predictions.items():
                # 创建当前sheet的数据框
                sheet_df = pd.DataFrame({
                    'Date': data['date'],
                    'Latitude': data['lat'],
                    'Longitude': data['lon'],
                    'Actual_VWC': data['actual'],
                    'Predicted_VWC': data['predicted'],
                    'Source': data['source']
                })
                
                # 添加波段和极化信息
                sheet_df['Band'] = band
                sheet_df['Polarization'] = pol
                
                all_data.append(sheet_df)
            
            # 合并所有数据
            if all_data:
                combined_df = pd.concat(all_data, ignore_index=True)
                
                # 保存到Excel
                sheet_name = f"{band}_{pol}"
                combined_df.to_excel(writer, sheet_name=sheet_name, index=False)
                print(f"保存预测结果到: {sheet_name} ({len(combined_df)}行)")
    
    print(f"所有预测结果已保存至: {output_file}")

def main():
    # 输入文件路径 - 使用新路径
    input_file = r"E:\data\VWC\test-VWC\SMEX02_CLASIC07_SMEX08_SMAPVEX16\InsituData_Point_ML.xlsx"
    
    # 加载并预处理数据
    data_dict = load_and_preprocess_data(input_file)
    
    # 存储所有预测结果
    all_predictions = {}
    
    # 遍历所有波段和极化组合
    for band in BANDS:
        for pol in POLS:
            print(f"\n处理波段-极化组合: {band}-{pol}")
            predictions = predict_vwc(data_dict, band, pol)
            all_predictions[(band, pol)] = predictions
    
    # 创建散点图
    create_scatter_plots(all_predictions)
    
    # 保存预测结果到Excel
    save_prediction_details(all_predictions)
    
    print("\n处理完成!")

if __name__ == "__main__":
    main()

加载文件: E:\data\VWC\test-VWC\SMEX02_CLASIC07_SMEX08_SMAPVEX16\InsituData_Point_ML.xlsx
  - SMEX02: 104行
    替换了 42 行SM_Satellite数据
  - CLASIC07: 22行
    替换了 17 行SM_Satellite数据
  - SMAPVEX08: 10行
    替换了 10 行SM_Satellite数据
    替换了 10 行LAI_Satellite数据
  - SMAPVEX16: 1400行
    替换了 1375 行SM_Satellite数据

处理波段-极化组合: Ku-H
加载模型: models/RFR_Ku_H-pol_Weighted_Type1.pkl
  模型训练特征: ['VOD', 'LAI', 'SM', 'Grass_man', 'Grass_nat', 'Shrub_bd', 'Shrub_be', 'Shrub_nd', 'Shrub_ne', 'Tree_bd', 'Tree_be', 'Tree_nd', 'Tree_ne']
  SMEX02 应用归一化处理...
    归一化 VOD: 除以2
    归一化 LAI: 除以6
    跳过PFT特征归一化（已归一化）
  SMEX02 移除了 83 行包含缺失值的数据
  SMEX02 预测完成: 21 个样本
  CLASIC07 应用归一化处理...
    归一化 VOD: 除以2
    归一化 LAI: 除以6
    跳过PFT特征归一化（已归一化）
  CLASIC07 移除了 5 行包含缺失值的数据
  CLASIC07 预测完成: 17 个样本
  SMAPVEX08 应用归一化处理...
    归一化 VOD: 除以2
    归一化 LAI: 除以6
    跳过PFT特征归一化（已归一化）
  SMAPVEX08 预测完成: 10 个样本
  SMAPVEX16 应用归一化处理...
    归一化 VOD: 除以2
    归一化 LAI: 除以6
    跳过PFT特征归一化（已归一化）
  SMAPVEX16 移除了 25 行包含缺失值的数据
  SMAPVEX16 预测完成: 1375 个样本

处理

In [9]:
# 2.仍然处理为像元区域验证,对比PFT，计算像元VWC，并且列出权重，即有效测量值占比，然后投入值预估
# 因为数据都是农作物或草，没有灌木、树木，所以这里直接读取E:\data\VWC\test-VWC\SMEX02_CLASIC07_SMEX08_SMAPVEX16\InsituData_Pixel_ML.xlsx的各个Sheet,将PFT_grassnat和PFT_grassman这两列相加获得新列validCoverage
# 再将VWC修改，等于原先的VWC×validCoverage。使用新训练的模型进行预测
# 绘制3*3的散点图
# 不使用Weight

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib as mpl
import matplotlib.font_manager as fm
import joblib
import os
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')
from sklearn.metrics import mean_squared_error

# 常量定义
BANDS = ['Ku', 'X', 'C']
POLS = ['H', 'V', 'HV']
SHEET_NAMES = ['SMEX02', 'CLASIC07', 'SMAPVEX08', 'SMAPVEX16']
VWC_COLUMNS = {
    'SMEX02': 'VWC-Field',
    'CLASIC07': 'VWC (kg/m²)',
    'SMAPVEX08': 'VWC',
    'SMAPVEX16': 'PLANT_WATER_CONTENT_AREA'
}

# 标记和颜色设置
MARKER_STYLES = {
    'SMEX02': {'marker': 'x', 'color': '#F8766D'},
    'CLASIC07': {'marker': '^', 'facecolor': 'none', 'edgecolor': '#00BFC4'},
    'SMAPVEX08': {'marker': '+', 'color': '#C77CFF'},
    'SMAPVEX16': {'marker': 'o', 'facecolor': 'none', 'edgecolor': '#7CAE00'}
}

# 设置全局字体
plt.rcParams['font.family'] = 'Arial'
plt.rcParams['font.weight'] = 'bold'

def load_and_preprocess_data(file_path):
    """
    加载并预处理Excel文件中的所有sheet，计算validCoverage权重
    
    参数:
    file_path (str): Excel文件路径
    
    返回:
    dict: 包含预处理后数据的字典，键为sheet名称
    """
    print(f"加载文件: {file_path}")
    data_dict = {}
    
    for sheet in SHEET_NAMES:
        try:
            df = pd.read_excel(file_path, sheet_name=sheet)
            print(f"  - {sheet}: {len(df)}行")
            
            # 计算validCoverage权重 (grassman + grassnat)/100
            if 'grassman' in df.columns and 'grassnat' in df.columns:
                df['validCoverage'] = (df['grassman'] + df['grassnat']) / 100.0
                print(f"    计算validCoverage权重")
            else:
                print(f"    警告: {sheet} 缺少grassman或grassnat列，无法计算validCoverage")
                df['validCoverage'] = 1.0  # 默认权重为1
                
            # 计算实际VWC = VWC值 * validCoverage
            if VWC_COLUMNS[sheet] in df.columns:
                df['Actual_VWC'] = df[VWC_COLUMNS[sheet]] * df['validCoverage']
                print(f"    计算实际VWC: Actual_VWC = {VWC_COLUMNS[sheet]} * validCoverage")
            else:
                print(f"    错误: {sheet} 缺少{VWC_COLUMNS[sheet]}列")
                df['Actual_VWC'] = 0.0
                
            # 替换SM_Satellite和LAI_Satellite（如果存在地面实测数据）
            if 'SM' in df.columns:
                mask = df['SM'].notna()
                df.loc[mask, 'SM_Satellite'] = df.loc[mask, 'SM']
                print(f"    替换了 {mask.sum()} 行SM_Satellite数据")
            
            if 'LAI' in df.columns:
                mask = df['LAI'].notna()
                df.loc[mask, 'LAI_Satellite'] = df.loc[mask, 'LAI']
                print(f"    替换了 {mask.sum()} 行LAI_Satellite数据")
            
            data_dict[sheet] = df
        except Exception as e:
            print(f"  加载 {sheet} 时出错: {str(e)}")
            data_dict[sheet] = pd.DataFrame()
    
    return data_dict

def get_features_for_model(band, pol):
    """
    根据波段和极化类型获取特征列表（使用模型训练时的名称）
    
    参数:
    band (str): 波段 ('Ku', 'X', 'C')
    pol (str): 极化类型 ('H', 'V', 'HV')
    
    返回:
    list: 特征列名列表
    """
    # 使用模型训练时的特征名称
    features = [
        'LAI',  # 注意：训练时使用"LAI"而不是"LAI_Satellite"
        'SM',   # 注意：训练时使用"SM"而不是"SM_Satellite"
        'Grass_man', 
        'Grass_nat',
        'Shrub_bd', 'Shrub_be', 'Shrub_nd', 'Shrub_ne',
        'Tree_bd', 'Tree_be', 'Tree_nd', 'Tree_ne'
    ]
    
    # 添加VOD特征 - 根据模型类型
    if pol == 'H' or pol == 'V':
        # 单极化模型使用"VOD"
        features.append('VOD')
    elif pol == 'HV':
        # 双极化模型使用"VOD-Hpol"和"VOD-Vpol"
        features.extend(['VOD-Hpol', 'VOD-Vpol'])
    
    return features

def predict_vwc(data_dict, band, pol):
    """
    使用指定模型预测VWC，包括特征归一化和样本权重处理
    
    参数:
    data_dict (dict): 包含所有sheet数据的字典
    band (str): 波段 ('Ku', 'X', 'C')
    pol (str): 极化类型 ('H', 'V', 'HV')
    
    返回:
    dict: 包含每个sheet预测结果的字典
    """
    # 加载模型
    model_path = f"models/RFR_{band}_{pol}-pol_Type1.pkl"  # 不使用带权重的模型
    print(f"加载模型: {model_path}")
    
    if not os.path.exists(model_path):
        print(f"  模型文件不存在: {model_path}")
        return {}
    
    try:
        model = joblib.load(model_path)
        # 打印模型训练时的特征名称（如果可用）
        if hasattr(model, 'feature_names_in_'):
            print(f"  模型训练特征: {list(model.feature_names_in_)}")
    except Exception as e:
        print(f"  加载模型失败: {str(e)}")
        return {}
    
    # 获取特征列表
    features = get_features_for_model(band, pol)
    
    # 存储预测结果
    predictions = {}
    
    for sheet, df in data_dict.items():
        if df.empty:
            continue
        
        # 创建特征映射（将数据列名映射到模型期望的特征名）
        feature_mapping = {}
        for feature in features:
            # 特殊处理VOD特征
            if feature == 'VOD':
                # 单极化模型
                if pol == 'H':
                    if band == 'Ku':
                        feature_mapping['ku_vod_H'] = 'VOD'
                    elif band == 'X':
                        feature_mapping['x_vod_H'] = 'VOD'
                    elif band == 'C':
                        feature_mapping['c_vod_H'] = 'VOD'
                elif pol == 'V':
                    if band == 'Ku':
                        feature_mapping['ku_vod_V'] = 'VOD'
                    elif band == 'X':
                        feature_mapping['x_vod_V'] = 'VOD'
                    elif band == 'C':
                        feature_mapping['c_vod_V'] = 'VOD'
            elif feature == 'VOD-Hpol':
                # 双极化模型中的H极化
                if band == 'Ku':
                    feature_mapping['ku_vod_H'] = 'VOD-Hpol'
                elif band == 'X':
                    feature_mapping['x_vod_H'] = 'VOD-Hpol'
                elif band == 'C':
                    feature_mapping['c_vod_H'] = 'VOD-Hpol'
            elif feature == 'VOD-Vpol':
                # 双极化模型中的V极化
                if band == 'Ku':
                    feature_mapping['ku_vod_V'] = 'VOD-Vpol'
                elif band == 'X':
                    feature_mapping['x_vod_V'] = 'VOD-Vpol'
                elif band == 'C':
                    feature_mapping['c_vod_V'] = 'VOD-Vpol'
            else:
                # 其他特征映射
                if feature == 'LAI':
                    feature_mapping['LAI_Satellite'] = 'LAI'
                elif feature == 'SM':
                    feature_mapping['SM_Satellite'] = 'SM'
                elif feature == 'Grass_man':
                    feature_mapping['grassman'] = 'Grass_man'
                elif feature == 'Grass_nat':
                    feature_mapping['grassnat'] = 'Grass_nat'
                elif feature == 'Shrub_bd':
                    feature_mapping['shrubbd'] = 'Shrub_bd'
                elif feature == 'Shrub_be':
                    feature_mapping['shrubbe'] = 'Shrub_be'
                elif feature == 'Shrub_nd':
                    feature_mapping['shrubnd'] = 'Shrub_nd'
                elif feature == 'Shrub_ne':
                    feature_mapping['shrubne'] = 'Shrub_ne'
                elif feature == 'Tree_bd':
                    feature_mapping['treebd'] = 'Tree_bd'
                elif feature == 'Tree_be':
                    feature_mapping['treebe'] = 'Tree_be'
                elif feature == 'Tree_nd':
                    feature_mapping['treend'] = 'Tree_nd'
                elif feature == 'Tree_ne':
                    feature_mapping['treene'] = 'Tree_ne'
        
        # 检查是否包含所有必要特征
        missing_features = []
        for data_feature in feature_mapping.keys():
            if data_feature not in df.columns:
                missing_features.append(data_feature)
        
        if missing_features:
            print(f"  {sheet} 缺少特征: {', '.join(missing_features)}")
            continue
        
        # 准备数据（使用重命名的特征）
        X = df[list(feature_mapping.keys())].copy()
        X.columns = [feature_mapping[col] for col in X.columns]
        
        # 确保特征顺序与模型期望一致
        if hasattr(model, 'feature_names_in_'):
            X = X[list(model.feature_names_in_)]
        
        # ========== 添加归一化处理 ==========
        print(f"  {sheet} 应用归一化处理...")
        
        # 1. VOD特征归一化（除以2）
        vod_features = ['VOD', 'VOD-Hpol', 'VOD-Vpol']
        for vod_feature in vod_features:
            if vod_feature in X.columns:
                X[vod_feature] = X[vod_feature].clip(0, 2) / 2.0
                print(f"    归一化 {vod_feature}: 除以2")
        
        # 2. LAI特征归一化（除以6）
        if 'LAI' in X.columns:
            X['LAI'] = X['LAI'].clip(0, 6) / 6.0
            print(f"    归一化 LAI: 除以6")
        
        # 3. PFT特征归一化（除以100）
        pft_features = [
            'Grass_man', 'Grass_nat',
            'Shrub_bd', 'Shrub_be', 'Shrub_nd', 'Shrub_ne',
            'Tree_bd', 'Tree_be', 'Tree_nd', 'Tree_ne'
        ]
        
        for pft_feature in pft_features:
            if pft_feature in X.columns:
                X[pft_feature] = X[pft_feature] / 100.0
                print(f"    归一化 {pft_feature}: 除以100")
        # =================================
        
        # # 获取样本权重
        # sample_weights = df.loc[X.index, 'validCoverage']
        
        # 移除缺失值
        initial_count = len(X)
        X = X.dropna()
        # sample_weights = sample_weights.loc[X.index]  # 相应调整权重
        removed_count = initial_count - len(X)
        if removed_count > 0:
            print(f"  {sheet} 移除了 {removed_count} 行包含缺失值的数据")
        
        if X.empty:
            print(f"  {sheet} 无有效数据可用于预测")
            continue
        
        # 预测VWC
        y_pred = model.predict(X)
        predictions[sheet] = {
            'actual': df.loc[X.index, 'Actual_VWC'],  # 使用计算的实际VWC
            'predicted': y_pred,
            # 'weight': sample_weights.values,  # 保存权重
            'source': sheet,
            'row': df.loc[X.index, 'row'],
            'col': df.loc[X.index, 'col'],
            'lat': df.loc[X.index, 'Latitude'],
            'lon': df.loc[X.index, 'Longitude'],
            'date': df.loc[X.index, 'Date']
        }
        print(f"  {sheet} 预测完成: {len(y_pred)} 个样本")
    
    return predictions

def calculate_rmse(actual, predicted):
    """
    计算RMSE
    
    参数:
    actual (array-like): 实际值
    predicted (array-like): 预测值
    
    返回:
    float: RMSE值
    """
    return np.sqrt(np.mean((actual - predicted)**2))
    
def create_scatter_plots(all_predictions):
    """
    创建3x3散点子图，使用权重调整点的大小
    
    参数:
    all_predictions (dict): 包含所有波段和极化组合预测结果的字典
    """
    print("创建散点图...")
    
    # 创建图形
    fig = plt.figure(figsize=(18, 18))
    gs = gridspec.GridSpec(3, 3, figure=fig)
    
    # 设置全局标题
    fig.suptitle('', fontsize=24, fontweight='bold', y=0.95)
    
    # 遍历所有波段和极化组合
    for i, band in enumerate(BANDS):
        for j, pol in enumerate(POLS):
            ax = fig.add_subplot(gs[i, j])
            
            # 获取当前组合的预测结果
            predictions = all_predictions.get((band, pol), {})
            
            # 收集所有数据点
            all_actual = []
            all_predicted = []
            
            # 绘制每个sheet的数据点
            for sheet in SHEET_NAMES:
                if sheet in predictions:
                    actual = predictions[sheet]['actual']
                    predicted = predictions[sheet]['predicted']
                    
                    # 添加到总集合
                    all_actual.extend(actual)
                    all_predicted.extend(predicted)
                    
                    # 绘制当前sheet的点，使用权重调整点的大小
                    style = MARKER_STYLES[sheet]
                    
                    # # 计算点的大小（权重 * 50 + 10，确保最小尺寸）
                    # sizes = np.array(weights) * 50 + 10
                    
                    if 'facecolor' in style and 'edgecolor' in style:
                        # 对需要空心处理的点
                        ax.scatter(
                            actual, predicted,
                            marker=style['marker'],
                            facecolor=style['facecolor'],
                            edgecolor=style['edgecolor'],
                            s=50,
                            alpha=0.7,
                            linewidths=1.0,
                            label=sheet
                        )
                    else:
                        # 其他点
                        ax.scatter(
                            actual, predicted,
                            marker=style['marker'],
                            color=style['color'],
                            s=50,
                            alpha=0.7,
                            label=sheet
                        )
            
            # 如果没有数据，跳过
            if not all_actual:
                ax.text(0.5, 0.5, 'No Data', 
                        horizontalalignment='center', 
                        verticalalignment='center', 
                        transform=ax.transAxes,
                        fontsize=16)
                ax.set_title(f"{band}-{pol}", fontsize=16, fontweight='bold')
                continue
            
            # 计算RMSE
            rmse = calculate_rmse(np.array(all_actual), np.array(all_predicted))
            
            # 添加1:1参考线
            max_val = max(max(all_actual), max(all_predicted)) * 1.05
            ax.plot([0, max_val], [0, max_val], 'k--', lw=1.5, alpha=0.7)
            
            # 设置坐标轴范围
            ax.set_xlim(0, max_val)
            ax.set_ylim(0, max_val)
            
            # 设置坐标轴标签
            if i == 2:  # 最后一行
                ax.set_xlabel('Insitu VWC (kg/m²)', fontsize=14, fontweight='bold')
            if j == 0:  # 第一列
                ax.set_ylabel('Predicted VWC (kg/m²)', fontsize=14, fontweight='bold')
            
            # 添加标题和RMSE
            ax.set_title(f"{band}-{pol}", fontsize=16, fontweight='bold')
            ax.text(0.05, 0.95, f"RMSE: {rmse:.3f} kg/m²", 
                    transform=ax.transAxes,
                    fontsize=16,
                    fontweight='bold',
                    verticalalignment='top')
            
            # 添加网格
            ax.grid(True, linestyle='--', alpha=0.3)
    
    # 添加图例
    handles, labels = [], []
    for sheet in SHEET_NAMES:
        style = MARKER_STYLES[sheet]
        
        if 'facecolor' in style and 'edgecolor' in style:
            # 为空心点创建图例
            handles.append(plt.Line2D([0], [0], 
                                     marker=style['marker'], 
                                     color='w',
                                     markerfacecolor=style['facecolor'],  # 内部白色
                                     markeredgecolor=style['edgecolor'],  # 边缘颜色
                                     markersize=10,
                                     markeredgewidth=1.0))  # 边框宽度
        else:
            handles.append(plt.Line2D([0], [0], 
                                     marker=style['marker'], 
                                     color='w', 
                                     markerfacecolor=style['color'],
                                     markeredgecolor=style['color'], 
                                     markersize=10))
        labels.append(sheet)
    
    fig.legend(handles, labels, 
               loc='lower center', 
               ncol=4, 
               fontsize=12,
               frameon=True,
               fancybox=True,
               shadow=True,
               bbox_to_anchor=(0.5, 0.02),
               title="Dataset")
    
    # 调整布局
    plt.tight_layout(rect=[0, 0.05, 1, 0.95])
    
    # 保存图像
    output_dir = Path(r"E:\data\VWC\test-VWC\SMEX02_CLASIC07_SMEX08_SMAPVEX16")
    output_dir.mkdir(parents=True, exist_ok=True)
    fig_path = output_dir / "customizePFT_AllSMAPInsituData_VWC_Scatter.png"
    plt.savefig(fig_path, dpi=300, bbox_inches='tight')
    print(f"散点图已保存至: {fig_path}")
    plt.close()

def save_prediction_details(all_predictions):
    """
    将预测结果保存到Excel文件中，包含权重信息
    
    参数:
    all_predictions (dict): 包含所有波段和极化组合预测结果的字典
    """
    output_dir = Path(r"E:\data\VWC\test-VWC\SMEX02_CLASIC07_SMEX08_SMAPVEX16")
    output_file = output_dir / "customizePFT_predictions_details.xlsx"
    
    # 检查是否有数据可保存
    has_data = False
    for (band, pol), predictions in all_predictions.items():
        if predictions and any(predictions.values()):
            has_data = True
            break
    
    if not has_data:
        print("警告: 没有预测数据可保存")
        return
    
    # 创建Excel写入器
    with pd.ExcelWriter(output_file, engine='openpyxl') as writer:
        # 遍历所有波段和极化组合
        for (band, pol), predictions in all_predictions.items():
            if not predictions or not any(predictions.values()):
                print(f"跳过空数据集: {band}-{pol}")
                continue
                
            # 创建当前组合的数据框
            all_data = []
            
            # 收集所有sheet的数据
            for sheet, data in predictions.items():
                if not data or len(data.get('actual', [])) == 0:
                    print(f"  跳过空sheet: {sheet}")
                    continue
                
                # 创建当前sheet的数据框
                sheet_df = pd.DataFrame({
                    'Date': data['date'],
                    'Row': data['row'],
                    'Col': data['col'],
                    'Latitude': data['lat'],
                    'Longitude': data['lon'],
                    'Actual_VWC': data['actual'],
                    'Predicted_VWC': data['predicted'],
                    # 'Weight': data['weight'],
                    'Source': data['source']
                })
                
                # 添加波段和极化信息
                sheet_df['Band'] = band
                sheet_df['Polarization'] = pol
                
                all_data.append(sheet_df)
            
            # 合并所有数据
            if all_data:
                combined_df = pd.concat(all_data, ignore_index=True)
                
                # 保存到Excel
                sheet_name = f"{band}_{pol}"
                combined_df.to_excel(writer, sheet_name=sheet_name, index=False)
                print(f"保存预测结果到: {sheet_name} ({len(combined_df)}行)")
    
    print(f"所有预测结果已保存至: {output_file}")

def main():
    # 输入文件路径
    input_file = r"E:\data\VWC\test-VWC\SMEX02_CLASIC07_SMEX08_SMAPVEX16\InsituData_Pixel_ML.xlsx"
    
    # 加载并预处理数据，计算validCoverage和实际VWC
    data_dict = load_and_preprocess_data(input_file)
    
    # 存储所有预测结果
    all_predictions = {}
    
    # 遍历所有波段和极化组合
    for band in BANDS:
        for pol in POLS:
            print(f"\n处理波段-极化组合: {band}-{pol}")
            predictions = predict_vwc(data_dict, band, pol)
            all_predictions[(band, pol)] = predictions
    
    # 创建散点图
    create_scatter_plots(all_predictions)
    
    # 保存预测结果到Excel
    save_prediction_details(all_predictions)
    
    print("\n处理完成!")

if __name__ == "__main__":
    main()

加载文件: E:\data\VWC\test-VWC\SMEX02_CLASIC07_SMEX08_SMAPVEX16\InsituData_Pixel_ML.xlsx
  - SMEX02: 16行
    计算validCoverage权重
    计算实际VWC: Actual_VWC = VWC-Field * validCoverage
    替换了 14 行SM_Satellite数据
  - CLASIC07: 18行
    计算validCoverage权重
    计算实际VWC: Actual_VWC = VWC (kg/m²) * validCoverage
  - SMAPVEX08: 6行
    计算validCoverage权重
    计算实际VWC: Actual_VWC = VWC * validCoverage
    替换了 6 行LAI_Satellite数据
  - SMAPVEX16: 115行
    计算validCoverage权重
    计算实际VWC: Actual_VWC = PLANT_WATER_CONTENT_AREA * validCoverage

处理波段-极化组合: Ku-H
加载模型: models/RFR_Ku_H-pol_Type1.pkl
  模型训练特征: ['VOD', 'LAI', 'SM', 'Grass_man', 'Grass_nat', 'Shrub_bd', 'Shrub_be', 'Shrub_nd', 'Shrub_ne', 'Tree_bd', 'Tree_be', 'Tree_nd', 'Tree_ne']
  SMEX02 应用归一化处理...
    归一化 VOD: 除以2
    归一化 LAI: 除以6
    归一化 Grass_man: 除以100
    归一化 Grass_nat: 除以100
    归一化 Shrub_bd: 除以100
    归一化 Shrub_be: 除以100
    归一化 Shrub_nd: 除以100
    归一化 Shrub_ne: 除以100
    归一化 Tree_bd: 除以100
    归一化 Tree_be: 除以100
    归一化 Tree_nd: 除以100
    归一化 Tre

# 补充：训练数据的VOD-LFMC、VOD-AGB散点图

In [4]:
import pandas as pd
import matplotlib.pyplot as plt
import re
import os
import numpy as np
import seaborn as sns
from scipy.interpolate import UnivariateSpline
from sklearn.metrics import r2_score

# 文件路径
file_path = r'E:\Matlab\EX2025\AuxiliaryData\VWC_ML_Data.xlsx'

# 验证文件是否存在
if not os.path.exists(file_path):
    raise FileNotFoundError(f"文件未找到: {file_path}")

# 创建保存目录
save_dir = r'E:\文章\HUITU\Fig'
os.makedirs(save_dir, exist_ok=True)  # 确保目录存在

# 获取所有sheet名称
all_sheets = pd.ExcelFile(file_path).sheet_names

# 正则表达式匹配目标sheet命名格式
pattern = r'VOD_(Ku|X|C)_(H|V)pol_Asc_Cleaned_Type1'
target_sheets = [s for s in all_sheets if re.match(pattern, s)]

# 没有匹配sheet时的处理
if not target_sheets:
    raise ValueError("未找到符合命名规则的sheet")

# 波段和极化方式映射
band_map = {
    'Ku': 'Ku',
    'X': 'X',
    'C': 'C'
}
pol_map = {
    'H': 'Horizontal',
    'V': 'Vertical'
}

# 设置全局绘图风格
sns.set(style="whitegrid", font_scale=1.2)
plt.rcParams['font.family'] = 'Times New Roman'

# 遍历所有匹配的sheet
for sheet_name in target_sheets:
    # 提取波段和极化方式
    parts = sheet_name.split('_')
    band = parts[1]  # Ku/X/C
    pol = parts[2][0]  # H/V
    
    # 确定目标列名 (例如: VOD_Ku_Hpol_Asc)
    target_col = f'VOD_{band}_{pol}pol_Asc'
    
    # 获取波段和极化的友好名称
    band_name = band_map.get(band, band)
    pol_name = pol_map.get(pol, pol)
    
    try:
        # 读取sheet数据
        df = pd.read_excel(file_path, sheet_name=sheet_name)
        
        # 检查所需列是否存在
        required_cols = [target_col, 'LFMC', 'AGB']
        missing_cols = [col for col in required_cols if col not in df.columns]
        
        if missing_cols:
            print(f"Sheet '{sheet_name}'缺少列: {missing_cols}")
            continue
            
        # 创建带两个子图的图像
        fig, axes = plt.subplots(1, 2, figsize=(16, 6))
        fig.suptitle(f'{band_name}-Band, {pol_name} Polarization', 
                    fontsize=20, fontweight='bold', y=1.02)
        
        # ============= 左图: VOD vs LFMC =============
        ax1 = axes[0]
        x1 = df['LFMC'].values
        y1 = df[target_col].values
        
        # 删除NaN值
        valid_idx1 = ~np.isnan(x1) & ~np.isnan(y1)
        x1_clean = x1[valid_idx1]
        y1_clean = y1[valid_idx1]
        
        # 绘制散点图
        sns.scatterplot(x=x1_clean, y=y1_clean, ax=ax1, alpha=0.7, edgecolor='w', s=60)
        
        if len(x1_clean) > 3:  # 确保有足够的数据点
            # 按x值排序确保升序排列
            sort_idx = np.argsort(x1_clean)
            x1_sorted = x1_clean[sort_idx]
            y1_sorted = y1_clean[sort_idx]
            
            # 使用UnivariateSpline替代smoothing_spline
            # 它更灵活，不易出错
            try:
                spline1 = UnivariateSpline(x1_sorted, y1_sorted, s=len(x1_sorted)*3)
                x1_smooth = np.linspace(min(x1_sorted), max(x1_sorted), 300)
                y1_smooth = spline1(x1_smooth)
                
                # 计算R²
                y1_pred = spline1(x1_sorted)
                r2_1 = r2_score(y1_sorted, y1_pred)
                
                # 绘制拟合线
                ax1.plot(x1_smooth, y1_smooth, 'r-', lw=3, alpha=0.8)
                
                # 添加R²值
                ax1.text(0.05, 0.95, f'$R^2$ = {r2_1:.2f}', 
                         transform=ax1.transAxes, 
                         fontsize=14,
                         verticalalignment='top',
                         bbox=dict(facecolor='white', alpha=0.8, edgecolor='none'))
            except Exception as e:
                print(f"LFMC拟合失败: {str(e)}")
        
        # 设置轴标签和网格
        ax1.set_xlabel('LFMC (%)', fontsize=14)
        ax1.set_ylabel('VOD', fontsize=14)
        ax1.grid(True, linestyle='--', alpha=0.7)
        
        # ============= 右图: VOD vs AGB =============
        ax2 = axes[1]
        x2 = df['AGB'].values
        y2 = df[target_col].values
        
        # 删除NaN值
        valid_idx2 = ~np.isnan(x2) & ~np.isnan(y2)
        x2_clean = x2[valid_idx2]
        y2_clean = y2[valid_idx2]
        
        # 绘制散点图
        sns.scatterplot(x=x2_clean, y=y2_clean, ax=ax2, alpha=0.7, edgecolor='w', s=60)
        
        if len(x2_clean) > 3:  # 确保有足够的数据点
            # 按x值排序确保升序排列
            sort_idx = np.argsort(x2_clean)
            x2_sorted = x2_clean[sort_idx]
            y2_sorted = y2_clean[sort_idx]
            
            try:
                spline2 = UnivariateSpline(x2_sorted, y2_sorted, s=len(x2_sorted)*3)
                x2_smooth = np.linspace(min(x2_sorted), max(x2_sorted), 300)
                y2_smooth = spline2(x2_smooth)
                
                # 计算R²
                y2_pred = spline2(x2_sorted)
                r2_2 = r2_score(y2_sorted, y2_pred)
                
                # 绘制拟合线
                ax2.plot(x2_smooth, y2_smooth, 'r-', lw=3, alpha=0.8)
                
                # 添加R²值
                ax2.text(0.05, 0.95, f'$R^2$ = {r2_2:.2f}', 
                         transform=ax2.transAxes, 
                         fontsize=14,
                         verticalalignment='top',
                         bbox=dict(facecolor='white', alpha=0.8, edgecolor='none'))
            except Exception as e:
                print(f"AGB拟合失败: {str(e)}")
        
        # 设置轴标签和网格
        ax2.set_xlabel('AGB (Mg/ha)', fontsize=14)
        ax2.set_ylabel('')  # 共用同一个y轴标签
        ax2.grid(True, linestyle='--', alpha=0.7)
        
        # 自动调整布局
        plt.tight_layout()
        
        # 保存图像
        save_path = os.path.join(save_dir, f'VOD_{band}_{pol}_Correlation.png')
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close(fig)  # 关闭图形释放内存
        print(f"成功保存图像: {save_path}")
        
    except Exception as e:
        print(f"处理sheet '{sheet_name}'时出错: {str(e)}")
        plt.close()  # 确保出错时关闭图形

print("所有处理完成！")

成功保存图像: E:\文章\HUITU\Fig\VOD_Ku_H_Correlation.png
成功保存图像: E:\文章\HUITU\Fig\VOD_Ku_V_Correlation.png
成功保存图像: E:\文章\HUITU\Fig\VOD_X_H_Correlation.png
成功保存图像: E:\文章\HUITU\Fig\VOD_X_V_Correlation.png
成功保存图像: E:\文章\HUITU\Fig\VOD_C_H_Correlation.png
成功保存图像: E:\文章\HUITU\Fig\VOD_C_V_Correlation.png
所有处理完成！


In [None]:
# 尝试：单独分析LFMC和VOD，为什么相关性这么低，不利于后续的训练，尝试去除整体趋势异常值。

In [None]:
# 去除LFMC的异常值（3个标准差）

In [6]:
import pandas as pd
import matplotlib.pyplot as plt
import re
import os
import numpy as np
import seaborn as sns
from scipy.interpolate import UnivariateSpline
from sklearn.metrics import r2_score

# 文件路径
file_path = r'E:\Matlab\EX2025\AuxiliaryData\VWC_ML_Data.xlsx'

# 验证文件是否存在
if not os.path.exists(file_path):
    raise FileNotFoundError(f"文件未找到: {file_path}")

# 创建保存目录
save_dir = r'E:\文章\HUITU\Fig'
os.makedirs(save_dir, exist_ok=True)  # 确保目录存在

# 获取所有sheet名称
all_sheets = pd.ExcelFile(file_path).sheet_names

# 正则表达式匹配目标sheet命名格式
pattern = r'VOD_(Ku|X|C)_(H|V)pol_Asc_Cleaned_Type1'
target_sheets = [s for s in all_sheets if re.match(pattern, s)]

# 没有匹配sheet时的处理
if not target_sheets:
    raise ValueError("未找到符合命名规则的sheet")

# 波段和极化方式映射
band_map = {
    'Ku': 'Ku',
    'X': 'X',
    'C': 'C'
}
pol_map = {
    'H': 'Horizontal',
    'V': 'Vertical'
}

# 设置全局绘图风格
sns.set(style="whitegrid", font_scale=1.2)
plt.rcParams['font.family'] = 'Times New Roman'

# 遍历所有匹配的sheet
for sheet_name in target_sheets:
    # 提取波段和极化方式
    parts = sheet_name.split('_')
    band = parts[1]  # Ku/X/C
    pol = parts[2][0]  # H/V
    
    # 确定目标列名 (例如: VOD_Ku_Hpol_Asc)
    target_col = f'VOD_{band}_{pol}pol_Asc'
    
    # 获取波段和极化的友好名称
    band_name = band_map.get(band, band)
    pol_name = pol_map.get(pol, pol)
    
    try:
        # 读取sheet数据
        df = pd.read_excel(file_path, sheet_name=sheet_name)
        
        # 检查所需列是否存在
        required_cols = [target_col, 'LFMC', 'AGB']
        missing_cols = [col for col in required_cols if col not in df.columns]
        
        if missing_cols:
            print(f"Sheet '{sheet_name}'缺少列: {missing_cols}")
            continue
        
        # ======== 新增：LFMC异常值处理 ========
        if 'LFMC' in df.columns:
            # 计算均值和标准差（忽略NaN）
            lfmc_mean = df['LFMC'].mean()
            lfmc_std = df['LFMC'].std()
            
            # 计算异常阈值（仅上限）
            upper_limit = lfmc_mean + 3 * lfmc_std
            
            # 打印诊断信息
            print(f"[{sheet_name}] LFMC处理 - 均值: {lfmc_mean:.2f}, 标准差: {lfmc_std:.2f}, 上限: {upper_limit:.2f}")
            print(f"处理前数据点: {len(df)}")
            
            # 创建异常值掩码
            is_outlier = df['LFMC'] > upper_limit
            
            # 打印异常值数量
            print(f"发现异常值: {is_outlier.sum()}个")
            
            # 将异常值替换为NaN
            df.loc[is_outlier, 'LFMC'] = np.nan
            
            print(f"处理后有效点: {len(df) - is_outlier.sum()}")
        
        # ======== 异常值处理结束 ========
        
        # 创建带两个子图的图像
        fig, axes = plt.subplots(1, 2, figsize=(16, 6))
        fig.suptitle(f'{band_name}-Band, {pol_name} Polarization', 
                    fontsize=20, fontweight='bold', y=1.02)
        
        # ============= 左图: VOD vs LFMC =============
        ax1 = axes[0]
        x1 = df['LFMC'].values
        y1 = df[target_col].values
        
        # 删除NaN值
        valid_idx1 = ~np.isnan(x1) & ~np.isnan(y1)
        x1_clean = x1[valid_idx1]
        y1_clean = y1[valid_idx1]
        
        # 绘制散点图
        sns.scatterplot(x=x1_clean, y=y1_clean, ax=ax1, alpha=0.7, edgecolor='w', s=60)
        
        if len(x1_clean) > 3:  # 确保有足够的数据点
            # 按x值排序确保升序排列
            sort_idx = np.argsort(x1_clean)
            x1_sorted = x1_clean[sort_idx]
            y1_sorted = y1_clean[sort_idx]
            
            # 使用UnivariateSpline替代smoothing_spline
            # 它更灵活，不易出错
            try:
                spline1 = UnivariateSpline(x1_sorted, y1_sorted, s=len(x1_sorted)*3)
                x1_smooth = np.linspace(min(x1_sorted), max(x1_sorted), 300)
                y1_smooth = spline1(x1_smooth)
                
                # 计算R²
                y1_pred = spline1(x1_sorted)
                r2_1 = r2_score(y1_sorted, y1_pred)
                
                # 绘制拟合线
                ax1.plot(x1_smooth, y1_smooth, 'r-', lw=3, alpha=0.8)
                
                # 添加R²值
                ax1.text(0.05, 0.95, f'$R^2$ = {r2_1:.2f}', 
                         transform=ax1.transAxes, 
                         fontsize=14,
                         verticalalignment='top',
                         bbox=dict(facecolor='white', alpha=0.8, edgecolor='none'))
            except Exception as e:
                print(f"LFMC拟合失败: {str(e)}")
        
        # 设置轴标签和网格
        ax1.set_xlabel('LFMC (%)', fontsize=14)
        ax1.set_ylabel('VOD', fontsize=14)
        ax1.grid(True, linestyle='--', alpha=0.7)
        
        # ============= 右图: VOD vs AGB =============
        ax2 = axes[1]
        x2 = df['AGB'].values
        y2 = df[target_col].values
        
        # 删除NaN值
        valid_idx2 = ~np.isnan(x2) & ~np.isnan(y2)
        x2_clean = x2[valid_idx2]
        y2_clean = y2[valid_idx2]
        
        # 绘制散点图
        sns.scatterplot(x=x2_clean, y=y2_clean, ax=ax2, alpha=0.7, edgecolor='w', s=60)
        
        if len(x2_clean) > 3:  # 确保有足够的数据点
            # 按x值排序确保升序排列
            sort_idx = np.argsort(x2_clean)
            x2_sorted = x2_clean[sort_idx]
            y2_sorted = y2_clean[sort_idx]
            
            try:
                spline2 = UnivariateSpline(x2_sorted, y2_sorted, s=len(x2_sorted)*3)
                x2_smooth = np.linspace(min(x2_sorted), max(x2_sorted), 300)
                y2_smooth = spline2(x2_smooth)
                
                # 计算R²
                y2_pred = spline2(x2_sorted)
                r2_2 = r2_score(y2_sorted, y2_pred)
                
                # 绘制拟合线
                ax2.plot(x2_smooth, y2_smooth, 'r-', lw=3, alpha=0.8)
                
                # 添加R²值
                ax2.text(0.05, 0.95, f'$R^2$ = {r2_2:.2f}', 
                         transform=ax2.transAxes, 
                         fontsize=14,
                         verticalalignment='top',
                         bbox=dict(facecolor='white', alpha=0.8, edgecolor='none'))
            except Exception as e:
                print(f"AGB拟合失败: {str(e)}")
        
        # 设置轴标签和网格
        ax2.set_xlabel('AGB (Mg/ha)', fontsize=14)
        ax2.set_ylabel('')  # 共用同一个y轴标签
        ax2.grid(True, linestyle='--', alpha=0.7)
        
        # 自动调整布局
        plt.tight_layout()
        
        # 保存图像
        save_path = os.path.join(save_dir, f'VOD_{band}_{pol}_Correlation_Cleaned.png')
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close(fig)  # 关闭图形释放内存
        print(f"成功保存图像: {save_path}")
        
    except Exception as e:
        print(f"处理sheet '{sheet_name}'时出错: {str(e)}")
        plt.close()  # 确保出错时关闭图形

print("所有处理完成！")

[VOD_Ku_Hpol_Asc_Cleaned_Type1] LFMC处理 - 均值: 113.42, 标准差: 36.62, 上限: 223.30
处理前数据点: 26692
发现异常值: 357个
处理后有效点: 26335
成功保存图像: E:\文章\HUITU\Fig\VOD_Ku_H_Correlation_Cleaned.png
[VOD_Ku_Vpol_Asc_Cleaned_Type1] LFMC处理 - 均值: 114.64, 标准差: 36.95, 上限: 225.49
处理前数据点: 18884
发现异常值: 250个
处理后有效点: 18634
成功保存图像: E:\文章\HUITU\Fig\VOD_Ku_V_Correlation_Cleaned.png
[VOD_X_Hpol_Asc_Cleaned_Type1] LFMC处理 - 均值: 113.58, 标准差: 36.65, 上限: 223.51
处理前数据点: 27187
发现异常值: 358个
处理后有效点: 26829
成功保存图像: E:\文章\HUITU\Fig\VOD_X_H_Correlation_Cleaned.png
[VOD_X_Vpol_Asc_Cleaned_Type1] LFMC处理 - 均值: 113.83, 标准差: 36.58, 上限: 223.58
处理前数据点: 23909
发现异常值: 319个
处理后有效点: 23590
成功保存图像: E:\文章\HUITU\Fig\VOD_X_V_Correlation_Cleaned.png
[VOD_C_Hpol_Asc_Cleaned_Type1] LFMC处理 - 均值: 113.55, 标准差: 36.64, 上限: 223.47
处理前数据点: 27114
发现异常值: 365个
处理后有效点: 26749
成功保存图像: E:\文章\HUITU\Fig\VOD_C_H_Correlation_Cleaned.png
[VOD_C_Vpol_Asc_Cleaned_Type1] LFMC处理 - 均值: 113.11, 标准差: 36.55, 上限: 222.76
处理前数据点: 22223
发现异常值: 302个
处理后有效点: 21921
成功保存图像: E:\文章\HUITU\Fig\VO

In [None]:
# 我感觉这个确实是做不到……

# .使用高度清洗的模型进行估算(_purify)

2017-2018

In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import matplotlib.gridspec as gridspec
import joblib
import os
from pathlib import Path
import warnings
from datetime import datetime
from sklearn.metrics import mean_squared_error, r2_score
from scipy.interpolate import make_interp_spline  # 导入样条插值函数
warnings.filterwarnings('ignore')

# 设置全局字体
plt.rcParams['font.family'] = 'Arial'
plt.rcParams['font.weight'] = 'bold'
plt.rcParams['axes.labelweight'] = 'bold'
plt.rcParams['axes.titleweight'] = 'bold'
plt.rcParams['figure.titlesize'] = 16
plt.rcParams['figure.titleweight'] = 'bold'

# 常量定义
BANDS = ['Ku', 'X', 'C']
BAND_COLORS = {
    'Ku': '#1f77b4',  # 蓝色
    'X': '#ff7f0e',   # 橙色
    'C': '#2ca02c'    # 绿色
}
POLS = ['H', 'V', 'HV']
POL_LINESTYLES = {
    'H': '-',     # 实线
    'V': '--',    # 虚线
    'HV': ':'     # 点线
}
POL_MARKERS = {
    'H': '+',  # 加号
    'V': '^',  # 三角形
    'HV': 's'  # 正方形
}
POL_LABELS = {
    'H': 'H-Pol',
    'V': 'V-Pol',
    'HV': 'H&V-Pol'  # 修改这里
}

# 植被类型映射
VEGETATION_TYPES = {
    'CornVegMeasured': 'Corn (2017)',
    'OatVegMeasured': 'Oat (2017)',
    'GrassVWC': 'Grass (2018)'
}

# 实测与拟合数据映射
FITTING_MAPPING = {
    'CornVegMeasured': 'CornVegFitting',
    'OatVegMeasured': 'OatVegFitting',
    'GrassVWC': 'GrassVWC'  # 2018年没有拟合数据
}

# 实测VWC列名映射
ACTUAL_COL_MAPPING = {
    'CornVegMeasured': 'total_VWC(kg/m2)',
    'OatVegMeasured': 'total_VWC(kg/m2)',
    'GrassVWC': 'vegetation water content(kg/m2)'
}

# 实测数据样式
ACTUAL_STYLE = {
    'color': 'black',
    'marker': 'o',
    'markersize': 8,
    'markerfacecolor': 'none',
    'markeredgewidth': 1.5,
    'label': 'Measured'
}

def load_data(file_path):
    """加载Excel文件中的所有工作表"""
    print(f"加载文件: {file_path}")
    data_dict = {}
    
    # 获取所有工作表名称
    xl = pd.ExcelFile(file_path)
    sheet_names = xl.sheet_names
    
    for sheet in sheet_names:
        try:
            df = pd.read_excel(file_path, sheet_name=sheet)
            print(f"  - {sheet}: {len(df)}行")
            
            # 确保日期是datetime类型
            if 'Date' in df.columns:
                df['Date'] = pd.to_datetime(df['Date'])
            
            data_dict[sheet] = df
        except Exception as e:
            print(f"  加载 {sheet} 时出错: {str(e)}")
            data_dict[sheet] = pd.DataFrame()
    
    return data_dict

def predict_vwc_for_sheet(df, band, pol):
    """
    使用机器学习模型预测VWC，确保特征名称匹配
    """
    # 加载模型
    model_path = f"models/RFR_{band}_{pol}pol_Purify_Type1.pkl"
    if not os.path.exists(model_path):
        print(f"警告: 模型文件不存在: {model_path}")
        return pd.Series(np.nan, index=df.index)
    
    try:
        model = joblib.load(model_path)
        print(f"加载模型: {model_path}")
        
        # 获取模型期望的特征名称
        if hasattr(model, 'feature_names_in_'):
            expected_features = list(model.feature_names_in_)
            print(f"  模型期望特征: {expected_features}")
        else:
            print("  警告: 模型没有feature_names_in_属性")
            expected_features = []
    except Exception as e:
        print(f"加载模型失败: {str(e)}")
        return pd.Series(np.nan, index=df.index)
    
    # 1. 优先使用地面实测数据替换卫星数据
    if 'SM' in df.columns:
        sm_mask = df['SM'].notna() & (df['SM'] > 0)
        if sm_mask.any():
            df.loc[sm_mask, 'SM_Satellite'] = df.loc[sm_mask, 'SM']
            print(f"  使用实测SM替换了 {sm_mask.sum()} 行数据")
    
    if 'LAI' in df.columns:
        lai_mask = df['LAI'].notna() & (df['LAI'] > 0)
        if lai_mask.any():
            df.loc[lai_mask, 'LAI_Satellite'] = df.loc[lai_mask, 'LAI']
            print(f"  使用实测LAI替换了 {lai_mask.sum()} 行数据")
    
    # 2. 根据波段和极化组合确定特征映射
    feature_mapping = {}
    
    # Ku波段
    if band == 'Ku':
        if pol == 'H':
            feature_mapping = {
                'ku_vod_H': 'VOD',
                'LAI_Satellite': 'LAI',
                'SM_Satellite': 'SM',
                'grassman': 'Grass_man',
                'grassnat': 'Grass_nat',
                'shrubbd': 'Shrub_bd',
                'shrubbe': 'Shrub_be',
                'shrubnd': 'Shrub_nd',
                'shrubne': 'Shrub_ne',
                'treebd': 'Tree_bd',
                'treebe': 'Tree_be',
                'treend': 'Tree_nd',
                'treene': 'Tree_ne'
            }
        elif pol == 'V':
            feature_mapping = {
                'ku_vod_V': 'VOD',
                'LAI_Satellite': 'LAI',
                'SM_Satellite': 'SM',
                'grassman': 'Grass_man',
                'grassnat': 'Grass_nat',
                'shrubbd': 'Shrub_bd',
                'shrubbe': 'Shrub_be',
                'shrubnd': 'Shrub_nd',
                'shrubne': 'Shrub_ne',
                'treebd': 'Tree_bd',
                'treebe': 'Tree_be',
                'treend': 'Tree_nd',
                'treene': 'Tree_ne'
            }
        elif pol == 'HV':
            feature_mapping = {
                'ku_vod_H': 'VODHpol',
                'ku_vod_V': 'VODVpol',
                'LAI_Satellite': 'LAI',
                'SM_Satellite': 'SM',
                'grassman': 'Grass_man',
                'grassnat': 'Grass_nat',
                'shrubbd': 'Shrub_bd',
                'shrubbe': 'Shrub_be',
                'shrubnd': 'Shrub_nd',
                'shrubne': 'Shrub_ne',
                'treebd': 'Tree_bd',
                'treebe': 'Tree_be',
                'treend': 'Tree_nd',
                'treene': 'Tree_ne'
            }
    
    # X波段
    elif band == 'X':
        if pol == 'H':
            feature_mapping = {
                'x_vod_H': 'VOD',
                'LAI_Satellite': 'LAI',
                'SM_Satellite': 'SM',
                'grassman': 'Grass_man',
                'grassnat': 'Grass_nat',
                'shrubbd': 'Shrub_bd',
                'shrubbe': 'Shrub_be',
                'shrubnd': 'Shrub_nd',
                'shrubne': 'Shrub_ne',
                'treebd': 'Tree_bd',
                'treebe': 'Tree_be',
                'treend': 'Tree_nd',
                'treene': 'Tree_ne'
            }
        elif pol == 'V':
            feature_mapping = {
                'x_vod_V': 'VOD',
                'LAI_Satellite': 'LAI',
                'SM_Satellite': 'SM',
                'grassman': 'Grass_man',
                'grassnat': 'Grass_nat',
                'shrubbd': 'Shrub_bd',
                'shrubbe': 'Shrub_be',
                'shrubnd': 'Shrub_nd',
                'shrubne': 'Shrub_ne',
                'treebd': 'Tree_bd',
                'treebe': 'Tree_be',
                'treend': 'Tree_nd',
                'treene': 'Tree_ne'
            }
        elif pol == 'HV':
            feature_mapping = {
                'x_vod_H': 'VODHpol',
                'x_vod_V': 'VODVpol',
                'LAI_Satellite': 'LAI',
                'SM_Satellite': 'SM',
                'grassman': 'Grass_man',
                'grassnat': 'Grass_nat',
                'shrubbd': 'Shrub_bd',
                'shrubbe': 'Shrub_be',
                'shrubnd': 'Shrub_nd',
                'shrubne': 'Shrub_ne',
                'treebd': 'Tree_bd',
                'treebe': 'Tree_be',
                'treend': 'Tree_nd',
                'treene': 'Tree_ne'
            }
    
    # C波段
    elif band == 'C':
        if pol == 'H':
            feature_mapping = {
                'c_vod_H': 'VOD',
                'LAI_Satellite': 'LAI',
                'SM_Satellite': 'SM',
                'grassman': 'Grass_man',
                'grassnat': 'Grass_nat',
                'shrubbd': 'Shrub_bd',
                'shrubbe': 'Shrub_be',
                'shrubnd': 'Shrub_nd',
                'shrubne': 'Shrub_ne',
                'treebd': 'Tree_bd',
                'treebe': 'Tree_be',
                'treend': 'Tree_nd',
                'treene': 'Tree_ne'
            }
        elif pol == 'V':
            feature_mapping = {
                'c_vod_V': 'VOD',
                'LAI_Satellite': 'LAI',
                'SM_Satellite': 'SM',
                'grassman': 'Grass_man',
                'grassnat': 'Grass_nat',
                'shrubbd': 'Shrub_bd',
                'shrubbe': 'Shrub_be',
                'shrubnd': 'Shrub_nd',
                'shrubne': 'Shrub_ne',
                'treebd': 'Tree_bd',
                'treebe': 'Tree_be',
                'treend': 'Tree_nd',
                'treene': 'Tree_ne'
            }
        elif pol == 'HV':
            feature_mapping = {
                'c_vod_H': 'VODHpol',
                'c_vod_V': 'VODVpol',
                'LAI_Satellite': 'LAI',
                'SM_Satellite': 'SM',
                'grassman': 'Grass_man',
                'grassnat': 'Grass_nat',
                'shrubbd': 'Shrub_bd',
                'shrubbe': 'Shrub_be',
                'shrubnd': 'Shrub_nd',
                'shrubne': 'Shrub_ne',
                'treebd': 'Tree_bd',
                'treebe': 'Tree_be',
                'treend': 'Tree_nd',
                'treene': 'Tree_ne'
            }
    
    # 时间序列插值
    if 'Date' in df.columns and not df.empty:
        # 确保按日期排序
        df = df.sort_values('Date')
        
        # 确定需要插值的特征列
        interpolate_cols = list(feature_mapping.keys())
        valid_cols = [col for col in interpolate_cols if col in df.columns]
        
        # 设置时间索引
        date_index = pd.DatetimeIndex(df['Date'])
        df_temp = df.set_index('Date')
        
        # 生成完整的时间序列范围
        full_range = pd.date_range(start=date_index.min(), end=date_index.max(), freq='D')
        df_full = df_temp.reindex(full_range)
        
        # 对特征列进行线性插值
        for col in valid_cols:
            df_full[col] = df_full[col].interpolate(method='time', limit_direction='both')
            print(f"  已完成{col}的时间序列插值")
        
        # 重置索引
        df = df_full.reset_index().rename(columns={'index': 'Date'})
    else:
        print("  无日期列或数据为空，跳过插值")
    
    # 3. 检查是否所有映射后的特征都存在
    missing_features = []
    for data_col in feature_mapping.keys():
        if data_col not in df.columns:
            missing_features.append(data_col)
    
    if missing_features:
        print(f"  缺少特征: {', '.join(missing_features)}")
        return pd.Series(np.nan, index=df.index)
    
    # 4. 准备特征数据
    X = pd.DataFrame()
    for data_col, model_feature in feature_mapping.items():
        X[model_feature] = df[data_col]
    
    # 5. 应用特征归一化
    # VOD特征归一化（除以2）
    vod_features = ['VOD', 'VODHpol', 'VODVpol']
    for vod_feature in vod_features:
        if vod_feature in X.columns:
            X[vod_feature] = X[vod_feature].clip(0, 2) / 2.0
    
    # LAI特征归一化（除以6）
    if 'LAI' in X.columns:
        X['LAI'] = X['LAI'].clip(0, 6) / 6.0
    
    # PFT特征归一化（除以100）
    pft_features = [
        'Grass_man', 'Grass_nat',
        'Shrub_bd', 'Shrub_be', 'Shrub_nd', 'Shrub_ne',
        'Tree_bd', 'Tree_be', 'Tree_nd', 'Tree_ne'
    ]
    for pft_feature in pft_features:
        if pft_feature in X.columns:
            X[pft_feature] = X[pft_feature] / 100.0
    
    # 6. 移除缺失值
    initial_count = len(X)
    X = X.dropna()
    removed_count = initial_count - len(X)
    if removed_count > 0:
        print(f"  移除了 {removed_count} 行包含缺失值的数据")
    
    if X.empty:
        print("  无有效数据可用于预测")
        return pd.Series(np.nan, index=df.index)
    
    # 7. 确保特征顺序与模型期望一致
    if hasattr(model, 'feature_names_in_'):
        X = X[expected_features]
    
    # 8. 预测VWC
    try:
        y_pred = model.predict(X)
        
        # 创建完整长度的预测序列
        full_pred = pd.Series(np.nan, index=df.index)
        full_pred.loc[X.index] = y_pred
        
        return full_pred
    except Exception as e:
        print(f"  预测失败: {str(e)}")
        return pd.Series(np.nan, index=df.index)

def create_combined_plots(data_dict_2017, data_dict_2018):
    """创建组合时间序列图并保存预测结果"""
    print("创建组合时间序列图...")
    
    # 创建输出目录
    output_dir = Path("prediction_results")
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # 创建图形 - 增加底部空间用于多行图例
    fig = plt.figure(figsize=(15, 18))
    gs = gridspec.GridSpec(4, 1, figure=fig, height_ratios=[1, 1, 1, 0.4], hspace=0.3)
    
    # 设置全局标题
    fig.suptitle('Vegetation Water Content Time Series', fontsize=20, fontweight='bold', y=0.95)
    
    # 植被类型列表
    vegetation_types = [
        ('CornVegMeasured', data_dict_2017),  # 玉米
        ('OatVegMeasured', data_dict_2017),   # 燕麦
        ('GrassVWC', data_dict_2018)           # 草
    ]
    
    # 实测VWC列名映射
    ACTUAL_COL_MAPPING = {
        'CornVegMeasured': 'total_VWC(kg/m2)',
        'OatVegMeasured': 'total_VWC(kg/m2)',
        'GrassVWC': 'vegetation water content(kg/m2)'
    }
    
    # 存储所有评估指标
    all_metrics = {}
    
    # 存储所有预测结果
    all_predictions = {}
    
    # 根据图片定义极化标记样式
    POL_MARKERS = {
        'H': '+',   # 加号
        'V': '^',   # 三角形
        'HV': 's'   # 正方形
    }
    
    # 波段颜色
    BAND_COLORS = {
        'Ku': 'blue',
        'X': 'green',
        'C': 'red'
    }
    
    # 极化线型
    POL_LINESTYLES = {
        'H': '-',
        'V': '--',
        'HV': '-.'
    }
    
    # 极化名称映射 - 修改1：使用新的标题格式
    POL_NAMES = {
        'H': 'H-Pol',
        'V': 'V-Pol',
        'HV': 'H&V-Pol'
    }
    
    # 波段名称映射 - 修改2：使用新的标题格式
    BAND_NAMES = {
        'Ku': 'Ku-Band',
        'X': 'X-Band',
        'C': 'C-Band'
    }
    
    # 植被类型显示名称
    VEGETATION_TYPES = {
        'CornVegMeasured': 'Corn',
        'OatVegMeasured': 'Oat',
        'GrassVWC': 'Grass'
    }
    
    # 遍历所有植被类型
    for idx, (veg_type, data_dict) in enumerate(vegetation_types):
        ax = fig.add_subplot(gs[idx])
        
        # 获取当前植被类型的实测列名
        actual_col = ACTUAL_COL_MAPPING[veg_type]
        
        # 初始化Y轴范围
        y_min = float('inf')
        y_max = float('-inf')
        
        # 获取实测数据
        if veg_type in data_dict:
            df_measured = data_dict[veg_type].copy()
            
            # 确保日期列存在
            if 'Date' not in df_measured.columns:
                print(f"警告: {veg_type} 中没有 'Date' 列")
                continue
            
            # 按日期排序
            df_measured = df_measured.sort_values('Date')
            
            # 更新Y轴范围（实测值）
            if actual_col in df_measured.columns:
                measured_values = df_measured[actual_col].dropna()
                if not measured_values.empty:
                    y_min = min(y_min, measured_values.min())
                    y_max = max(y_max, measured_values.max())
            
            # 获取拟合数据用于预测
            fitting_sheet = FITTING_MAPPING.get(veg_type, veg_type)
            if fitting_sheet in data_dict:
                df_fitting = data_dict[fitting_sheet].copy()
                
                # 确保日期列存在
                if 'Date' not in df_fitting.columns:
                    print(f"警告: {fitting_sheet} 中没有 'Date' 列")
                    continue
                
                # 按日期排序
                df_fitting = df_fitting.sort_values('Date')
            else:
                # 2018年没有单独的拟合数据
                df_fitting = df_measured.copy()
            
            # 存储评估指标
            metrics = []
            
            # 为每个波段和极化组合预测VWC
            for band in BAND_COLORS.keys():
                for pol in POL_LINESTYLES.keys():
                    # 生成列名
                    col_name = f"Predicted_VWC_{band}_{pol}"
                    
                    # 如果列不存在，使用模型预测
                    if col_name not in df_fitting.columns:
                        print(f"为 {fitting_sheet} 预测 {band}-{pol} VWC...")
                        df_fitting[col_name] = predict_vwc_for_sheet(df_fitting, band, pol)
                    
                    # 只在有有效预测值的点进行绘制和评估
                    if col_name in df_fitting.columns:
                        # 更新Y轴范围（预测值）
                        pred_values = df_fitting[col_name].dropna()
                        if not pred_values.empty:
                            y_min = min(y_min, pred_values.min())
                            y_max = max(y_max, pred_values.max())
                        
                        # 获取有效预测数据点
                        valid_mask = df_fitting[col_name].notna()
                        valid_dates = df_fitting['Date'][valid_mask]
                        valid_values = df_fitting[col_name][valid_mask]
                        
                        # 如果数据点足够多，使用样条插值生成平滑曲线
                        if len(valid_dates) > 3:
                            try:
                                # 将日期转换为数值（从最小日期开始的天数）
                                date_numeric = (valid_dates - valid_dates.min()).dt.days
                                
                                # 创建样条插值对象
                                spline = make_interp_spline(date_numeric, valid_values, k=3)
                                
                                # 生成更密集的时间点
                                dense_dates = np.linspace(date_numeric.min(), date_numeric.max(), 300)
                                dense_values = spline(dense_dates)
                                
                                # 将数值日期转换回实际日期
                                dense_dates = valid_dates.min() + pd.to_timedelta(dense_dates, unit='D')
                                
                                # 绘制平滑曲线
                                ax.plot(dense_dates, dense_values,
                                        color=BAND_COLORS[band],
                                        linestyle=POL_LINESTYLES[pol],
                                        linewidth=1.5)
                            except Exception as e:
                                print(f"样条插值失败: {str(e)}")
                                # 如果插值失败，使用原始数据点绘制折线
                                ax.plot(valid_dates, valid_values,
                                        color=BAND_COLORS[band],
                                        linestyle=POL_LINESTYLES[pol],
                                        linewidth=1.5)
                        else:
                            # 数据点太少，直接绘制折线
                            ax.plot(valid_dates, valid_values,
                                    color=BAND_COLORS[band],
                                    linestyle=POL_LINESTYLES[pol],
                                    linewidth=1.5)
                        
                        # 找出同时有实测值和预测值的点
                        common_data = pd.merge(
                            df_measured[['Date', actual_col]], 
                            df_fitting[['Date', col_name]], 
                            on='Date', 
                            how='inner'
                        ).dropna(subset=[actual_col, col_name])
                        
                        if not common_data.empty:
                            # 更新Y轴范围（共同数据）
                            common_min = min(common_data[actual_col].min(), common_data[col_name].min())
                            common_max = max(common_data[actual_col].max(), common_data[col_name].max())
                            y_min = min(y_min, common_min)
                            y_max = max(y_max, common_max)
                            
                            # 在实测日期位置绘制实测值点（空心圆）
                            ax.plot(common_data['Date'], common_data[actual_col],
                                    linestyle='',  # 无线条
                                    color='black',
                                    marker='o',
                                    markersize=8,
                                    markerfacecolor='none',  # 透明填充（空心）
                                    markeredgewidth=1.5)
                            
                            # 在实测日期位置绘制预测点（空心标记）
                            ax.plot(common_data['Date'], common_data[col_name],
                                    linestyle='',  # 无线条
                                    color=BAND_COLORS[band],
                                    marker=POL_MARKERS[pol],
                                    markersize=10,
                                    markerfacecolor='none',  # 透明填充（空心）
                                    markeredgewidth=1.5)
                            
                            # 计算评估指标
                            rmse = np.sqrt(mean_squared_error(common_data[actual_col], common_data[col_name]))
                            r2 = r2_score(common_data[actual_col], common_data[col_name])
                            
                            # 添加到指标列表
                            metrics.append({
                                'band': band,
                                'pol': pol,
                                'rmse': rmse,
                                'r2': r2
                            })
                            
                            # 保存预测结果
                            model_key = f"{veg_type}_{band}_{pol}"
                            all_predictions[model_key] = {
                                'dates': common_data['Date'].tolist(),
                                'measured': common_data[actual_col].tolist(),
                                'predicted': common_data[col_name].tolist(),
                                'rmse': rmse,
                                'r2': r2
                            }
            
            # 设置子图标题 - 修改3：使用新的标题格式
            ax.set_title(VEGETATION_TYPES.get(veg_type, veg_type), 
                         fontsize=16, fontweight='bold')
            
            # 设置坐标轴标签
            if idx == 2:  # 最后一行
                ax.set_xlabel('Date', fontsize=12, fontweight='bold')
            ax.set_ylabel('VWC (kg/m²)', fontsize=12, fontweight='bold')
            
            # 设置X轴格式
            ax.xaxis.set_major_locator(mdates.DayLocator(interval=10))
            ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d'))
            
            # 添加网格
            ax.grid(True, linestyle='--', alpha=0.3)
            
            # 动态设置Y轴范围
            if y_min != float('inf') and y_max != float('-inf'):
                # 添加10%的边距
                y_range = y_max - y_min
                padding = y_range * 0.1
                
                # 确保最小值不小于0
                y_min = max(0, y_min - padding)
                y_max = y_max + padding
                
                ax.set_ylim(y_min, y_max)
            else:
                # 默认范围
                ax.set_ylim(0, 10)
            
            # 存储指标
            all_metrics[veg_type] = metrics
    
    # ==================================
    # 创建图例（精确匹配要求）
    # ==================================
    
    # 创建图例区域的轴
    ax_legend = fig.add_subplot(gs[3])
    ax_legend.axis('off')  # 隐藏坐标轴
    
    # 定义图例行内容（标题+项目） - 修改4：使用新的标题格式
    legend_rows = [
        # 第一行：Ku波段
        [f"{BAND_NAMES['Ku']},{POL_NAMES[pol]}" for pol in ['H', 'V', 'HV']],
        
        # 第二行：X波段
        [f"{BAND_NAMES['X']},{POL_NAMES[pol]}" for pol in ['H', 'V', 'HV']],
        
        # 第三行：C波段
        [f"{BAND_NAMES['C']},{POL_NAMES[pol]}" for pol in ['H', 'V', 'HV']],

        # 第四行：实测点
        [f"Insitu VWC"]
    ]
    
    # 创建代理艺术家
    proxies = {}
    
    # Insitu VWC代理（空心圆）
    proxies['insitu'] = plt.Line2D([], [], 
                     linestyle='', 
                     marker='o',
                     markersize=10,
                     markerfacecolor='none',
                     markeredgecolor='black',
                     markeredgewidth=1.5,
                     label='Insitu VWC')
    
    # 波段-极化组合代理 - 修改5：使用新的标题格式
    for band in ['Ku', 'X', 'C']:
        color = BAND_COLORS[band]
        for pol in ['H', 'V', 'HV']:
            proxies[f"{band}-{pol}"] = plt.Line2D([], [],
                color=color,
                linestyle=POL_LINESTYLES[pol],
                linewidth=2,
                marker=POL_MARKERS[pol],
                markersize=10,
                markerfacecolor='none',
                markeredgecolor=color,
                markeredgewidth=1.5,
                label=f"{BAND_NAMES[band]},{POL_NAMES[pol]}")
    
    # 为每行创建图例
    y_positions = [0.85, 0.60, 0.35, 0.10]  # 三行垂直位置
    
    for row_idx, row_items in enumerate(legend_rows):
        handles = []
        labels = []
        
        for item in row_items:
            # 处理Insitu项
            if item == "Insitu VWC":
                handles.append(proxies['insitu'])
                labels.append(item)
            # 处理波段-极化项
            else:
                # 解析新的标签格式
                band_part, pol_part = item.split(',')
                band = band_part.split('-')[0]  # 提取波段名称
                
                handles.append(proxies[f"{band}-{pol}"])
                labels.append(item)  # 使用完整的标签文本
        
        # 计算当前行文本宽度（均匀分布）
        n_items = len(handles)
        x_positions = np.linspace(0.05, 0.95, n_items)
        
        # 绘制当前行的图例项
        for i, (handle, label) in enumerate(zip(handles, labels)):
            ax_legend.plot([], [])  # 空白绘图以创建图例项
            
            # 创建图例句柄
            leg = ax_legend.legend([handle], [label], 
                                  loc='lower center',
                                  bbox_to_anchor=(x_positions[i], y_positions[row_idx]),
                                  frameon=False,
                                  handlelength=2,
                                  fontsize=10,
                                  handletextpad=0.8)
            
            # 添加到轴（否则会被覆盖）
            ax_legend.add_artist(leg)
    
    # 在子图中显示评估指标
    for idx, (veg_type, metrics) in enumerate(all_metrics.items()):
        if idx < 3:  # 确保索引有效（排除图例轴）
            ax = fig.axes[idx]
            
            # 创建指标文本
            if metrics:
                # 使用多列格式显示所有指标
                metric_text = "Evaluation Metrics:\n"
                
                # 按波段分组指标
                band_metrics = {}
                for metric in metrics:
                    band = metric['band']
                    if band not in band_metrics:
                        band_metrics[band] = []
                    band_metrics[band].append(metric)
                
                # 为每个波段创建一行文本
                for band in ['Ku', 'X', 'C']:
                    if band in band_metrics:
                        band_text = f"{BAND_NAMES[band]}: "
                        pol_texts = []
                        for metric in band_metrics[band]:
                            pol_texts.append(f"{POL_NAMES[metric['pol']]}(RMSE={metric['rmse']:.3f})")
                        band_text += ", ".join(pol_texts)
                        metric_text += band_text + "\n"
                
                # 添加文本框
                ax.text(0.02, 0.95, metric_text, 
                        transform=ax.transAxes,
                        fontsize=9,
                        verticalalignment='top',
                        bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    # 调整布局
    plt.tight_layout(rect=[0, 0, 1, 0.93])
    
    # 保存图像
    figures_dir = Path("figures")
    figures_dir.mkdir(parents=True, exist_ok=True)
    fig_path = figures_dir / "Combined_VWC_Time_Series_Purify.png"
    plt.savefig(fig_path, dpi=300, bbox_inches='tight')
    print(f"组合时间序列图已保存至: {fig_path}")
    plt.close()
    
    # 保存所有预测结果到CSV文件
    for model_key, data in all_predictions.items():
        veg_type, band, pol = model_key.split('_')
        df = pd.DataFrame({
            'Date': data['dates'],
            'Measured': data['measured'],
            'Predicted': data['predicted']
        })
        csv_path = output_dir / f"{veg_type}_{band}_{pol}_predictions_Purify.csv"
        df.to_csv(csv_path, index=False)
        print(f"保存预测结果至: {csv_path}")
    
    # 保存评估指标
    metrics_path = output_dir / "model_metrics_Purify.csv"
    metrics_data = []
    for veg_type, metrics in all_metrics.items():
        for metric in metrics:
            metrics_data.append({
                'Vegetation': veg_type,
                'Band': metric['band'],
                'Polarization': metric['pol'],
                'RMSE': metric['rmse'],
                'R2': metric['r2']
            })
    
    metrics_df = pd.DataFrame(metrics_data)
    metrics_df.to_csv(metrics_path, index=False)
    print(f"保存模型评估指标至: {metrics_path}")

def main():
    # 2017年数据文件
    file_2017 = r"E:\data\VWC\test-VWC\多频多角度地基微波辐射计及地表参量观测数据集\DuolunExp_Veg_ML.xlsx"
    data_2017 = load_data(file_2017)
    
    # 2018年数据文件
    file_2018 = r"E:\data\VWC\test-VWC\多频多角度地基微波辐射计及地表参量观测数据集\ZhenglanqiExp_VWC_ML.xlsx"
    data_2018 = load_data(file_2018)
    
    # 创建组合时间序列图
    create_combined_plots(data_2017, data_2018)
    
    print("\n处理完成!")

if __name__ == "__main__":
    main()

加载文件: E:\data\VWC\test-VWC\多频多角度地基微波辐射计及地表参量观测数据集\DuolunExp_Veg_ML.xlsx
  - CornVegMeasured: 8行
  - CornVegFitting: 64行
  - OatVegMeasured: 7行
  - OatVegFitting: 64行
加载文件: E:\data\VWC\test-VWC\多频多角度地基微波辐射计及地表参量观测数据集\ZhenglanqiExp_VWC_ML.xlsx
  - GrassVWC: 13行
创建组合时间序列图...
为 CornVegFitting 预测 Ku-H VWC...
加载模型: models/RFR_Ku_Hpol_Purify_Type1.pkl
  模型期望特征: ['VOD', 'LAI', 'SM', 'Grass_man', 'Grass_nat', 'Shrub_bd', 'Shrub_be', 'Shrub_nd', 'Shrub_ne', 'Tree_bd', 'Tree_be', 'Tree_nd', 'Tree_ne']
  使用实测LAI替换了 63 行数据
  已完成ku_vod_H的时间序列插值
  已完成LAI_Satellite的时间序列插值
  已完成SM_Satellite的时间序列插值
  已完成grassman的时间序列插值
  已完成grassnat的时间序列插值
  已完成shrubbd的时间序列插值
  已完成shrubbe的时间序列插值
  已完成shrubnd的时间序列插值
  已完成shrubne的时间序列插值
  已完成treebd的时间序列插值
  已完成treebe的时间序列插值
  已完成treend的时间序列插值
  已完成treene的时间序列插值
为 CornVegFitting 预测 Ku-V VWC...
加载模型: models/RFR_Ku_Vpol_Purify_Type1.pkl
  模型期望特征: ['VOD', 'LAI', 'SM', 'Grass_man', 'Grass_nat', 'Shrub_bd', 'Shrub_be', 'Shrub_nd', 'Shrub_ne', 'Tree_bd', 'Tree_be', 'Tree_nd', 'T

In [3]:
# 3*3 散点图结果
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from pathlib import Path
import warnings
from sklearn.metrics import mean_squared_error
warnings.filterwarnings('ignore')

# 设置全局字体
plt.rcParams['font.family'] = 'Arial'
plt.rcParams['font.weight'] = 'bold'
plt.rcParams['axes.labelweight'] = 'bold'
plt.rcParams['axes.titleweight'] = 'bold'
plt.rcParams['figure.titlesize'] = 16
plt.rcParams['figure.titleweight'] = 'bold'

# 常量定义
BANDS = ['Ku', 'X', 'C']
BAND_COLORS = {
    'Ku': '#1f77b4',  # 蓝色
    'X': '#ff7f0e',   # 橙色
    'C': '#2ca02c'    # 绿色
}
POLS = ['H', 'V', 'HV']

# 植被类型标记样式 - 玉米标记改为空心方形（'s'）
VEG_MARKERS = {
    'CornVegMeasured': {'marker': 's', 'size': 80, 'label': 'Corn (2017)'},  # 改为方形
    'OatVegMeasured': {'marker': '^', 'size': 80, 'label': 'Oat (2017)'},
    'GrassVWC': {'marker': 'o', 'size': 80, 'label': 'Grass (2018)'}
}

def load_prediction_data(prediction_dir):
    """从CSV文件加载预测结果"""
    print(f"加载预测结果: {prediction_dir}")
    all_predictions = {}
    
    # 遍历所有CSV文件
    for csv_file in prediction_dir.glob("*_predictions_Purify.csv"):
        # 解析文件名获取模型信息
        filename = csv_file.stem
        parts = filename.split('_')
        
        if len(parts) >= 4:  # 格式: {植被类型}_{波段}_{极化}_predictions
            veg_type = parts[0]
            band = parts[1]
            pol = parts[2]
            model_key = f"{band}_{pol}"
            
            # 加载数据
            df = pd.read_csv(csv_file)
            
            # 确保日期是datetime类型
            if 'Date' in df.columns:
                df['Date'] = pd.to_datetime(df['Date'])
            
            # 存储数据
            if model_key not in all_predictions:
                all_predictions[model_key] = {}
            
            all_predictions[model_key][veg_type] = df
    
    return all_predictions

def get_model_title(band, pol):
    """根据波段和极化返回自定义标题"""
    band_names = {
        'Ku': 'Ku-Band',
        'X': 'X-Band',
        'C': 'C-Band'
    }
    pol_names = {
        'H': 'H-Pol',
        'V': 'V-Pol',
        'HV': 'H&V-Pol'  # 修改这里
    }
    return f"{band_names.get(band, band)},{pol_names.get(pol, pol)}"

def create_scatter_plots_from_predictions(prediction_dir):
    """从预测结果文件创建9个模型的真值与预测值散点图（3x3网格）"""
    # 加载预测结果
    all_predictions = load_prediction_data(prediction_dir)
    
    if not all_predictions:
        print("警告: 没有找到预测结果文件")
        return
    
    # 创建3x3网格图
    fig = plt.figure(figsize=(15, 15))
    fig.suptitle('', fontsize=20, y=0.95)
    gs = gridspec.GridSpec(3, 3, wspace=0.3, hspace=0.3)
    
    # 收集所有散点的最小和最大值（用于统一坐标轴）
    all_actual_min, all_actual_max = np.inf, -np.inf
    all_pred_min, all_pred_max = np.inf, -np.inf
    
    # 收集所有评估指标
    all_metrics = {}

    # 处理每个模型（波段和极化组合）
    for i, band in enumerate(BANDS):
        for j, pol in enumerate(POLS):
            model_key = f"{band}_{pol}"
            ax = plt.subplot(gs[i, j])
            print(f"处理模型: {model_key}")
            
            # 检查该模型是否有预测数据
            if model_key not in all_predictions:
                print(f"警告: {model_key} 模型没有预测数据")
                ax.text(0.5, 0.5, 'No Data', horizontalalignment='center', 
                        verticalalignment='center', transform=ax.transAxes,
                        fontsize=14, color='red')
                ax.set_title(get_model_title(band, pol), fontsize=14)
                continue
                
            # 收集该模型的所有植被类型的数据
            all_actual = []
            all_predicted = []
            all_veg_types = []
            
            # 存储各植被类型的数据点
            veg_data = {
                'CornVegMeasured': {'actual': [], 'predicted': []},
                'OatVegMeasured': {'actual': [], 'predicted': []},
                'GrassVWC': {'actual': [], 'predicted': []}
            }
            
            # 处理玉米数据
            veg_type = 'CornVegMeasured'
            if veg_type in all_predictions[model_key]:
                df = all_predictions[model_key][veg_type]
                if 'Measured' in df.columns and 'Predicted' in df.columns:
                    # 添加数据点
                    veg_data[veg_type]['actual'] = df['Measured'].tolist()
                    veg_data[veg_type]['predicted'] = df['Predicted'].tolist()
                    
                    # 添加到总数据
                    all_actual.extend(df['Measured'])
                    all_predicted.extend(df['Predicted'])
                    all_veg_types.extend([veg_type] * len(df))
                    print(f"  - 玉米数据点: {len(df)}")
            
            # 处理燕麦数据
            veg_type = 'OatVegMeasured'
            if veg_type in all_predictions[model_key]:
                df = all_predictions[model_key][veg_type]
                if 'Measured' in df.columns and 'Predicted' in df.columns:
                    # 添加数据点
                    veg_data[veg_type]['actual'] = df['Measured'].tolist()
                    veg_data[veg_type]['predicted'] = df['Predicted'].tolist()
                    
                    # 添加到总数据
                    all_actual.extend(df['Measured'])
                    all_predicted.extend(df['Predicted'])
                    all_veg_types.extend([veg_type] * len(df))
                    print(f"  - 燕麦数据点: {len(df)}")
            
            # 处理草数据
            veg_type = 'GrassVWC'
            if veg_type in all_predictions[model_key]:
                df = all_predictions[model_key][veg_type]
                if 'Measured' in df.columns and 'Predicted' in df.columns:
                    # 添加数据点
                    veg_data[veg_type]['actual'] = df['Measured'].tolist()
                    veg_data[veg_type]['predicted'] = df['Predicted'].tolist()
                    
                    # 添加到总数据
                    all_actual.extend(df['Measured'])
                    all_predicted.extend(df['Predicted'])
                    all_veg_types.extend([veg_type] * len(df))
                    print(f"  - 草数据点: {len(df)}")
            
            # 如果没有数据点，跳过
            if len(all_actual) == 0:
                print(f"警告: {model_key} 模型没有有效数据点")
                ax.text(0.5, 0.5, 'No Data', horizontalalignment='center', 
                        verticalalignment='center', transform=ax.transAxes,
                        fontsize=14, color='red')
                ax.set_title(get_model_title(band, pol), fontsize=14)
                continue
                
            # 转换为numpy数组
            all_actual = np.array(all_actual)
            all_predicted = np.array(all_predicted)
            
            # 更新全局最小/最大值
            all_actual_min = min(all_actual_min, np.min(all_actual))
            all_actual_max = max(all_actual_max, np.max(all_actual))
            all_pred_min = min(all_pred_min, np.min(all_predicted))
            all_pred_max = max(all_pred_max, np.max(all_predicted))
            
            # 计算各植被类型的RMSE
            rmse_corn = None
            rmse_oat = None
            rmse_grass = None
            
            if veg_data['CornVegMeasured']['actual']:
                actual_corn = np.array(veg_data['CornVegMeasured']['actual'])
                predicted_corn = np.array(veg_data['CornVegMeasured']['predicted'])
                rmse_corn = np.sqrt(mean_squared_error(actual_corn, predicted_corn))
            
            if veg_data['OatVegMeasured']['actual']:
                actual_oat = np.array(veg_data['OatVegMeasured']['actual'])
                predicted_oat = np.array(veg_data['OatVegMeasured']['predicted'])
                rmse_oat = np.sqrt(mean_squared_error(actual_oat, predicted_oat))
            
            if veg_data['GrassVWC']['actual']:
                actual_grass = np.array(veg_data['GrassVWC']['actual'])
                predicted_grass = np.array(veg_data['GrassVWC']['predicted'])
                rmse_grass = np.sqrt(mean_squared_error(actual_grass, predicted_grass))
            
            # 计算整体RMSE
            rmse_total = np.sqrt(mean_squared_error(all_actual, all_predicted))
            
            # 存储评估指标
            all_metrics[model_key] = {
                'RMSE_Corn': rmse_corn,
                'RMSE_Oat': rmse_oat,
                'RMSE_Grass': rmse_grass,
                'RMSE_Total': rmse_total
            }
            
            # 绘制散点图 - 按植被类型区分标记
            # 先绘制草和燕麦，最后绘制玉米（确保玉米在最上层）
            for veg_type in ['GrassVWC', 'OatVegMeasured', 'CornVegMeasured']:
                if veg_data[veg_type]['actual']:
                    actual_values = np.array(veg_data[veg_type]['actual'])
                    predicted_values = np.array(veg_data[veg_type]['predicted'])
                    
                    marker_style = VEG_MARKERS[veg_type]
                    
                    # 为玉米标记使用更大的尺寸和线宽
                    if veg_type == 'CornVegMeasured':
                        size = 100  # 增加大小
                        edgewidth = 1.5  # 更粗的线宽
                        alpha = 0.9  # 更高的不透明度
                    else:
                        size = marker_style['size']
                        edgewidth = 1.0
                        alpha = 0.8
                    
                    # 所有标记使用相同的波段颜色
                    ax.scatter(actual_values, predicted_values, 
                              marker=marker_style['marker'], 
                              s=size,
                              alpha=alpha,  # 调整透明度
                              facecolor='none', 
                              edgecolor=BAND_COLORS[band],  # 使用波段颜色
                              linewidths=edgewidth,
                              label=marker_style['label'])
            
            # 添加1:1参考线
            ax.plot([0, 4], [0, 4], 'k--', linewidth=1, label='1:1 Line')
            
            # 设置标题和坐标轴标签
            ax.set_title(get_model_title(band, pol), fontsize=14)
            if j == 0:  # 第一列添加y轴标签
                ax.set_ylabel('RF VWC (kg/m²)', fontsize=12)
            if i == 2:  # 最后一行添加x轴标签
                ax.set_xlabel('In Situ VWC (kg/m²)', fontsize=12)
            
            # 添加网格
            ax.grid(True, linestyle='--', alpha=0.3)
            
            # 显示评估指标
            metric_text = ""
            if rmse_corn is not None:
                metric_text += f"Corn RMSE = {rmse_corn:.3f}\n"
            if rmse_oat is not None:
                metric_text += f"Oat RMSE = {rmse_oat:.3f}\n"
            if rmse_grass is not None:
                metric_text += f"Grass RMSE = {rmse_grass:.3f}\n"
            metric_text += f"Total RMSE = {rmse_total:.3f}"
            
            ax.text(0.05, 0.95, metric_text, transform=ax.transAxes, 
                   fontsize=9, verticalalignment='top',
                   bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    # 设置所有子图的坐标轴范围一致
    max_val = 4
    min_val = 0
    for ax in fig.get_axes():
        ax.set_xlim(min_val, max_val)
        ax.set_ylim(min_val, max_val)
    
    # 添加图例
    # 创建代理艺术家用于图例
    handles = []
    labels = []
    
    # 添加植被类型标记
    for veg_type, style in VEG_MARKERS.items():
        # 为玉米标记使用特殊大小
        if veg_type == 'CornVegMeasured':
            markersize = 10  # 图例中保持相同大小
        else:
            markersize = 8
            
        handles.append(
            plt.Line2D([], [], marker=style['marker'], linestyle='None', 
                       markersize=markersize, alpha=0.7, markerfacecolor='none', 
                       markeredgecolor='gray', label=style['label'])
        )
    
    # 添加1:1参考线
    handles.append(
        plt.Line2D([], [], color='k', linestyle='--', linewidth=1, label='1:1 Line')
    )
    
    plt.tight_layout(rect=[0, 0.01, 1, 0.95])  # 调整底部空间
 
    # 添加图例到整个图形
    fig.legend(handles=handles, loc='lower center', 
               bbox_to_anchor=(0.5, 0.05), ncol=4, fontsize=10, 
               title="")
    output_dir = Path("figures")
    output_dir.mkdir(parents=True, exist_ok=True)
    fig_path = output_dir / "Scatter_Predictions_From_Saved_Data_Purify.png"
    plt.savefig(fig_path, dpi=1000, bbox_inches='tight', pad_inches=0.1)
    print(f"散点图已保存至: {fig_path}")
    plt.close()
    
    # 打印所有模型的评估指标
    print("\n模型评估指标:")
    for model_name, metrics in all_metrics.items():
        print(f"{model_name}:")
        if metrics['RMSE_Corn'] is not None:
            print(f"  Corn RMSE = {metrics['RMSE_Corn']:.4f}")
        if metrics['RMSE_Oat'] is not None:
            print(f"  Oat RMSE = {metrics['RMSE_Oat']:.4f}")
        if metrics['RMSE_Grass'] is not None:
            print(f"  Grass RMSE = {metrics['RMSE_Grass']:.4f}")
        print(f"  Total RMSE = {metrics['RMSE_Total']:.4f}")

def main():
    # 设置预测结果目录
    prediction_dir = Path("prediction_results")
    
    # 创建散点图
    create_scatter_plots_from_predictions(prediction_dir)
    
    print("\n处理完成!")

if __name__ == "__main__":
    main()

加载预测结果: prediction_results
处理模型: Ku_H
  - 玉米数据点: 8
  - 燕麦数据点: 7
  - 草数据点: 13
处理模型: Ku_V
  - 玉米数据点: 8
  - 燕麦数据点: 7
  - 草数据点: 13
处理模型: Ku_HV
  - 玉米数据点: 8
  - 燕麦数据点: 7
  - 草数据点: 13
处理模型: X_H
  - 玉米数据点: 8
  - 燕麦数据点: 7
  - 草数据点: 13
处理模型: X_V
  - 玉米数据点: 8
  - 燕麦数据点: 7
  - 草数据点: 13
处理模型: X_HV
  - 玉米数据点: 8
  - 燕麦数据点: 7
  - 草数据点: 13
处理模型: C_H
  - 玉米数据点: 8
  - 燕麦数据点: 7
  - 草数据点: 13
处理模型: C_V
  - 玉米数据点: 8
  - 燕麦数据点: 7
  - 草数据点: 13
处理模型: C_HV
  - 玉米数据点: 8
  - 燕麦数据点: 7
  - 草数据点: 13
散点图已保存至: figures\Scatter_Predictions_From_Saved_Data_Purify.png

模型评估指标:
Ku_H:
  Corn RMSE = 2.9664
  Oat RMSE = 2.8482
  Grass RMSE = 1.4324
  Total RMSE = 2.3441
Ku_V:
  Corn RMSE = 2.7527
  Oat RMSE = 2.0731
  Grass RMSE = 1.4749
  Total RMSE = 2.0614
Ku_HV:
  Corn RMSE = 2.8971
  Oat RMSE = 2.4244
  Grass RMSE = 1.0185
  Total RMSE = 2.0855
X_H:
  Corn RMSE = 2.6867
  Oat RMSE = 1.8765
  Grass RMSE = 0.9755
  Total RMSE = 1.8397
X_V:
  Corn RMSE = 1.4954
  Oat RMSE = 1.2271
  Grass RMSE = 1.2577
  Total RMSE = 1.3228

2002-2016

In [4]:
# 散点图（4个数据画在一块，写出n，按照波段-极化组合绘制为3*3）
# 点形状及颜色：
# SMEX02：*；CLASIC07：^；SMEX08：+；SMAPVEX16：o

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib as mpl
import matplotlib.font_manager as fm
import joblib
import os
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')
from sklearn.metrics import mean_squared_error

# 常量定义
BANDS = ['Ku', 'X', 'C']
POLS = ['H', 'V', 'HV']
SHEET_NAMES = ['SMEX02', 'CLASIC07', 'SMAPVEX08', 'SMAPVEX16']
VWC_COLUMNS = {
    'SMEX02': 'VWC-Field',
    'CLASIC07': 'VWC (kg/m²)',
    'SMAPVEX08': 'VWC',
    'SMAPVEX16': 'PLANT_WATER_CONTENT_AREA'
}

# 标记和颜色设置
MARKER_STYLES = {
    'SMEX02': {'marker': 'x', 'color': '#F8766D'},
    'CLASIC07': {'marker': '^', 'facecolor': 'none', 'edgecolor': '#00BFC4'},
    'SMAPVEX08': {'marker': '+', 'color': '#C77CFF'},
    'SMAPVEX16': {'marker': 'o', 'facecolor': 'none', 'edgecolor': '#7CAE00'}
}

# 设置全局字体
plt.rcParams['font.family'] = 'Arial'
plt.rcParams['font.weight'] = 'bold'

def load_and_preprocess_data(file_path):
    """
    加载并预处理Excel文件中的所有sheet
    
    参数:
    file_path (str): Excel文件路径
    
    返回:
    dict: 包含预处理后数据的字典，键为sheet名称
    """
    print(f"加载文件: {file_path}")
    data_dict = {}
    
    for sheet in SHEET_NAMES:
        try:
            df = pd.read_excel(file_path, sheet_name=sheet)
            print(f"  - {sheet}: {len(df)}行")
            
            # 替换SM_Satellite和LAI_Satellite（如果存在地面实测数据）
            if 'SM' in df.columns:
                mask = df['SM'].notna()
                df.loc[mask, 'SM_Satellite'] = df.loc[mask, 'SM']
                print(f"    替换了 {mask.sum()} 行SM_Satellite数据")
            
            if 'LAI' in df.columns:
                mask = df['LAI'].notna()
                df.loc[mask, 'LAI_Satellite'] = df.loc[mask, 'LAI']
                print(f"    替换了 {mask.sum()} 行LAI_Satellite数据")
            
            data_dict[sheet] = df
        except Exception as e:
            print(f"  加载 {sheet} 时出错: {str(e)}")
            data_dict[sheet] = pd.DataFrame()
    
    return data_dict

def get_features_for_model(band, pol):
    """
    根据波段和极化类型获取特征列表（使用模型训练时的名称）
    
    参数:
    band (str): 波段 ('Ku', 'X', 'C')
    pol (str): 极化类型 ('H', 'V', 'HV')
    
    返回:
    list: 特征列名列表
    """
    # 使用模型训练时的特征名称
    features = [
        'LAI',  # 注意：训练时使用"LAI"而不是"LAI_Satellite"
        'SM',   # 注意：训练时使用"SM"而不是"SM_Satellite"
        'Grass_man', 
        'Grass_nat',
        'Shrub_bd', 'Shrub_be', 'Shrub_nd', 'Shrub_ne',
        'Tree_bd', 'Tree_be', 'Tree_nd', 'Tree_ne'
    ]
    
    # 添加VOD特征 - 根据模型类型
    if pol == 'H' or pol == 'V':
        # 单极化模型使用"VOD"
        features.append('VOD')
    elif pol == 'HV':
        # 双极化模型使用"VOD-Hpol"和"VOD-Vpol"
        features.extend(['VOD-Hpol', 'VOD-Vpol'])
    
    return features

def predict_vwc(data_dict, band, pol):
    """
    使用指定模型预测VWC，包括特征归一化
    
    参数:
    data_dict (dict): 包含所有sheet数据的字典
    band (str): 波段 ('Ku', 'X', 'C')
    pol (str): 极化类型 ('H', 'V', 'HV')
    
    返回:
    dict: 包含每个sheet预测结果的字典
    """
    # 加载模型
    model_path = f"models/RFR_{band}_{pol}pol_Type1.pkl"
    print(f"加载模型: {model_path}")
    
    if not os.path.exists(model_path):
        print(f"  模型文件不存在: {model_path}")
        return {}
    
    try:
        model = joblib.load(model_path)
        # 打印模型训练时的特征名称（如果可用）
        if hasattr(model, 'feature_names_in_'):
            print(f"  模型训练特征: {list(model.feature_names_in_)}")
    except Exception as e:
        print(f"  加载模型失败: {str(e)}")
        return {}
    
    # 获取特征列表
    features = get_features_for_model(band, pol)
    
    # 存储预测结果
    predictions = {}
    
    for sheet, df in data_dict.items():
        if df.empty:
            continue
        
        # 创建特征映射（将数据列名映射到模型期望的特征名）
        feature_mapping = {}
        for feature in features:
            # 特殊处理VOD特征
            if feature == 'VOD':
                # 单极化模型
                if pol == 'H':
                    if band == 'Ku':
                        feature_mapping['ku_vod_H'] = 'VOD'
                    elif band == 'X':
                        feature_mapping['x_vod_H'] = 'VOD'
                    elif band == 'C':
                        feature_mapping['c_vod_H'] = 'VOD'
                elif pol == 'V':
                    if band == 'Ku':
                        feature_mapping['ku_vod_V'] = 'VOD'
                    elif band == 'X':
                        feature_mapping['x_vod_V'] = 'VOD'
                    elif band == 'C':
                        feature_mapping['c_vod_V'] = 'VOD'
            elif feature == 'VOD-Hpol':
                # 双极化模型中的H极化
                if band == 'Ku':
                    feature_mapping['ku_vod_H'] = 'VOD-Hpol'
                elif band == 'X':
                    feature_mapping['x_vod_H'] = 'VOD-Hpol'
                elif band == 'C':
                    feature_mapping['c_vod_H'] = 'VOD-Hpol'
            elif feature == 'VOD-Vpol':
                # 双极化模型中的V极化
                if band == 'Ku':
                    feature_mapping['ku_vod_V'] = 'VOD-Vpol'
                elif band == 'X':
                    feature_mapping['x_vod_V'] = 'VOD-Vpol'
                elif band == 'C':
                    feature_mapping['c_vod_V'] = 'VOD-Vpol'
            else:
                # 其他特征映射
                if feature == 'LAI':
                    feature_mapping['LAI_Satellite'] = 'LAI'
                elif feature == 'SM':
                    feature_mapping['SM_Satellite'] = 'SM'
                elif feature == 'Grass_man':
                    feature_mapping['grassman'] = 'Grass_man'
                elif feature == 'Grass_nat':
                    feature_mapping['grassnat'] = 'Grass_nat'
                elif feature == 'Shrub_bd':
                    feature_mapping['shrubbd'] = 'Shrub_bd'
                elif feature == 'Shrub_be':
                    feature_mapping['shrubbe'] = 'Shrub_be'
                elif feature == 'Shrub_nd':
                    feature_mapping['shrubnd'] = 'Shrub_nd'
                elif feature == 'Shrub_ne':
                    feature_mapping['shrubne'] = 'Shrub_ne'
                elif feature == 'Tree_bd':
                    feature_mapping['treebd'] = 'Tree_bd'
                elif feature == 'Tree_be':
                    feature_mapping['treebe'] = 'Tree_be'
                elif feature == 'Tree_nd':
                    feature_mapping['treend'] = 'Tree_nd'
                elif feature == 'Tree_ne':
                    feature_mapping['treene'] = 'Tree_ne'
        
        # 检查是否包含所有必要特征
        missing_features = []
        for data_feature in feature_mapping.keys():
            if data_feature not in df.columns:
                missing_features.append(data_feature)
        
        if missing_features:
            print(f"  {sheet} 缺少特征: {', '.join(missing_features)}")
            continue
        
        # 准备数据（使用重命名的特征）
        X = df[list(feature_mapping.keys())].copy()
        X.columns = [feature_mapping[col] for col in X.columns]
        
        # 确保特征顺序与模型期望一致
        if hasattr(model, 'feature_names_in_'):
            X = X[list(model.feature_names_in_)]
        
        # ========== 添加归一化处理 ==========
        print(f"  {sheet} 应用归一化处理...")
        
        # 1. VOD特征归一化（除以2）
        vod_features = ['VOD', 'VOD-Hpol', 'VOD-Vpol']
        for vod_feature in vod_features:
            if vod_feature in X.columns:
                X[vod_feature] = X[vod_feature].clip(0, 2) / 2.0
                print(f"    归一化 {vod_feature}: 除以2")
        
        # 2. LAI特征归一化（除以6）
        if 'LAI' in X.columns:
            X['LAI'] = X['LAI'].clip(0, 6) / 6.0
            print(f"    归一化 LAI: 除以6")
        
        # 3. PFT特征归一化（除以100）
        pft_features = [
            'Grass_man', 'Grass_nat',
            'Shrub_bd', 'Shrub_be', 'Shrub_nd', 'Shrub_ne',
            'Tree_bd', 'Tree_be', 'Tree_nd', 'Tree_ne'
        ]
        
        for pft_feature in pft_features:
            if pft_feature in X.columns:
                X[pft_feature] = X[pft_feature] / 100.0
                print(f"    归一化 {pft_feature}: 除以100")
        # =================================
        
        # 移除缺失值
        initial_count = len(X)
        X = X.dropna()
        removed_count = initial_count - len(X)
        if removed_count > 0:
            print(f"  {sheet} 移除了 {removed_count} 行包含缺失值的数据")
        
        if X.empty:
            print(f"  {sheet} 无有效数据可用于预测")
            continue
        
        # 预测VWC
        y_pred = model.predict(X)
        predictions[sheet] = {
            'actual': df.loc[X.index, VWC_COLUMNS[sheet]],
            'predicted': y_pred,
            'source': sheet,
            'row': df.loc[X.index, 'row'],
            'col': df.loc[X.index, 'col'],
            'lat': df.loc[X.index, 'Latitude'],
            'lon': df.loc[X.index, 'Longitude'],
            'date': df.loc[X.index, 'Date']
        }
        print(f"  {sheet} 预测完成: {len(y_pred)} 个样本")
    
    return predictions

def calculate_rmse(actual, predicted):
    """
    计算RMSE
    
    参数:
    actual (array-like): 实际值
    predicted (array-like): 预测值
    
    返回:
    float: RMSE值
    """
    return np.sqrt(np.mean((actual - predicted)**2))

def create_scatter_plots(all_predictions):
    """
    创建3x3散点子图
    
    参数:
    all_predictions (dict): 包含所有波段和极化组合预测结果的字典
    """
    print("创建散点图...")
    
    # 创建图形
    fig = plt.figure(figsize=(18, 18))
    gs = gridspec.GridSpec(3, 3, figure=fig)
    
    # 设置全局标题
    fig.suptitle('', fontsize=24, fontweight='bold', y=0.95)
    
    # 遍历所有波段和极化组合
    for i, band in enumerate(BANDS):
        for j, pol in enumerate(POLS):
            ax = fig.add_subplot(gs[i, j])
            
            # 获取当前组合的预测结果
            predictions = all_predictions.get((band, pol), {})
            
            # 收集所有数据点
            all_actual = []
            all_predicted = []
            
            # 绘制每个sheet的数据点
            for sheet in SHEET_NAMES:
                if sheet in predictions:
                    actual = predictions[sheet]['actual']
                    predicted = predictions[sheet]['predicted']
                    
                    # 添加到总集合
                    all_actual.extend(actual)
                    all_predicted.extend(predicted)
                    
                    # 绘制当前sheet的点
                    if sheet in ['CLASIC07', 'SMAPVEX16']:
                        # 对CLASIC07、SMAPVEX16特殊处理：空心
                        ax.scatter(
                            actual, predicted,
                            marker=MARKER_STYLES[sheet]['marker'],
                            facecolor=MARKER_STYLES[sheet]['facecolor'],  # 内部无填充
                            edgecolor=MARKER_STYLES[sheet]['edgecolor'],  # 使用边缘颜色
                            s=50,
                            alpha=0.7,
                            linewidths=1.0,  # 确保边框可见
                            label=sheet
                        )
                    else:
                        # 其他数据集保持原样
                        ax.scatter(
                            actual, predicted,
                            marker=MARKER_STYLES[sheet]['marker'],
                            color=MARKER_STYLES[sheet].get('color', MARKER_STYLES[sheet].get('edgecolor', None)),
                            s=50,
                            alpha=0.7,
                            label=sheet
                        )
            
            # 如果没有数据，跳过
            if not all_actual:
                ax.text(0.5, 0.5, '无数据', 
                        horizontalalignment='center', 
                        verticalalignment='center', 
                        transform=ax.transAxes,
                        fontsize=16)
                ax.set_title(f"{band}-{pol}", fontsize=16, fontweight='bold')
                continue
            
            # 计算整体RMSE
            rmse = calculate_rmse(np.array(all_actual), np.array(all_predicted))
            
            # 添加1:1参考线
            max_val = max(max(all_actual), max(all_predicted)) * 1.05
            ax.plot([0, max_val], [0, max_val], 'k--', lw=1.5, alpha=0.7)
            
            # 设置坐标轴范围
            ax.set_xlim(0, max_val)
            ax.set_ylim(0, max_val)
            
            # 设置坐标轴标签
            if i == 2:  # 最后一行
                ax.set_xlabel('Insitu VWC (kg/m²)', fontsize=14, fontweight='bold')
            if j == 0:  # 第一列
                ax.set_ylabel('Predicted VWC (kg/m²)', fontsize=14, fontweight='bold')
            
            # 添加标题和RMSE
            ax.set_title(f"{band}-{pol}", fontsize=16, fontweight='bold')
            ax.text(0.05, 0.95, f"RMSE: {rmse:.3f} kg/m²", 
                    transform=ax.transAxes,
                    fontsize=16,
                    fontweight='bold',
                    verticalalignment='top')
            
            # 添加网格
            ax.grid(True, linestyle='--', alpha=0.3)
    
    # 添加图例
    handles, labels = [], []
    for sheet in SHEET_NAMES:
        style = MARKER_STYLES[sheet]
        
        if sheet in ['CLASIC07', 'SMAPVEX16']:
            # 为CLASIC07、SMAPVEX16创建空心图例
            handles.append(plt.Line2D([0], [0], 
                                     marker=style['marker'], 
                                     color='w',
                                     markerfacecolor=style['facecolor'],  # 内部白色
                                     markeredgecolor=style['edgecolor'],  # 边缘颜色
                                     markersize=10,
                                     markeredgewidth=1.0))  # 边框宽度
        else:
            handles.append(plt.Line2D([0], [0], 
                                     marker=style['marker'], 
                                     color='w', 
                                     markerfacecolor=style.get('color', style.get('edgecolor')),
                                     markeredgecolor=style.get('color', style.get('edgecolor')), 
                                     markersize=10))
        labels.append(sheet)
    
    fig.legend(handles, labels, 
               loc='lower center', 
               ncol=4, 
               fontsize=12,
               frameon=True,
               fancybox=True,
               shadow=True,
               bbox_to_anchor=(0.5, 0.02))
    
    # 调整布局
    plt.tight_layout(rect=[0, 0.05, 1, 0.95])
    
    # 保存图像
    output_dir = Path(r"E:\data\VWC\test-VWC\SMEX02_CLASIC07_SMEX08_SMAPVEX16")
    output_dir.mkdir(parents=True, exist_ok=True)
    fig_path = "figures/AllSMAPInsituData_VWC_Scatter_purify.png"  # 更新为指定的文件名
    plt.savefig(fig_path, dpi=300, bbox_inches='tight')
    print(f"散点图已保存至: {fig_path}")
    plt.close()

def save_prediction_details(all_predictions):
    """
    将预测结果保存到Excel文件中
    
    参数:
    all_predictions (dict): 包含所有波段和极化组合预测结果的字典
    """
    output_dir = Path(r"E:\data\VWC\test-VWC\SMEX02_CLASIC07_SMEX08_SMAPVEX16")
    output_file = output_dir / "details_purify.xlsx"
    
    # 创建Excel写入器
    with pd.ExcelWriter(output_file, engine='openpyxl') as writer:
        # 遍历所有波段和极化组合
        for (band, pol), predictions in all_predictions.items():
            if not predictions:
                continue
                
            # 创建当前组合的数据框
            all_data = []
            
            # 收集所有sheet的数据
            for sheet, data in predictions.items():
                # 创建当前sheet的数据框
                sheet_df = pd.DataFrame({
                    'Date': data['date'],
                    'Row': data['row'],
                    'Col': data['col'],
                    'Latitude': data['lat'],
                    'Longitude': data['lon'],
                    'Actual_VWC': data['actual'],
                    'Predicted_VWC': data['predicted'],
                    'Source': data['source']
                })
                
                # 添加波段和极化信息
                sheet_df['Band'] = band
                sheet_df['Polarization'] = pol
                
                all_data.append(sheet_df)
            
            # 合并所有数据
            if all_data:
                combined_df = pd.concat(all_data, ignore_index=True)
                
                # 保存到Excel
                sheet_name = f"{band}_{pol}"
                combined_df.to_excel(writer, sheet_name=sheet_name, index=False)
                print(f"保存预测结果到: {sheet_name} ({len(combined_df)}行)")
    
    print(f"所有预测结果已保存至: {output_file}")

def main():
    # 输入文件路径
    input_file = r"E:\data\VWC\test-VWC\SMEX02_CLASIC07_SMEX08_SMAPVEX16\InsituData_Pixel_ML.xlsx"
    
    # 加载并预处理数据
    data_dict = load_and_preprocess_data(input_file)
    
    # 存储所有预测结果
    all_predictions = {}
    
    # 遍历所有波段和极化组合
    for band in BANDS:
        for pol in POLS:
            print(f"\n处理波段-极化组合: {band}-{pol}")
            predictions = predict_vwc(data_dict, band, pol)
            all_predictions[(band, pol)] = predictions
    
    # 创建散点图
    create_scatter_plots(all_predictions)
    
    # 保存预测结果到Excel
    save_prediction_details(all_predictions)
    
    print("\n处理完成!")

if __name__ == "__main__":
    main()

加载文件: E:\data\VWC\test-VWC\SMEX02_CLASIC07_SMEX08_SMAPVEX16\InsituData_Pixel_ML.xlsx
  - SMEX02: 16行
    替换了 14 行SM_Satellite数据
  - CLASIC07: 18行
  - SMAPVEX08: 6行
    替换了 6 行LAI_Satellite数据
  - SMAPVEX16: 115行

处理波段-极化组合: Ku-H
加载模型: models/RFR_Ku_Hpol_Type1.pkl
  模型训练特征: ['VOD', 'LAI', 'SM', 'Grass_man', 'Grass_nat', 'Shrub_bd', 'Shrub_be', 'Shrub_nd', 'Shrub_ne', 'Tree_bd', 'Tree_be', 'Tree_nd', 'Tree_ne']
  SMEX02 应用归一化处理...
    归一化 VOD: 除以2
    归一化 LAI: 除以6
    归一化 Grass_man: 除以100
    归一化 Grass_nat: 除以100
    归一化 Shrub_bd: 除以100
    归一化 Shrub_be: 除以100
    归一化 Shrub_nd: 除以100
    归一化 Shrub_ne: 除以100
    归一化 Tree_bd: 除以100
    归一化 Tree_be: 除以100
    归一化 Tree_nd: 除以100
    归一化 Tree_ne: 除以100
  SMEX02 移除了 7 行包含缺失值的数据
  SMEX02 预测完成: 9 个样本
  CLASIC07 应用归一化处理...
    归一化 VOD: 除以2
    归一化 LAI: 除以6
    归一化 Grass_man: 除以100
    归一化 Grass_nat: 除以100
    归一化 Shrub_bd: 除以100
    归一化 Shrub_be: 除以100
    归一化 Shrub_nd: 除以100
    归一化 Shrub_ne: 除以100
    归一化 Tree_bd: 除以100
    归一化 Tree_be: 除以100
  

# .VWC-SHB2024

In [15]:
!conda env list


# conda environments:
#
base                   D:\Anaconda
                       D:\Miniconda
d2l                    D:\ProgramData\anaconda3\envs\d2l
gdal_env               D:\ProgramData\anaconda3\envs\gdal_env
geo_env                D:\ProgramData\anaconda3\envs\geo_env
project              * D:\ProgramData\anaconda3\envs\project

