# 模型生成图像

## 单日数据生成及掩膜（这里把主程序的overwrite改为true即可，以方便直接复写）

In [1]:
# 单日数据的生成（支持Hveg变量）
import os
import sys
import h5py
import joblib
from osgeo import gdal
gdal.UseExceptions()
from osgeo import osr
import numpy as np
from tqdm import tqdm
from datetime import datetime, timedelta
import logging
import concurrent.futures

# ============================== 配置日志 ==============================
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(),
        logging.FileHandler('vwc_generation.log')
    ]
)
logger = logging.getLogger()

# ============================== 更新配置参数 ==============================
# 定义所有要处理的波段和极化组合
MODEL_CONFIGS = [
    # 单极化模型
    {'band': 'Ku', 'pol': 'H', 'name': 'KuH'},
    {'band': 'Ku', 'pol': 'V', 'name': 'KuV'},
    {'band': 'X', 'pol': 'H', 'name': 'XH'},
    {'band': 'X', 'pol': 'V', 'name': 'XV'},
    {'band': 'C', 'pol': 'H', 'name': 'CH'},
    {'band': 'C', 'pol': 'V', 'name': 'CV'},
    
    # 双极化模型
    {'band': 'Ku', 'pol': 'HV', 'name': 'KuHV'},
    {'band': 'X', 'pol': 'HV', 'name': 'XHV'},
    {'band': 'C', 'pol': 'HV', 'name': 'CHV'},
]

# 时间范围限定 (只生成2015年)
START_DATE = datetime(2015, 1, 1)
END_DATE = datetime(2015, 12, 31)

# 路径配置
VOD_BASE_PATH = r'E:\data\VOD\mat\kuxcVOD\ASC'
LAI_BASE_PATH = r'E:\data\GLASS LAI\mat\0.1Deg\Dataset'
PFT_BASE_PATH = r'E:\data\ESACCI PFT\Resample\Data'
Hveg_PATH = r'E:\data\CanopyHeight\CH.mat'  # Hveg数据路径
OUTPUT_PATH = r'E:\data\VWC\VWCMap\Daily'

# 确保输出目录存在
os.makedirs(OUTPUT_PATH, exist_ok=True)

# 加载所有模型（使用新模型名称）
MODELS = {}
for config in MODEL_CONFIGS:
    band = config['band']
    pol = config['pol']
    # 使用新模型名称（添加Hveg后缀）
    model_path = os.path.join(os.getcwd(), f'models\\RFR_{band}_{pol}-pol_Type1_Primary.pkl')
    try:
        if os.path.exists(model_path):
            MODELS[(band, pol)] = joblib.load(model_path)
            logger.info(f"成功加载 {band}波段{pol}极化模型（初始实验设置）")
        else:
            logger.warning(f"模型文件不存在 ({band},{pol}): {model_path}")
    except Exception as e:
        logger.error(f"加载模型失败 ({band},{pol}): {str(e)}")

# 定义PFT特征列表
PFT_FEATURES = [
    'Grass_man', 'Grass_nat', 
    'Shrub_bd', 'Shrub_be', 'Shrub_nd', 'Shrub_ne',
    'Tree_bd', 'Tree_be', 'Tree_nd', 'Tree_ne'
]

# 加载Hveg数据（全局变量）
Hveg_DATA = None
try:
    logger.info(f"加载Hveg数据: {Hveg_PATH}")
    with h5py.File(Hveg_PATH, 'r') as f:
        # 尝试不同可能的变量名
        for key in ['Hveg', 'CanopyHeight', 'CH']:
            if key in f:
                Hveg_DATA = np.array(f[key][:])
                logger.info(f"成功加载Hveg数据: {key}")
                break
        else:
            logger.warning("未找到Hveg变量，尝试使用第一个数据集")
            keys = list(f.keys())
            if keys:
                Hveg_DATA = np.array(f[keys[0]][:])
                logger.warning(f"使用默认数据集 '{keys[0]}' 作为Hveg数据")
    
    # 确保数据形状正确 (1800, 3600)
    if Hveg_DATA.shape == (3600, 1800):
        Hveg_DATA = Hveg_DATA.T
    elif Hveg_DATA.shape != (1800, 3600):
        logger.error(f"不支持的Hveg数据形状: {Hveg_DATA.shape}")
        Hveg_DATA = None
    
    # 归一化Hveg数据（除以40）
    if Hveg_DATA is not None:
        Hveg_DATA = Hveg_DATA
        logger.info("Hveg数据加载完成")
except Exception as e:
    logger.error(f"加载Hveg数据失败: {str(e)}")
    Hveg_DATA = None

# ============================== 缓存机制 ==============================
LAI_MONTH_CACHE = {}  # 缓存LAI数据（按年月）
PFT_YEAR_CACHE = {}    # 缓存PFT数据（按年）

# ============================== 辅助函数 ==============================
def create_multiband_geotiff(bands_data, band_names, output_path, nodata=-9999.0):
    """创建多波段地理参考的TIFF文件"""
    try:
        driver = gdal.GetDriverByName('GTiff')
        rows, cols = bands_data[0].shape
        
        # 验证分辨率
        expected_res = 0.1
        if cols != 3600 or rows != 1800:
            logger.warning(f"图像尺寸异常: {rows}x{cols} (期望1800x3600)")
        
        # 创建数据集
        out_ds = driver.Create(
            output_path, 
            cols, 
            rows, 
            len(bands_data), 
            gdal.GDT_Float32,
            options=['COMPRESS=LZW', 'BIGTIFF=YES']
        )
        
        # 设置地理变换 (确保精确的0.1度分辨率)
        geotransform = (-180.0, expected_res, 0.0, 90.0, 0.0, -expected_res)
        out_ds.SetGeoTransform(geotransform)
        
        # 设置坐标系 (WGS84)
        srs = osr.SpatialReference()
        srs.ImportFromEPSG(4326)
        out_ds.SetProjection(srs.ExportToWkt())
        
        # 添加分辨率元数据
        out_ds.SetMetadataItem('PIXEL_SIZE_X', str(expected_res))
        out_ds.SetMetadataItem('PIXEL_SIZE_Y', str(expected_res))
        out_ds.SetMetadataItem('UNITS', 'degrees')
        
        # 写入每个波段数据
        for i, band_data in enumerate(bands_data):
            band = out_ds.GetRasterBand(i+1)
            band.WriteArray(band_data)
            band.SetNoDataValue(nodata)
            band.SetDescription(band_names[i])
            band.SetScale(0.1)  # 设置缩放因子
            band.SetOffset(0.0)  # 设置偏移值
        
        # 清理
        out_ds.FlushCache()
        out_ds = None
        
        # 验证输出文件
        ds = gdal.Open(output_path)
        if ds:
            actual_geotransform = ds.GetGeoTransform()
            ds = None
            
            # 验证分辨率
            res_x = actual_geotransform[1]
            res_y = -actual_geotransform[5]  # 取绝对值
            
            if abs(res_x - expected_res) > 1e-6 or abs(res_y - expected_res) > 1e-6:
                logger.error(f"分辨率不匹配! 期望: {expected_res}, 实际: X={res_x}, Y={res_y}")
                return False
        
        logger.info(f"成功创建多波段GeoTIFF: {output_path} (波段: {', '.join(band_names)})")
        return True
    except Exception as e:
        logger.error(f"创建GeoTIFF失败: {str(e)}")
        return False

def get_vod_file(date):
    """获取VOD文件路径"""
    sensor = 'AMSRE' if date.year <= 2011 else 'AMSR2'
    date_str = date.strftime('%Y%m%d')
    
    # 查找VOD文件
    possible_files = [
        f'MCCA_{sensor}_010D_CCXH_VSM_VOD_Asc_{date_str}_V0.nc4.mat',
        f'MCCA_{sensor}_010D_CCXH_VSM_VOD_Asc_{date_str}_V0.mat',
        f'MCCA_{sensor}_010D_CCXH_VSM_VOD_Asc_{date_str}_V1.nc4.mat',
        f'MCCA_{sensor}_010D_CCXH_VSM_VOD_Asc_{date_str}.mat'
    ]
    
    for filename in possible_files:
        file_path = os.path.join(VOD_BASE_PATH, filename)
        if os.path.exists(file_path):
            return file_path
    
    logger.warning(f"未找到VOD文件: {date_str}")
    return None

def get_lai_file(year, month):
    """获取LAI文件路径"""
    month_str = str(month).zfill(2)
    
    # 尝试文件名格式
    possible_files = [
        f'{year}-{month_str}-01.tif.mat',
        f'{year}-{month_str}-01.mat',
        f'LAI_{year}{month_str}.mat',
        f'{year}_{month_str}_LAI.mat'
    ]
    
    for filename in possible_files:
        file_path = os.path.join(LAI_BASE_PATH, filename)
        if os.path.exists(file_path):
            return file_path
    
    logger.warning(f"未找到LAI文件: {year}-{month_str}")
    return None

def get_pft_file(year):
    """获取PFT文件路径"""
    # 尝试不同文件名
    possible_files = [
        f'{year}.mat',
        f'PFT_{year}.mat',
        f'ESACCI_PFT_{year}.mat',
        f'pft_{year}.mat'
    ]
    
    for filename in possible_files:
        file_path = os.path.join(PFT_BASE_PATH, filename)
        if os.path.exists(file_path):
            return file_path
    
    logger.warning(f"未找到PFT文件: {year}")
    return None

def get_month_centers(date):
    """获取用于插值的前后月中日期"""
    current_mid = date.replace(day=15)
    
    if date.day <= 15:
        prev_mid = (current_mid - timedelta(days=30)).replace(day=15)
        next_mid = current_mid
    else:
        prev_mid = current_mid
        next_mid = (current_mid + timedelta(days=30)).replace(day=15)
    return prev_mid, next_mid

def load_lai_matrix(file_path):
    """加载LAI矩阵数据"""
    if not file_path or not os.path.exists(file_path): 
        return None
    
    try:
        with h5py.File(file_path, 'r') as f:
            # 尝试不同数据集名称
            for key in ['lai', 'Layer', 'data']:
                if key in f:
                    lai_data = np.array(f[key][:])
                    break
            else:
                # 使用第一个数据集
                keys = list(f.keys())
                if keys:
                    lai_data = np.array(f[keys[0]][:])
                    logger.warning(f"使用默认数据集 '{keys[0]}' 作为LAI数据: {file_path}")
                else:
                    logger.warning(f"未找到数据集: {file_path}")
                    return None
            
            # 转置为(1800,3600)
            if lai_data.shape == (3600, 1800):
                lai_data = lai_data.T
            elif lai_data.shape != (1800, 3600):
                logger.error(f"不支持的LAI数据形状: {lai_data.shape}")
                return None
            
            return lai_data
    except Exception as e:
        logger.error(f"加载LAI文件失败: {file_path} - {str(e)}")
        return None

def load_pft_matrix(file_path):
    """加载并旋转PFT矩阵数据"""
    if not file_path or not os.path.exists(file_path): 
        return None
    
    try:
        pft_data = {}
        with h5py.File(file_path, 'r') as f:
            # 定义植被类型映射
            pft_mapping = {
                'water': ['water', 'WATER'],
                'bare': ['bare', 'bareland', 'BARE'],
                'snowice': ['snowice', 'snow', 'ice', 'SNOWICE'],
                'built': ['built', 'urban', 'BUILT'],
                'grassnat': ['grassnat', 'grass_natural', 'GRASSNAT'],
                'grassman': ['grassman', 'grass_managed', 'GRASSMAN'],
                'shrubbd': ['shrubbd', 'shrub_bd', 'SHRUBBD'],
                'shrubbe': ['shrubbe', 'shrub_be', 'SHRUBBE'],
                'shrubnd': ['shrubnd', 'shrub_nd', 'SHRUBND'],
                'shrubne': ['shrubne', 'shrub_ne', 'SHRUBNE'],
                'treebd': ['treebd', 'tree_bd', 'TREEBD'],
                'treebe': ['treebe', 'tree_be', 'TREEBE'],
                'treend': ['treend', 'tree_nd', 'TREEND'],
                'treene': ['treene', 'tree_ne', 'TREENE']
            }
            
            # 查找匹配的数据集
            matched_datasets = {}
            for target, aliases in pft_mapping.items():
                for alias in aliases:
                    if alias in f:
                        data = np.array(f[alias][:])
                        # 旋转和转置
                        if data.shape == (3600, 7200):
                            # 从0.05度到0.1度的降采样
                            data = np.rot90(data, k=-1)
                            # 平均聚合到0.1度
                            data = (data[:, ::2] + data[:, 1::2]) / 2.0
                            pft_data[target] = data.T
                        elif data.shape == (3600, 1800):
                            pft_data[target] = np.rot90(data, k=-1)
                        elif data.shape == (1800, 3600):
                            pft_data[target] = data
                        else:
                            logger.warning(f"未知PFT形状: {data.shape} for {target}")
                            pft_data[target] = np.zeros((1800, 3600))
                        matched_datasets[alias] = True
                        break
                else:
                    logger.warning(f"未找到 {target} 的PFT数据")
                    pft_data[target] = np.zeros((1800, 3600))
            
            # 输出找到的数据集
            if matched_datasets:
                logger.info(f"从PFT文件中加载的数据集: {', '.join(matched_datasets.keys())}")
            else:
                logger.warning("未找到匹配的PFT数据集")
                
        return pft_data
    except Exception as e:
        logger.error(f"加载PFT文件失败: {file_path} - {str(e)}")
        return None

def create_land_mask(rows=1800, cols=3600):
    """创建简单的陆地掩膜"""
    # 创建纬度数组
    lats = np.linspace(90, -90, rows)
    
    # 创建陆地掩膜 (排除极地和海洋)
    land_mask = np.zeros((rows, cols), dtype=bool)
    for i in range(rows):
        if -60 <= lats[i] <= 80:  # 排除南极和北极
            land_mask[i, :] = True
    
    return land_mask

def prepare_shared_features(vod_data, lai_data, pft_data, valid_mask, 
                            rows=1800, cols=3600, pft_features=PFT_FEATURES):
    """准备共享特征矩阵（LAI, Hveg, PFT等）避免为每个模型重复计算"""
    # 获取有效索引
    valid_indices = np.where(valid_mask)
    num_valid = len(valid_indices[0])
    
    if num_valid == 0:
        logger.warning("无有效数据点")
        return None, None, None
    
    # 初始化共享特征矩阵 (num_valid, num_features)
    shared_features = np.zeros((num_valid, 0), dtype=np.float32)
    shared_feature_names = []
    
    # 1. LAI特征
    if lai_data is not None:
        lai_val = lai_data
        shared_features = np.column_stack((shared_features, lai_val[valid_indices]))
        shared_feature_names.append('LAI')
    else:
        logger.warning("LAI特征缺失，使用0填充")
        shared_features = np.column_stack((shared_features, np.zeros(num_valid)))
        shared_feature_names.append('LAI')
    
    # 2. Hveg特征
    if Hveg_DATA is not None:
        shared_features = np.column_stack((shared_features, Hveg_DATA[valid_indices]))
        shared_feature_names.append('Hveg')
    else:
        logger.warning("Hveg特征缺失，使用0填充")
        shared_features = np.column_stack((shared_features, np.zeros(num_valid)))
        shared_feature_names.append('Hveg')
    
    # 3. PFT特征
    if pft_data is not None:
        # PFT特征映射 (简化版)
        pft_mapping = {
            'Grass_man': 'grassman',
            'Grass_nat': 'grassnat',
            'Shrub_bd': 'shrubbd',
            'Shrub_be': 'shrubbe',
            'Shrub_nd': 'shrubnd',
            'Shrub_ne': 'shrubne',
            'Tree_bd': 'treebd',
            'Tree_be': 'treebe',
            'Tree_nd': 'treend',
            'Tree_ne': 'treene'
        }
        
        for model_feature in pft_features:
            pft_key = pft_mapping.get(model_feature)
            if pft_key and pft_key in pft_data:
                pft_val = pft_data[pft_key] / 100.0 # 处理成比例
                shared_features = np.column_stack((shared_features, pft_val[valid_indices]))
                shared_feature_names.append(model_feature)
            else:
                logger.warning(f"PFT特征缺失: {model_feature}，使用0填充")
                shared_features = np.column_stack((shared_features, np.zeros(num_valid)))
                shared_feature_names.append(model_feature)
    else:
        logger.warning("PFT数据缺失，所有PFT特征使用0填充")
        for model_feature in pft_features:
            shared_features = np.column_stack((shared_features, np.zeros(num_valid)))
            shared_feature_names.append(model_feature)
    
    logger.info(f"准备共享特征完成: {shared_features.shape} 形状, 特征数: {len(shared_feature_names)}")
    return shared_features, shared_feature_names, valid_indices

# 修复函数：添加shared_feature_names参数
def add_vod_features(model_config, vod_data, shared_features, shared_feature_names, valid_indices):
    """为共享特征矩阵添加特定模型的VOD特征"""
    band = model_config['band']
    pol = model_config['pol']
    name = model_config['name']
    
    # 1. VOD特征处理
    vod_features = []
    vod_feature_names = []
    
    if pol in ['H', 'V']:  # 单极化模型
        vod_key = f"{band.lower()}_vod_{pol.lower()}"
        if vod_key in vod_data:
            vod_val = vod_data[vod_key]
            vod_val = np.nan_to_num(vod_val, nan=0.0)
            vod_val = vod_val
            vod_features.append(vod_val[valid_indices])
            vod_feature_names.append('VOD')
        else:
            logger.warning(f"VOD特征缺失: {vod_key}")
            return None, None
    
    elif pol == 'HV':  # 双极化模型
        vod_key_h = f"{band.lower()}_vod_h"
        vod_key_v = f"{band.lower()}_vod_v"
        
        if vod_key_h in vod_data and vod_key_v in vod_data:
            vod_val_h = vod_data[vod_key_h]
            vod_val_v = vod_data[vod_key_v]
            
            vod_val_h = np.nan_to_num(vod_val_h, nan=0.0)
            vod_val_v = np.nan_to_num(vod_val_v, nan=0.0)
            
            vod_val_h = vod_val_h
            vod_val_v = vod_val_v
                
            vod_features.append(vod_val_h[valid_indices])
            vod_features.append(vod_val_v[valid_indices])
            vod_feature_names.extend(['VOD-Hpol', 'VOD-Vpol'])
        else:
            missing = []
            if vod_key_h not in vod_data: missing.append(vod_key_h)
            if vod_key_v not in vod_data: missing.append(vod_key_v)
            logger.warning(f"双极化VOD特征缺失: {', '.join(missing)}")
            return None, None
    
    # 合并所有特征
    if vod_features:
        # 如果共享特征为空（理论上不会）
        if shared_features is None:
            shared_features = np.array([]).reshape(len(vod_features[0]), 0)
        
        # 合并VOD特征
        final_features = np.column_stack((shared_features, *vod_features))
        final_feature_names = shared_feature_names + vod_feature_names
        return final_features, final_feature_names
    else:
        logger.warning(f"{name} 模型无VOD特征")
        return None, None

def predict_vwc_single_model(model, features, feature_names, rows, cols, valid_mask):
    """为单个模型预测VWC"""
    if features is None:
        return None
    
    # 初始化结果数组
    vwc_array = np.full((rows, cols), -9999.0, dtype=np.float32)
    valid_indices = np.where(valid_mask)
    num_valid = len(valid_indices[0])
    
    if num_valid == 0:
        logger.warning("无有效数据点")
        return vwc_array
    
    try:
        # 检查特征顺序是否匹配
        if hasattr(model, 'feature_names_in_'):
            model_features = model.feature_names_in_
            
            if feature_names:
                # 检查特征顺序是否一致
                if list(feature_names) != list(model_features):
                    logger.warning("特征顺序不匹配! 尝试重新排序...")
                    try:
                        # 重新排序特征以匹配模型期望
                        sorted_indices = [feature_names.index(f) for f in model_features]
                        features = features[:, sorted_indices]
                        feature_names = [feature_names[i] for i in sorted_indices]
                        logger.info("特征重新排序完成")
                    except Exception as e:
                        logger.error(f"特征重新排序失败: {str(e)}")
                        return None
        
        # 预测 (分批处理避免内存溢出)
        predictions = np.zeros(num_valid, dtype=np.float32)
        chunk_size = 500000  # 增大块大小以提升效率
        chunks = (num_valid + chunk_size - 1) // chunk_size
        
        for chunk_idx in tqdm(range(chunks), desc=f"模型预测(分块)", leave=False):
            start = chunk_idx * chunk_size
            end = min((chunk_idx + 1) * chunk_size, num_valid)
            X_chunk = features[start:end]
            predictions[start:end] = model.predict(X_chunk)
        
        # 填充结果数组
        vwc_array[valid_indices] = predictions
        
        # 添加预测统计
        valid_predictions = predictions[~np.isnan(predictions)]
        if len(valid_predictions) > 0:
            logger.info(f"预测值统计: min={np.min(valid_predictions):.4f}, "
                        f"max={np.max(valid_predictions):.4f}, "
                        f"mean={np.mean(valid_predictions):.4f}")
        else:
            logger.warning("无有效预测值")
        
        return vwc_array
        
    except Exception as e:
        logger.error(f"预测失败: {str(e)}")
        return None

def process_one_date(date, land_mask, overwrite=False):
    """处理单日数据，重构以提高效率"""
    # 1. 输出文件路径
    output_filename = f'VWC-{date.strftime("%Y%m%d")}.tif'
    output_path = os.path.join(OUTPUT_PATH, output_filename)
    
    # 检查文件是否已存在
    if os.path.exists(output_path) and not overwrite:
        logger.info(f"文件已存在: {output_path} - 跳过")
        return True
    
    # 如果存在且需覆盖则删除
    if os.path.exists(output_path):
        try:
            os.remove(output_path)
            logger.info(f"已删除现有文件: {output_path}")
        except Exception as e:
            logger.error(f"删除文件失败: {output_path} - {str(e)}")
            return False
    
    logger.info(f"处理日期: {date.strftime('%Y-%m-%d')}")
    band_predictions = {}
    
    try:
        # 2. 加载VOD数据
        vod_file = get_vod_file(date)
        if not vod_file or not os.path.exists(vod_file):
            logger.error(f"VOD文件未找到: {date}")
            return False
            
        logger.info(f"加载VOD文件: {vod_file}")
        vod_data = {}
        with h5py.File(vod_file, 'r') as f:
            # 加载所有VOD数据集
            for key in f.keys():
                if key.lower().startswith(('ku_vod', 'x_vod', 'c_vod')):
                    vod_data[key.lower()] = np.array(f[key][:]).T
            
            # 加载QC
            qc_data = np.array(f['QC'][:,:]).T
        
        # 3. 创建有效掩膜
        qc_mask = qc_data == 0  # QC=0 表示有效
        lai_data = None  # 将在下一步加载
        valid_mask = qc_mask & land_mask  # 先加入QC和陆地掩码
        
        # 4. 加载并插值LAI数据（使用缓存）
        try:
            # 获取前后月中日期
            prev_mid, next_mid = get_month_centers(date)
            
            # 检查缓存 (前月)
            prev_key = (prev_mid.year, prev_mid.month)
            if prev_key in LAI_MONTH_CACHE:
                logger.info(f"使用缓存的前月LAI: {prev_key[0]}-{prev_key[1]:02d}")
                lai_prev_data = LAI_MONTH_CACHE[prev_key]
            else:
                lai_prev_file = get_lai_file(prev_mid.year, prev_mid.month)
                lai_prev_data = load_lai_matrix(lai_prev_file)
                if lai_prev_data is None:
                    logger.error(f"前月LAI文件加载失败: {prev_key[0]}-{prev_key[1]:02d}")
                    return False
                LAI_MONTH_CACHE[prev_key] = lai_prev_data
            
            # 检查缓存 (后月)
            next_key = (next_mid.year, next_mid.month)
            if next_key in LAI_MONTH_CACHE:
                logger.info(f"使用缓存的后月LAI: {next_key[0]}-{next_key[1]:02d}")
                lai_next_data = LAI_MONTH_CACHE[next_key]
            else:
                lai_next_file = get_lai_file(next_mid.year, next_mid.month)
                lai_next_data = load_lai_matrix(lai_next_file)
                if lai_next_data is None:
                    logger.error(f"后月LAI文件加载失败: {next_key[0]}-{next_key[1]:02d}")
                    return False
                LAI_MONTH_CACHE[next_key] = lai_next_data
            
            # LAI插值
            total_days = (next_mid - prev_mid).days
            current_offset = (date - prev_mid).days
            weight = current_offset / total_days
            lai_data = lai_prev_data * (1 - weight) + lai_next_data * weight
            lai_data = np.nan_to_num(lai_data, nan=0.0)
            
            # 更新有效掩膜
            lai_mask = ~np.isnan(lai_data)  # 非NaN的LAI值有效
            valid_mask = valid_mask & lai_mask
        except Exception as e:
            logger.error(f"LAI处理失败: {str(e)}", exc_info=True)
            return False
        
        # 5. 加载PFT数据（使用缓存）
        year_key = date.year
        if year_key in PFT_YEAR_CACHE:
            logger.info(f"使用缓存的{year_key}年PFT数据")
            pft_data = PFT_YEAR_CACHE[year_key]
        else:
            pft_file = get_pft_file(year_key)
            pft_data = load_pft_matrix(pft_file)
            if pft_data is None:
                logger.error(f"PFT文件加载失败: {year_key}")
                return False
            PFT_YEAR_CACHE[year_key] = pft_data
        
        # 6. 准备共享特征（只需一次）
        num_valid = np.count_nonzero(valid_mask)
        logger.info(f"有效数据点数量: {num_valid} (比例: {num_valid/(1800 * 3600)*100:.2f}%)")
        
        shared_features, shared_feature_names, valid_indices = prepare_shared_features(
            vod_data, lai_data, pft_data, valid_mask
        )
        
        if shared_features is None or num_valid == 0:
            logger.warning("无有效数据点，跳过预测")
            return True
        
        # 7. 为每个模型预测
        for config in MODEL_CONFIGS:
            band = config['band']
            pol = config['pol']
            name = config['name']
            
            # 检查模型可用性
            if (band, pol) not in MODELS:
                logger.warning(f"模型不可用: {name} - 跳过")
                continue
            
            logger.info(f"为模型 {name} 准备特征...")
            
            # 添加模型特定的VOD特征
            # 修复：传入shared_feature_names参数
            features, feature_names = add_vod_features(
                config, vod_data, shared_features, shared_feature_names, valid_indices
            )
            
            if features is None:
                continue
            
            # 确保特征顺序匹配模型期望
            model = MODELS[(band, pol)]
            if hasattr(model, 'feature_names_in_'):
                model_features = model.feature_names_in_
                try:
                    if feature_names != list(model_features):
                        logger.warning(f"特征顺序不匹配! 为模型 {name} 重新排序...")
                        # 重新排序特征以匹配模型期望
                        sorted_indices = [feature_names.index(f) for f in model_features]
                        features = features[:, sorted_indices]
                        feature_names = [feature_names[i] for i in sorted_indices]
                        logger.info("特征重新排序完成")
                except Exception as e:
                    logger.error(f"特征重新排序失败: {str(e)}")
                    continue
            
            # 预测
            prediction = predict_vwc_single_model(
                model, features, feature_names, 1800, 3600, valid_mask
            )
            
            if prediction is not None:
                band_predictions[name] = prediction
                logger.info(f"模型 {name} 预测完成")
            else:
                logger.warning(f"模型 {name} 预测失败")
        
        # 8. 创建多波段TIFF
        if band_predictions:
            bands = []
            band_names = []
            for config in MODEL_CONFIGS:
                name = config['name']
                if name in band_predictions:
                    bands.append(band_predictions[name])
                    band_names.append(name)
            
            if bands:
                success = create_multiband_geotiff(bands, band_names, output_path)
                if success:
                    logger.info(f"成功保存多波段TIFF: {output_path}")
                    return True
                else:
                    logger.error(f"保存失败: {output_path}")
                    return False
            else:
                logger.warning("无有效波段数据可保存")
                return False
        else:
            logger.warning("无预测数据可保存")
            return False
            
    except Exception as e:
        logger.error(f"处理日期 {date} 错误: {str(e)}", exc_info=True)
        return False

# ============================== 主处理流程 ==============================
def generate_vwc(overwrite=False):
    """生成VWC影像（并行优化版本）"""
    if not MODELS:
        logger.error("无可用模型，程序终止")
        return
    
    # 创建陆地掩膜（一次性）
    logger.info("创建陆地掩膜...")
    land_mask = create_land_mask()
    logger.info(f"陆地掩膜创建完成: 有效点={np.count_nonzero(land_mask)}({np.count_nonzero(land_mask)/(1800 * 3600)*100:.2f}%)")
    
    # 日期序列 (2015-2016)
    dates = [START_DATE + timedelta(days=i) 
             for i in range((END_DATE - START_DATE).days + 1)]
    
    # 并行处理（线程池）
    max_workers = 4  # 根据CPU物理核心数调整
    futures = []
    completed, failed = 0, 0
    
    logger.info(f"开始并行处理{len(dates)}天数据，最大线程数: {max_workers}")
    
    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        # 提交任务
        for date in dates:
            futures.append(executor.submit(process_one_date, date, land_mask, overwrite))
        
        # 进度监控
        for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="生成VWC影像"):
            try:
                success = future.result()
                if success:
                    completed += 1
                else:
                    failed += 1
            except Exception as e:
                logger.error(f"任务失败: {str(e)}", exc_info=True)
                failed += 1
    
    # 清理缓存
    LAI_MONTH_CACHE.clear()
    PFT_YEAR_CACHE.clear()
    
    logger.info(f"处理完成: {completed}天成功, {failed}天失败")

# ============================== 执行主函数 ==============================
if __name__ == "__main__":
    logger.info(f"当前目录: {os.getcwd()}")
    logger.info(f"输出目录: {OUTPUT_PATH}")
    logger.info(f"开始处理: {START_DATE.strftime('%Y-%m-%d')} 到 {END_DATE.strftime('%Y-%m-%d')}")
    
    # 设置并行处理的线程数（可根据需要调整）
    max_workers = 4  # 建议设置为物理核心数或略多
    
    # 运行主处理流程
    generate_vwc(overwrite=False)
    logger.info("VWC影像生成完成")

2025-08-20 14:35:53,007 - INFO - 成功加载 Ku波段H极化模型（初始实验设置）
2025-08-20 14:35:54,161 - INFO - 成功加载 Ku波段V极化模型（初始实验设置）
2025-08-20 14:35:54,925 - INFO - 成功加载 X波段H极化模型（初始实验设置）
2025-08-20 14:35:55,826 - INFO - 成功加载 X波段V极化模型（初始实验设置）
2025-08-20 14:35:56,523 - INFO - 成功加载 C波段H极化模型（初始实验设置）
2025-08-20 14:35:57,276 - INFO - 成功加载 C波段V极化模型（初始实验设置）
2025-08-20 14:35:58,949 - INFO - 成功加载 Ku波段HV极化模型（初始实验设置）
2025-08-20 14:36:00,946 - INFO - 成功加载 X波段HV极化模型（初始实验设置）
2025-08-20 14:36:03,657 - INFO - 成功加载 C波段HV极化模型（初始实验设置）
2025-08-20 14:36:03,660 - INFO - 加载Hveg数据: E:\data\CanopyHeight\CH.mat
2025-08-20 14:36:04,238 - INFO - 成功加载Hveg数据: Hveg
2025-08-20 14:36:04,241 - INFO - Hveg数据加载完成
2025-08-20 14:36:04,250 - INFO - 当前目录: D:\Python\jupyter\VWC_RFRegression
2025-08-20 14:36:04,251 - INFO - 输出目录: E:\data\VWC\VWCMap\Daily
2025-08-20 14:36:04,255 - INFO - 开始处理: 2015-01-01 到 2015-12-31
2025-08-20 14:36:04,257 - INFO - 创建陆地掩膜...
2025-08-20 14:36:04,264 - INFO - 陆地掩膜创建完成: 有效点=5040000(77.78%)
2025-08-20 14:36:04,266 - I

In [2]:
# 地物类型掩膜
import os
import numpy as np
from osgeo import gdal, osr
import logging
from tqdm import tqdm
import concurrent.futures
import multiprocessing
import time

# ============================== 配置日志 ==============================
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(),
        logging.FileHandler('vwc_mask_water_bare_snowice_optimized.log')
    ]
)
logger = logging.getLogger()

# ============================== 配置参数 ==============================
VWC_DIR = r'E:\data\VWC\VWCMap\Daily'  # 每日VWC数据目录
MAIN_TYPE_FILE = r'E:\data\ESACCI PFT\Resample\Data\mainType.tif'  # 主地物类型文件
NODATA_VALUE = -9999.0  # 无效值
NUM_WORKERS = multiprocessing.cpu_count()  # 使用所有CPU核心

# 需要掩膜的地物类型 (水、裸地、冰雪)
MASK_TYPES = [0, 1, 2]  # 0=water, 1=bare, 2=snowice

# ============================== 辅助函数 ==============================
def load_main_type_mask():
    """加载主地物类型掩膜 (优化版本)"""
    try:
        # 使用内存映射提高读取速度
        ds = gdal.Open(MAIN_TYPE_FILE, gdal.GA_ReadOnly)
        if ds is None:
            logger.error(f"无法打开主地物类型文件: {MAIN_TYPE_FILE}")
            return None
        
        # 读取整个波段到内存
        band = ds.GetRasterBand(1)
        main_type = band.ReadAsArray()
        
        # 创建掩膜 (需要掩膜的类型为True)
        mask = np.isin(main_type, MASK_TYPES)
        
        # 获取NoData值并处理
        nodata = band.GetNoDataValue()
        if nodata is not None:
            # 排除NoData区域
            mask = np.logical_and(mask, main_type != nodata)
        
        ds = None
        logger.info(f"成功加载主地物类型掩膜: {mask.shape}")
        return mask
    except Exception as e:
        logger.error(f"加载主地物类型文件失败: {str(e)}")
        return None

def process_vwc_file(file_path, mask):
    """处理单个VWC文件 (优化版本)"""
    try:
        start_time = time.time()
        
        # 以读写模式打开文件
        ds = gdal.Open(file_path, gdal.GA_Update)
        if ds is None:
            logger.warning(f"无法打开VWC文件: {file_path}")
            return False
        
        # 获取波段数量
        num_bands = ds.RasterCount
        rows = ds.RasterYSize
        cols = ds.RasterXSize
        
        # 验证尺寸匹配
        if rows != mask.shape[0] or cols != mask.shape[1]:
            logger.error(f"尺寸不匹配: VWC文件({rows}x{cols}) vs 掩膜({mask.shape[0]}x{mask.shape[1]})")
            ds = None
            return False
        
        # 一次性读取所有波段数据
        band_data = []
        for band_idx in range(1, num_bands + 1):
            band = ds.GetRasterBand(band_idx)
            data = band.ReadAsArray()
            band_data.append(data)
        
        # 应用掩膜到所有波段
        for i in range(num_bands):
            band_data[i][mask] = NODATA_VALUE
        
        # 一次性写回所有波段
        for band_idx in range(1, num_bands + 1):
            band = ds.GetRasterBand(band_idx)
            band.WriteArray(band_data[band_idx-1])
        
        # 清理
        ds.FlushCache()
        ds = None
        
        process_time = time.time() - start_time
        logger.debug(f"处理完成: {file_path} (耗时: {process_time:.2f}s)")
        return True
    except Exception as e:
        logger.error(f"处理文件 {file_path} 失败: {str(e)}")
        return False

def process_vwc_files_parallel(mask):
    """并行处理所有VWC文件"""
    # 获取所有VWC文件
    vwc_files = []
    for filename in os.listdir(VWC_DIR):
        if filename.startswith('VWC-') and filename.endswith('.tif'):
            file_path = os.path.join(VWC_DIR, filename)
            vwc_files.append(file_path)
    
    if not vwc_files:
        logger.error("未找到任何VWC文件")
        return 0
    
    logger.info(f"找到 {len(vwc_files)} 个VWC文件，使用 {NUM_WORKERS} 个线程并行处理")
    
    # 使用线程池并行处理
    success_count = 0
    with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor:
        # 提交所有任务
        future_to_file = {executor.submit(process_vwc_file, file_path, mask): file_path 
                          for file_path in vwc_files}
        
        # 使用tqdm显示进度
        for future in tqdm(concurrent.futures.as_completed(future_to_file), 
                          total=len(vwc_files), desc="处理VWC文件"):
            file_path = future_to_file[future]
            try:
                result = future.result()
                if result:
                    success_count += 1
            except Exception as e:
                logger.error(f"处理文件 {file_path} 时出错: {str(e)}")
    
    return success_count

# ============================== 主处理流程 ==============================
def main():
    logger.info(f"VWC目录: {VWC_DIR}")
    logger.info(f"主地物类型文件: {MAIN_TYPE_FILE}")
    logger.info(f"需要掩膜的地物类型: {MASK_TYPES} (水、裸地、冰雪)")
    logger.info(f"使用 {NUM_WORKERS} 个线程进行并行处理")
    
    start_time = time.time()
    
    # 加载主地物类型掩膜
    mask = load_main_type_mask()
    if mask is None:
        logger.error("无法加载掩膜，程序终止")
        return
    
    # 处理所有VWC文件
    success_count = process_vwc_files_parallel(mask)
    
    total_time = time.time() - start_time
    logger.info(f"处理完成! 成功处理 {success_count}/{len(os.listdir(VWC_DIR))} 个文件")
    logger.info(f"总耗时: {total_time:.2f}秒")

# ============================== 执行主函数 ==============================
if __name__ == "__main__":
    try:
        main()
        logger.info("VWC文件掩膜处理完成")
    except Exception as e:
        logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True)

2025-08-20 16:36:59,822 - INFO - VWC目录: E:\data\VWC\VWCMap\Daily
2025-08-20 16:36:59,824 - INFO - 主地物类型文件: E:\data\ESACCI PFT\Resample\Data\mainType.tif
2025-08-20 16:36:59,825 - INFO - 需要掩膜的地物类型: [0, 1, 2] (水、裸地、冰雪)
2025-08-20 16:36:59,827 - INFO - 使用 16 个线程进行并行处理
2025-08-20 16:36:59,930 - INFO - 成功加载主地物类型掩膜: (1800, 3600)
2025-08-20 16:36:59,934 - INFO - 找到 365 个VWC文件，使用 16 个线程并行处理
处理VWC文件: 100%|███████████████████████████████████████████████████████████████████| 365/365 [27:36<00:00,  4.54s/it]
2025-08-20 17:04:36,599 - INFO - 处理完成! 成功处理 365/365 个文件
2025-08-20 17:04:36,603 - INFO - 总耗时: 1656.77秒
2025-08-20 17:04:36,605 - INFO - VWC文件掩膜处理完成


## 8日合成数据

In [1]:
import os
import numpy as np
from osgeo import gdal, osr
import datetime
import re
import logging
from tqdm import tqdm

# ============================== 配置日志 ==============================
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(),
        logging.FileHandler('vwc_8day_composite_corrected.log')
    ]
)
logger = logging.getLogger()

# ============================== 配置参数 ==============================
INPUT_DIR = r'E:\data\VWC\VWCMap\Daily'  # 每日VWC数据目录
OUTPUT_DIR = r'E:\data\VWC\VWCMap\8Day'  # 8日合成输出目录
START_YEAR = 2015
END_YEAR = 2015
NODATA_VALUE = -9999.0

# 确保输出目录存在
os.makedirs(OUTPUT_DIR, exist_ok=True)

# ============================== 辅助函数 ==============================
def get_daily_files():
    """获取所有每日文件，并按日期排序"""
    files = []
    pattern = re.compile(r'VWC-(\d{4})(\d{2})(\d{2})\.tif$')  # 匹配YYYYMMDD格式的文件名
    
    for filename in os.listdir(INPUT_DIR):
        if not filename.endswith('.tif'):
            continue
            
        match = pattern.match(filename)
        if not match:
            continue
            
        year = int(match.group(1))
        month = int(match.group(2))
        day = int(match.group(3))
        file_date = datetime.date(year, month, day)
        
        # 只处理指定年份范围内的文件
        if START_YEAR <= year <= END_YEAR:
            file_path = os.path.join(INPUT_DIR, filename)
            files.append({
                'path': file_path,
                'date': file_date,
                'year': year
            })
    
    # 按日期排序
    files.sort(key=lambda x: x['date'])
    return files

def create_composite_geotiff(data, output_path, start_date, end_date, nodata=NODATA_VALUE):
    """创建8日合成的地理参考TIFF文件"""
    try:
        driver = gdal.GetDriverByName('GTiff')
        rows, cols, bands = data.shape
        
        # 创建输出数据集
        out_ds = driver.Create(
            output_path,
            cols,
            rows,
            bands,
            gdal.GDT_Float32,
            options=['COMPRESS=LZW', 'TILED=YES']
        )
        
        # 设置地理变换
        out_ds.SetGeoTransform((-180, 0.1, 0, 90, 0, -0.1))
        
        # 设置坐标系 (WGS84)
        srs = osr.SpatialReference()
        srs.ImportFromEPSG(4326)
        out_ds.SetProjection(srs.ExportToWkt())
        
        # 添加日期元数据
        out_ds.SetMetadata({
            'START_DATE': start_date.strftime('%Y-%m-%d'),
            'END_DATE': end_date.strftime('%Y-%m-%d'),
            'NUM_DAYS': str((end_date - start_date).days + 1)
        })
        
        # 写入每个波段
        for band_idx in range(bands):
            band = out_ds.GetRasterBand(band_idx+1)
            band.WriteArray(data[:, :, band_idx])
            band.SetNoDataValue(nodata)
            band.SetDescription(f"Model_{band_idx+1}")
        
        # 清理
        out_ds.FlushCache()
        out_ds = None
        logger.info(f"成功创建8日合成GeoTIFF: {output_path}")
        return True
    except Exception as e:
        logger.error(f"创建GeoTIFF失败: {str(e)}")
        return False

def generate_8day_composites():
    """生成8日合成数据 (按年处理)"""
    # 获取所有每日文件
    daily_files = get_daily_files()
    if not daily_files:
        logger.error("未找到任何每日文件")
        return
    
    logger.info(f"找到 {len(daily_files)} 个每日文件")
    
    # 按年份分组
    files_by_year = {}
    for year in range(START_YEAR, END_YEAR + 1):
        files_by_year[year] = [f for f in daily_files if f['year'] == year]
    
    # 处理每一年
    for year in range(START_YEAR, END_YEAR + 1):
        year_files = files_by_year[year]
        if not year_files:
            logger.warning(f"{year}年没有每日文件")
            continue
            
        logger.info(f"处理 {year} 年 (共 {len(year_files)} 个每日文件)")
        
        # 确定该年的起始和结束日期
        start_date = datetime.date(year, 1, 1)
        end_date = datetime.date(year, 12, 31)
        
        # 创建8日周期
        current_start = start_date
        composite_groups = []
        
        # 构建8日周期 (严格限制在该年内)
        while current_start <= end_date:
            # 计算周期结束日期 (最多7天后)
            current_end = current_start + datetime.timedelta(days=7)
            
            # 如果结束日期超出该年范围，调整到12月31日
            if current_end > end_date:
                current_end = end_date
                
            # 收集当前周期内的文件
            group_files = []
            for file_info in year_files:
                if current_start <= file_info['date'] <= current_end:
                    group_files.append(file_info)
            
            if group_files:
                composite_groups.append({
                    'start_date': current_start,
                    'end_date': current_end,
                    'files': group_files
                })
            
            # 移动到下一个周期 (从结束日期的下一天开始)
            current_start = current_end + datetime.timedelta(days=1)
            
            # 如果下一个起始日期已超出该年范围，结束循环
            if current_start > end_date:
                break
        
        logger.info(f"{year}年共生成 {len(composite_groups)} 个8日周期")
        
        # 处理该年的每个8日周期
        for group in tqdm(composite_groups, desc=f"处理{year}年周期"):
            # 使用起始日期作为文件名标识
            start_str = group['start_date'].strftime('%Y%m%d')
            end_str = group['end_date'].strftime('%Y%m%d')
            
            # 文件名格式：VWC-YYYYMMDD.tif（使用起始日期）
            output_filename = f'VWC-{start_str}.tif'
            output_path = os.path.join(OUTPUT_DIR, output_filename)
            
            if os.path.exists(output_path):
                logger.info(f"文件已存在: {output_path}")
                continue
            
            num_days = len(group['files'])
            logger.info(f"处理周期: {start_str} 到 {end_str} ({num_days}天)")
            
            # 初始化数据立方体
            num_bands = None
            composite_sum = None
            valid_count = None
            rows, cols = 0, 0
            
            # 处理周期内的每个文件
            for file_info in group['files']:
                try:
                    ds = gdal.Open(file_info['path'])
                    if ds is None:
                        logger.warning(f"无法打开文件: {file_info['path']}")
                        continue
                    
                    # 获取图像尺寸和波段数（只在第一次确定）
                    if rows == 0:
                        rows = ds.RasterYSize
                        cols = ds.RasterXSize
                        num_bands = ds.RasterCount
                        composite_sum = np.zeros((rows, cols, num_bands), dtype=np.float32)
                        valid_count = np.zeros((rows, cols, num_bands), dtype=np.uint16)
                    
                    # 读取每个波段
                    for band_idx in range(num_bands):
                        band = ds.GetRasterBand(band_idx+1)
                        data = band.ReadAsArray()
                        nodata = band.GetNoDataValue()
                        
                        if nodata is None:
                            nodata = NODATA_VALUE
                        
                        # 创建有效值掩膜
                        valid_mask = (data != nodata) & (~np.isnan(data))
                        
                        # 累加有效值
                        composite_sum[:, :, band_idx] += np.where(valid_mask, data, 0)
                        valid_count[:, :, band_idx] += valid_mask.astype(np.uint16)
                    
                    ds = None
                except Exception as e:
                    logger.error(f"处理文件 {file_info['path']} 错误: {str(e)}")
            
            # 计算平均值
            if composite_sum is not None and valid_count is not None:
                # 避免除以零
                with np.errstate(divide='ignore', invalid='ignore'):
                    composite_avg = np.where(
                        valid_count > 0,
                        composite_sum / valid_count,
                        NODATA_VALUE  # 无效值
                    )
                
                # 保存合成结果
                success = create_composite_geotiff(
                    composite_avg, 
                    output_path,
                    group['start_date'],
                    group['end_date']
                )
                if success:
                    # 计算有效数据百分比
                    valid_percent = (valid_count > 0).sum() / (rows * cols * num_bands) * 100
                    logger.info(f"8日合成完成: {start_str} 到 {end_str} (有效数据: {valid_percent:.2f}%)")
                else:
                    logger.error(f"保存失败: {output_path}")
            else:
                logger.warning(f"无有效数据可合成: {start_str} 到 {end_str}")
    
    logger.info("8日合成处理完成")

# ============================== 执行主函数 ==============================
if __name__ == "__main__":
    logger.info(f"输入目录: {INPUT_DIR}")
    logger.info(f"输出目录: {OUTPUT_DIR}")
    logger.info(f"生成8日合成数据: {START_YEAR} 到 {END_YEAR}")
    
    try:
        generate_8day_composites()
        logger.info("8日合成数据生成完成")
    except Exception as e:
        logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True)

2025-08-20 17:38:17,291 - INFO - 输入目录: E:\data\VWC\VWCMap\Daily
2025-08-20 17:38:17,292 - INFO - 输出目录: E:\data\VWC\VWCMap\8Day
2025-08-20 17:38:17,293 - INFO - 生成8日合成数据: 2015 到 2015
2025-08-20 17:38:17,294 - INFO - 找到 365 个每日文件
2025-08-20 17:38:17,299 - INFO - 处理 2015 年 (共 365 个每日文件)
2025-08-20 17:38:17,301 - INFO - 2015年共生成 46 个8日周期
2025-08-20 17:38:17,324 - INFO - 处理周期: 20150101 到 20150108 (8天)                            | 0/46 [00:00<?, ?it/s]
2025-08-20 17:38:29,565 - INFO - 成功创建8日合成GeoTIFF: E:\data\VWC\VWCMap\8Day\VWC-20150101.tif
2025-08-20 17:38:29,651 - INFO - 8日合成完成: 20150101 到 20150108 (有效数据: 9.68%)
2025-08-20 17:38:29,654 - INFO - 处理周期: 20150109 到 20150116 (8天)                    | 1/46 [00:12<09:14, 12.33s/it]
2025-08-20 17:38:40,987 - INFO - 成功创建8日合成GeoTIFF: E:\data\VWC\VWCMap\8Day\VWC-20150109.tif
2025-08-20 17:38:41,058 - INFO - 8日合成完成: 20150109 到 20150116 (有效数据: 9.57%)
2025-08-20 17:38:41,058 - INFO - 处理周期: 20150117 到 20150124 (8天)                    | 2/46 [00:23<08:38

## 月度数据合成

In [2]:
# 月度数据合成
import os
import numpy as np
from osgeo import gdal, osr
import datetime
import re
import logging
from tqdm import tqdm

# ============================== 配置日志 ==============================
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(),
        logging.FileHandler('vwc_monthly_composite.log')
    ]
)
logger = logging.getLogger()

# ============================== 配置参数 ==============================
INPUT_DIR = r'E:\data\VWC\VWCMap\Daily'  # 每日VWC数据目录
OUTPUT_DIR = r'E:\data\VWC\VWCMap\Monthly'  # 月度合成输出目录

# 起始和结束年份
START_YEAR = 2015
END_YEAR = 2015

# 确保输出目录存在
os.makedirs(OUTPUT_DIR, exist_ok=True)

# ============================== 辅助函数 ==============================
def get_daily_files_for_month(year, month):
    """获取指定年月的所有每日文件（仅TIFF格式）"""
    files = []
    # 修正：只匹配真正的TIFF文件
    pattern = re.compile(r'VWC-(\d{4})(\d{2})(\d{2})\.tif$')  # 添加$确保只匹配.tif结尾
    
    # 列出所有每日文件
    for filename in os.listdir(INPUT_DIR):
        # 跳过辅助文件（.aux.xml等）
        if not filename.endswith('.tif'):
            continue
            
        match = pattern.match(filename)
        if not match:
            continue
            
        file_year = int(match.group(1))
        file_month = int(match.group(2))
        file_day = int(match.group(3))
        
        if file_year == year and file_month == month:
            file_path = os.path.join(INPUT_DIR, filename)
            files.append({
                'path': file_path,
                'date': datetime.date(file_year, file_month, file_day)
            })
    
    # 按日期排序
    files.sort(key=lambda x: x['date'])
    return files

def create_monthly_geotiff(data, output_path, nodata=-9999.0):
    """创建地理参考的月度TIFF文件"""
    try:
        driver = gdal.GetDriverByName('GTiff')
        rows, cols, bands = data.shape
        
        # 创建数据集
        out_ds = driver.Create(
            output_path, 
            cols, 
            rows, 
            bands, 
            gdal.GDT_Float32,
            options=['COMPRESS=LZW']
        )
        
        # 设置地理变换
        out_ds.SetGeoTransform((-180, 0.1, 0, 90, 0, -0.1))
        
        # 设置坐标系 (WGS84)
        srs = osr.SpatialReference()
        srs.ImportFromEPSG(4326)
        out_ds.SetProjection(srs.ExportToWkt())
        
        # 添加日期元数据
        out_ds.SetMetadata({'PRODUCTION_DATE': datetime.date.today().isoformat()})
        
        # 写入每个波段数据
        for band_idx in range(bands):
            band = out_ds.GetRasterBand(band_idx+1)
            band.WriteArray(data[:, :, band_idx])
            band.SetNoDataValue(nodata)
            band.SetDescription(f"Model_{band_idx+1}")
        
        # 清理
        out_ds.FlushCache()
        out_ds = None
        logger.info(f"成功创建月度合成GeoTIFF: {output_path}")
        return True
    except Exception as e:
        logger.error(f"创建GeoTIFF失败: {str(e)}")
        return False

def generate_monthly_composites():
    """生成月度合成数据"""
    total_months = (END_YEAR - START_YEAR + 1) * 12
    processed = 0
    
    # 进度条初始化
    pbar = tqdm(total=total_months, desc="月度合成进度")
    
    for year in range(START_YEAR, END_YEAR + 1):
        for month in range(1, 13):
            # 获取该月的所有每日文件
            files = get_daily_files_for_month(year, month)
            
            if not files:
                logger.warning(f"在{year}年{month}月未找到任何每日文件")
                pbar.update(1)
                processed += 1
                continue
            
            # 输出文件名 - 修正：只包含年月
            output_filename = f'VWC-{year}{str(month).zfill(2)}.tif'
            output_path = os.path.join(OUTPUT_DIR, output_filename)
            
            # 检查文件是否已存在
            if os.path.exists(output_path):
                logger.info(f"月度合成已存在: {output_path}")
                pbar.update(1)
                processed += 1
                continue
            
            logger.info(f"生成{year}年{month}月合成 ({len(files)}个每日文件)")
            
            # 初始化数据立方体 (高度, 宽度, 波段) 和有效计数
            num_bands = None
            monthly_sum = None
            valid_count = None
            rows, cols = 0, 0
            
            # 用于存储每日数据的列表（用于中值合成）
            daily_arrays = {}
            
            # 处理每个文件
            for file_info in files:
                try:
                    # 打开文件 - 添加额外检查确保是TIFF文件
                    if not file_info['path'].endswith('.tif'):
                        logger.warning(f"跳过非TIFF文件: {file_info['path']}")
                        continue
                    
                    ds = gdal.Open(file_info['path'])
                    if ds is None:
                        logger.warning(f"无法打开文件: {file_info['path']}")
                        continue
                    
                    # 获取图像尺寸（只在第一次确定）
                    if rows == 0:
                        rows = ds.RasterYSize
                        cols = ds.RasterXSize
                    
                    # 获取波段数量（只在第一次确定）
                    if num_bands is None:
                        num_bands = ds.RasterCount
                        monthly_sum = np.zeros((rows, cols, num_bands), dtype=np.float32)
                        valid_count = np.zeros((rows, cols, num_bands), dtype=np.uint16)
                        # 初始化每日数据存储
                        daily_arrays = {i: [] for i in range(num_bands)}
                    
                    # 读取每个波段
                    for band_idx in range(num_bands):
                        band = ds.GetRasterBand(band_idx+1)
                        data = band.ReadAsArray()
                        nodata = band.GetNoDataValue()
                        
                        if nodata is None:
                            nodata = -9999.0
                        
                        # 创建有效值掩膜
                        valid_mask = (data != nodata) & (~np.isnan(data))
                        
                        # 累加有效值
                        monthly_sum[:, :, band_idx] += np.where(valid_mask, data, 0)
                        valid_count[:, :, band_idx] += valid_mask.astype(np.uint16)
                        
                        # 收集每日数据用于中值计算
                        daily_arrays[band_idx].append(data)
                    
                    # 关闭数据集
                    ds = None
                    
                except Exception as e:
                    logger.error(f"处理文件 {file_info['path']} 错误: {str(e)}")
            
            # 计算月度平均值 (避免除以零)
            if monthly_sum is not None and valid_count is not None:
                # 计算平均值
                with np.errstate(divide='ignore', invalid='ignore'):
                    monthly_avg = np.where(
                        valid_count > 0,
                        monthly_sum / valid_count,
                        -9999.0  # 无效值
                    )
                
                # 计算中值（更健壮的指标）
                monthly_median = np.full((rows, cols, num_bands), -9999.0, dtype=np.float32)
                for band_idx in range(num_bands):
                    if daily_arrays[band_idx]:
                        # 计算每个像素的中值
                        stack = np.stack(daily_arrays[band_idx], axis=0)
                        valid_stack = np.ma.masked_equal(stack, -9999.0)
                        monthly_median[:, :, band_idx] = np.ma.median(valid_stack, axis=0).filled(-9999.0)
                    
                    # 如果中值计算失败，使用平均值
                    if np.all(monthly_median[:, :, band_idx] == -9999.0):
                        monthly_median[:, :, band_idx] = monthly_avg[:, :, band_idx]
                
                # 保存合成结果 - 使用中值作为更健壮的指标
                success = create_monthly_geotiff(monthly_median, output_path)
                
                if success:
                    # 输出质量报告
                    valid_percent = (valid_count > 0).sum() / (rows * cols * num_bands) * 100
                    logger.info(f"月度合成完成: {year}年{month}月, 有效数据: {valid_percent:.2f}%")
                else:
                    logger.error(f"保存月度合成失败: {output_path}")
            
            # 更新进度条
            pbar.update(1)
            processed += 1
    
    # 关闭进度条
    pbar.close()
    logger.info(f"月度合成处理完成! 共处理 {processed} 个月份")

# ============================== 执行主函数 ==============================
if __name__ == "__main__":
    logger.info(f"输入目录: {INPUT_DIR}")
    logger.info(f"输出目录: {OUTPUT_DIR}")
    logger.info(f"生成月度合成数据: {START_YEAR} 到 {END_YEAR}")
    
    try:
        generate_monthly_composites()
        logger.info("月度合成数据生成完成")
    except Exception as e:
        logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True)

2025-08-20 17:46:33,142 - INFO - 输入目录: E:\data\VWC\VWCMap\Daily
2025-08-20 17:46:33,142 - INFO - 输出目录: E:\data\VWC\VWCMap\Monthly
2025-08-20 17:46:33,142 - INFO - 生成月度合成数据: 2015 到 2015
2025-08-20 17:46:33,157 - INFO - 生成2015年1月合成 (31个每日文件)                                    | 0/12 [00:00<?, ?it/s]
2025-08-20 17:48:35,343 - INFO - 成功创建月度合成GeoTIFF: E:\data\VWC\VWCMap\Monthly\VWC-201501.tif
2025-08-20 17:48:35,433 - INFO - 月度合成完成: 2015年1月, 有效数据: 10.55%
2025-08-20 17:48:35,438 - INFO - 生成2015年2月合成 (28个每日文件)                           | 1/12 [02:02<22:25, 122.28s/it]
2025-08-20 17:50:20,923 - INFO - 成功创建月度合成GeoTIFF: E:\data\VWC\VWCMap\Monthly\VWC-201502.tif
2025-08-20 17:50:21,012 - INFO - 月度合成完成: 2015年2月, 有效数据: 10.93%
2025-08-20 17:50:21,017 - INFO - 生成2015年3月合成 (31个每日文件)                           | 2/12 [03:47<18:44, 112.46s/it]
2025-08-20 17:52:20,117 - INFO - 成功创建月度合成GeoTIFF: E:\data\VWC\VWCMap\Monthly\VWC-201503.tif
2025-08-20 17:52:20,209 - INFO - 月度合成完成: 2015年3月, 有效数据: 13.51%
2025-08-

## VWC季节性制图 

In [3]:
import os
import numpy as np
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
from osgeo import gdal, osr
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.gridspec as gridspec
from cartopy.feature import NaturalEarthFeature
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
import matplotlib as mpl
import logging

# ==========================================================
# 配置参数
# ==========================================================
VWC_DIR = r"E:\data\VWC\VWCMap\Monthly"
SPECIFIC_MONTHS = {
    "2015-01": "201501",
    "2015-04": "201504",
    "2015-07": "201507",
    "2015-10": "201510"
}
BAND = 1  # 读取第一个波段(KuH)
OUTPUT_DIR = r"E:\文章\HUITU\Fig"
OUTPUT_NAME = "Global_VWC_KuH_Seasonal_2015_Primary"

# 掩膜相关参数
MAIN_TYPE_FILE = r'E:\data\ESACCI PFT\Resample\Data\mainType.tif'
MASK_TYPES = [0, 1, 2]  # 0=water, 1=bare, 2=snowice
NODATA_VALUE = -9999.0

# ==========================================================
# 掩膜处理函数（使用修复后的方法）
# ==========================================================
def load_main_type_mask():
    """加载主地物类型掩膜（含空间参考校验）"""
    try:
        ds = gdal.Open(MAIN_TYPE_FILE, gdal.GA_ReadOnly)
        if ds is None:
            raise ValueError(f"无法打开主地物类型文件: {MAIN_TYPE_FILE}")
        
        band = ds.GetRasterBand(1)
        main_type = band.ReadAsArray()
        
        # 创建掩膜（需要掩膜的类型为True）
        mask = np.isin(main_type, MASK_TYPES)
        
        # 获取NoData值并处理
        nodata = band.GetNoDataValue()
        if nodata is not None:
            mask = np.logical_and(mask, main_type != nodata)
        
        # 获取空间参考信息
        geotransform = ds.GetGeoTransform()
        projection = ds.GetProjection()
        ds = None
        
        print("成功加载主地物类型掩膜")
        return mask, geotransform, projection
    except Exception as e:
        print(f"加载主地物类型文件失败: {str(e)}")
        return None, None, None

def apply_mask_to_vwc_file(file_path, mask, ref_geotransform, ref_projection):
    """对单个VWC文件应用掩膜（仅处理第一个波段）"""
    try:
        print(f"正在掩膜处理: {os.path.basename(file_path)}")
        
        # 以读写模式打开文件
        ds = gdal.Open(file_path, gdal.GA_Update)
        if ds is None:
            raise ValueError(f"无法打开VWC文件: {file_path}")
        
        # 验证空间参考是否匹配
        if (ds.GetGeoTransform() != ref_geotransform or 
            ds.GetProjection() != ref_projection):
            raise ValueError("空间参考不匹配，请确保文件投影一致")
        
        # 仅处理第一个波段
        band = ds.GetRasterBand(BAND)
        data = band.ReadAsArray()
        
        # 应用掩膜
        data[mask] = NODATA_VALUE
        
        # 写回数据并设置NoData属性
        band.WriteArray(data)
        band.SetNoDataValue(NODATA_VALUE)
        
        # 强制写入磁盘
        ds.FlushCache()
        ds = None
        
        print(f"成功掩膜处理: {os.path.basename(file_path)}")
        return True
    except Exception as e:
        print(f"掩膜处理失败: {os.path.basename(file_path)}，错误: {str(e)}")
        return False

# ==========================================================
# 数据读取函数（保持不变）
# ==========================================================
def read_vwc_tif(file_path, band=1, no_data=-9999):
    """读取VWC TIFF文件并处理无效值"""
    ds = gdal.Open(file_path)
    if ds is None:
        raise FileNotFoundError(f"Cannot open file: {file_path}")
    
    # 读取指定波段
    band = ds.GetRasterBand(band)
    data = band.ReadAsArray()
    
    # 获取地理信息
    geotransform = ds.GetGeoTransform()
    projection = ds.GetProjection()
    
    # 计算地理范围
    x_size = ds.RasterXSize
    y_size = ds.RasterYSize
    lon_min = geotransform[0]
    lat_max = geotransform[3]
    lon_max = lon_min + geotransform[1] * x_size
    lat_min = lat_max + geotransform[5] * y_size
    
    # 处理无效值
    data = data.astype(np.float32)
    data[data == no_data] = np.nan
    
    ds = None
    return data, (lon_min, lon_max, lat_min, lat_max), projection

# ==========================================================
# 创建自定义颜色映射（保持不变）
# ==========================================================
def create_custom_cmap():
    """创建自定义颜色映射"""
    # 使用指定的五种颜色
    colors = [
        '#fe3c19',  # 0 kg/m²
        '#ffac18',  # 5 kg/m²
        '#f2fe2a',  # 10 kg/m²
        '#7cb815',  # 15 kg/m²
        '#147218'   # 20 kg/m²
    ]
    
    # 创建颜色映射
    cmap = LinearSegmentedColormap.from_list('custom_vwc', colors, N=256)
    return cmap

# ==========================================================
# 地图绘制函数（保持不变）
# ==========================================================
def plot_vwc_map(ax, data, extent, month_label, vmin=0, vmax=20):
    """在指定轴对象上绘制VWC地图"""
    # 创建自定义颜色映射
    cmap = create_custom_cmap()
    
    # 添加地图特征
    ax.coastlines(linewidth=0.5, color='gray')
    ax.add_feature(NaturalEarthFeature(category='physical', name='ocean', scale='50m', 
                                      facecolor='lightblue', alpha=0.3))
    ax.add_feature(NaturalEarthFeature(category='cultural', name='admin_0_countries', 
                                      scale='50m', edgecolor='gray', facecolor='none', linewidth=0.3))
    
    # 添加经纬度网格
    gl = ax.gridlines(draw_labels=True, linestyle='--', alpha=0.7)
    gl.top_labels = False
    gl.right_labels = False
    gl.xformatter = LONGITUDE_FORMATTER
    gl.yformatter = LATITUDE_FORMATTER
    
    # 绘制VWC数据
    im = ax.imshow(data, origin='upper', 
                  extent=extent,
                  transform=ccrs.PlateCarree(),
                  cmap=cmap, vmin=vmin, vmax=vmax,
                  interpolation='nearest')
    
    # 添加标题 (格式为YYYY-MM)
    ax.set_title(f"{month_label}", fontsize=12, pad=10)
    
    return im

# ==========================================================
# 主执行函数（添加掩膜处理步骤）
# ==========================================================
def main():
    print("======== 开始掩膜处理 ========")
    
    # 加载主地物类型掩膜
    mask, ref_geotransform, ref_projection = load_main_type_mask()
    if mask is None:
        print("错误：无法加载主掩膜，程序终止")
        return
    
    # 检查掩膜尺寸
    print(f"掩膜尺寸: {mask.shape[0]}x{mask.shape[1]}")
    
    # 对SPECIFIC_MONTHS中的每个文件应用掩膜
    for label, month_code in SPECIFIC_MONTHS.items():
        file_path = os.path.join(VWC_DIR, f"VWC-{month_code}.tif")
        if os.path.exists(file_path):
            apply_mask_to_vwc_file(file_path, mask, ref_geotransform, ref_projection)
        else:
            print(f"警告: 文件不存在 - {file_path}")
    
    print("======== 掩膜处理完成 ========")
    print("开始读取特定月度VWC数据...")
    
    # 准备存储数据
    vwc_data = []
    
    # 读取已掩膜处理的数据
    for label, month_code in SPECIFIC_MONTHS.items():
        file_path = os.path.join(VWC_DIR, f"VWC-{month_code}.tif")
        print(f"读取掩膜后的文件: {file_path}")
        
        try:
            data, extent, projection = read_vwc_tif(file_path, band=BAND, no_data=NODATA_VALUE)
            vwc_data.append({
                "data": data,
                "extent": extent,
                "label": label  # 使用YYYY-MM格式的标签
            })
        except Exception as e:
            print(f"错误读取文件: {file_path}\n{str(e)}")
            continue
    
    # 检查数据完整性
    if not vwc_data:
        print("错误: 无有效数据，无法绘制")
        return
    
    # 创建图形
    fig = plt.figure(figsize=(14, 12))
    
    # 创建2x2网格布局
    gs = gridspec.GridSpec(2, 2, figure=fig, 
                          wspace=0.1, hspace=0.15,
                          top=0.95, bottom=0.1,
                          left=0.05, right=0.95)
    
    axs = [
        fig.add_subplot(gs[0], projection=ccrs.PlateCarree()),
        fig.add_subplot(gs[1], projection=ccrs.PlateCarree()),
        fig.add_subplot(gs[2], projection=ccrs.PlateCarree()),
        fig.add_subplot(gs[3], projection=ccrs.PlateCarree())
    ]
    
    # 绘制四个季度的地图
    images = []
    for i in range(4):
        if i < len(vwc_data):
            ax = axs[i]
            data = vwc_data[i]['data']
            extent = vwc_data[i]['extent']
            label = vwc_data[i]['label']
            
            # 设置全球视图
            ax.set_global()
            
            # 绘制地图
            img = plot_vwc_map(ax, data, extent, label)
            images.append(img)
    
    # 添加共享颜色条
    cax = fig.add_axes([0.25, 0.05, 0.5, 0.02])
    
    # 创建颜色条，使用自定义颜色映射
    norm = mpl.colors.Normalize(vmin=0, vmax=20)
    cmap = create_custom_cmap()
    
    # 创建颜色条标签
    ticks = [0, 5, 10, 15, 20]  # 对应五种颜色的位置
    cbar = fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap),
                     cax=cax, orientation='horizontal',
                     ticks=ticks)
    cbar.set_label('VWC (kg/m²)', fontsize=11)
    
    # 添加主标题
    plt.suptitle("Global VWC Map - 2015", 
                 fontsize=14, y=0.98)
    
    # 保存图像（路径不变）
    output_path = os.path.join(OUTPUT_DIR, f"{OUTPUT_NAME}.png")
    plt.savefig(output_path, bbox_inches='tight', dpi=300)
    
    print(f"结果已保存至: {output_path}")
    plt.close()

if __name__ == "__main__":
    main()

成功加载主地物类型掩膜
掩膜尺寸: 1800x3600
正在掩膜处理: VWC-201501.tif
成功掩膜处理: VWC-201501.tif
正在掩膜处理: VWC-201504.tif
成功掩膜处理: VWC-201504.tif
正在掩膜处理: VWC-201507.tif
成功掩膜处理: VWC-201507.tif
正在掩膜处理: VWC-201510.tif
成功掩膜处理: VWC-201510.tif
开始读取特定月度VWC数据...
读取掩膜后的文件: E:\data\VWC\VWCMap\Monthly\VWC-201501.tif
读取掩膜后的文件: E:\data\VWC\VWCMap\Monthly\VWC-201504.tif
读取掩膜后的文件: E:\data\VWC\VWCMap\Monthly\VWC-201507.tif
读取掩膜后的文件: E:\data\VWC\VWCMap\Monthly\VWC-201510.tif
结果已保存至: E:\文章\HUITU\Fig\Global_VWC_KuH_Seasonal_2015_Primary.png


# 总结：

# 补一下第一块数据的验证

## 散点图输出

In [9]:
# 散点图（4个数据画在一块，写出n，按照波段-极化组合绘制为3 * 3）
# 点形状及颜色：
# SMEX02：*；CLASIC07：^；SMAPVEX08：+；SMAPVEX16：o

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib as mpl
import matplotlib.font_manager as fm
import joblib
import os
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')
from sklearn.metrics import mean_squared_error

# 常量定义
BANDS = ['Ku', 'X', 'C']
POLS = ['H', 'V', 'HV']
SHEET_NAMES = ['SMEX02', 'CLASIC07', 'SMAPVEX08', 'SMAPVEX16']
VWC_COLUMNS = {
    'SMEX02': 'VWC-Field',
    'CLASIC07': 'VWC (kg/m²)',
    'SMAPVEX08': 'VWC',
    'SMAPVEX16': 'PLANT_WATER_CONTENT_AREA'
}

# 标记和颜色设置
MARKER_STYLES = {
    'SMEX02': {'marker': '*', 'color': '#F8766D'},
    'CLASIC07': {'marker': '^', 'facecolor': 'none', 'edgecolor': '#00BFC4'},
    'SMAPVEX08': {'marker': '+', 'color': '#C77CFF'},
    'SMAPVEX16': {'marker': 'o', 'facecolor': 'none', 'edgecolor': '#7CAE00'}
}

# 设置全局字体
plt.rcParams['font.family'] = 'Arial'
plt.rcParams['font.weight'] = 'bold'

def load_and_preprocess_data(file_path):
    """
    加载并预处理Excel文件中的所有sheet
    
    参数:
    file_path (str): Excel文件路径
    
    返回:
    dict: 包含预处理后数据的字典，键为sheet名称
    """
    print(f"加载文件: {file_path}")
    data_dict = {}
    
    # 定义所有必需的PFT特征
    pft_features = [
        'Grass_man', 'Grass_nat',
        'Shrub_bd', 'Shrub_be', 'Shrub_nd', 'Shrub_ne',
        'Tree_bd', 'Tree_be', 'Tree_nd', 'Tree_ne'
    ]
    
    # 定义位置和日期特征
    location_features = ['Center_Latitude', 'Center_Longitude', 'Date']
    
    # 定义其他必需特征
    required_features = ['LAI', 'ku_vod_H', 'ku_vod_V', 'x_vod_H', 'x_vod_V', 'c_vod_H', 'c_vod_V']
    
    for sheet in SHEET_NAMES:
        try:
            df = pd.read_excel(file_path, sheet_name=sheet)
            print(f"  - {sheet}: {len(df)}行")
            
            # ========== 确保位置和日期特征存在 ==========
            for feature in location_features:
                if feature not in df.columns:
                    df[feature] = np.nan  # 初始化为NaN
                    print(f"    创建列 {feature} 并初始化为NaN")
            
            # ========== 确保所有必需特征都存在 ==========
            # 1. 创建缺失的PFT特征并设置值
            for feature in pft_features:
                if feature not in df.columns:
                    df[feature] = 0.0  # 初始化为0
                    print(f"    创建列 {feature} 并初始化为0")
            
            # 设置'Grass_man'为1，其他PFT特征为0
            df['Grass_man'] = 1.0
            for feature in pft_features[1:]:  # 跳过Grass_man
                df[feature] = 0.0
            print(f"    设置Grass_man=1, 其他PFT特征=0")
            
            # 2. 确保Hveg_Satellite存在
            if 'Hveg_Satellite' not in df.columns:
                df['Hveg_Satellite'] = 0.0  # 初始化为0
                print(f"    创建列 Hveg_Satellite 并初始化为0")
            
            # 3. 确保其他必需特征存在
            for feature in required_features:
                if feature not in df.columns:
                    df[feature] = 0.0  # 初始化为0
                    print(f"    创建列 {feature} 并初始化为0")
            
            # ========== 替换卫星数据（如果存在实测数据） ==========            
            if 'LAI' in df.columns:
                mask = df['LAI'].notna()
                df.loc[mask, 'LAI_Satellite'] = df.loc[mask, 'LAI']
                print(f"    替换了 {mask.sum()} 行LAI_Satellite数据")
            
            data_dict[sheet] = df
        except Exception as e:
            print(f"  加载 {sheet} 时出错: {str(e)}")
            data_dict[sheet] = pd.DataFrame()
    
    return data_dict

def get_features_for_model(band, pol):
    """
    根据波段和极化类型获取特征列表（使用模型训练时的名称）
    
    参数:
    band (str): 波段 ('Ku', 'X', 'C')
    pol (str): 极化类型 ('H', 'V', 'HV')
    
    返回:
    list: 特征列名列表
    """
    # 使用模型训练时的特征名称（添加Hveg变量）
    features = [
        'LAI',  # 注意：训练时使用"LAI"而不是"LAI_Satellite"
        'Hveg', # 添加Hveg变量
        'Grass_man', 
        'Grass_nat',
        'Shrub_bd', 'Shrub_be', 'Shrub_nd', 'Shrub_ne',
        'Tree_bd', 'Tree_be', 'Tree_nd', 'Tree_ne'
    ]
    
    # 添加VOD特征 - 根据模型类型
    if pol == 'H' or pol == 'V':
        # 单极化模型使用"VOD"
        features.append('VOD')
    elif pol == 'HV':
        # 双极化模型使用"VOD-Hpol"和"VOD-Vpol"
        features.extend(['VOD-Hpol', 'VOD-Vpol'])
    
    return features

def predict_vwc(data_dict, band, pol):
    """
    使用指定模型预测VWC
    
    参数:
    data_dict (dict): 包含所有sheet数据的字典
    band (str): 波段 ('Ku', 'X', 'C')
    pol (str): 极化类型 ('H', 'V', 'HV')
    
    返回:
    dict: 包含每个sheet预测结果的字典
    """
    # 加载模型 - 使用新模型命名规则（添加Hveg后缀）
    model_path = f"models/RFR_{band}_{pol}-pol_Type1_Primary.pkl"
    print(f"加载模型: {model_path}")
    
    if not os.path.exists(model_path):
        print(f"  模型文件不存在: {model_path}")
        return {}
    
    try:
        model = joblib.load(model_path)
        # 打印模型训练时的特征名称（如果可用）
        if hasattr(model, 'feature_names_in_'):
            print(f"  模型训练特征: {list(model.feature_names_in_)}")
    except Exception as e:
        print(f"  加载模型失败: {str(e)}")
        return {}
    
    # 获取特征列表
    features = get_features_for_model(band, pol)
    
    # 存储预测结果
    predictions = {}
    
    for sheet, df in data_dict.items():
        if df.empty:
            continue
        
        # 创建特征映射（将数据列名映射到模型期望的特征名）
        feature_mapping = {}
        for feature in features:
            # 特殊处理VOD特征
            if feature == 'VOD':
                # 单极化模型
                if pol == 'H':
                    if band == 'Ku':
                        feature_mapping['ku_vod_H'] = 'VOD'
                    elif band == 'X':
                        feature_mapping['x_vod_H'] = 'VOD'
                    elif band == 'C':
                        feature_mapping['c_vod_H'] = 'VOD'
                elif pol == 'V':
                    if band == 'Ku':
                        feature_mapping['ku_vod_V'] = 'VOD'
                    elif band == 'X':
                        feature_mapping['x_vod_V'] = 'VOD'
                    elif band == 'C':
                        feature_mapping['c_vod_V'] = 'VOD'
            elif feature == 'VOD-Hpol':
                # 双极化模型中的H极化
                if band == 'Ku':
                    feature_mapping['ku_vod_H'] = 'VOD-Hpol'
                elif band == 'X':
                    feature_mapping['x_vod_H'] = 'VOD-Hpol'
                elif band == 'C':
                    feature_mapping['c_vod_H'] = 'VOD-Hpol'
            elif feature == 'VOD-Vpol':
                # 双极化模型中的V极化
                if band == 'Ku':
                    feature_mapping['ku_vod_V'] = 'VOD-Vpol'
                elif band == 'X':
                    feature_mapping['x_vod_V'] = 'VOD-Vpol'
                elif band == 'C':
                    feature_mapping['c_vod_V'] = 'VOD-Vpol'
            # 添加Hveg特征映射
            elif feature == 'Hveg':
                feature_mapping['Hveg_Satellite'] = 'Hveg'
            # 其他特征映射
            elif feature == 'LAI':
                feature_mapping['LAI_Satellite'] = 'LAI'
            elif feature == 'Grass_man':
                feature_mapping['Grass_man'] = 'Grass_man'
            elif feature == 'Grass_nat':
                feature_mapping['Grass_nat'] = 'Grass_nat'
            elif feature == 'Shrub_bd':
                feature_mapping['Shrub_bd'] = 'Shrub_bd'
            elif feature == 'Shrub_be':
                feature_mapping['Shrub_be'] = 'Shrub_be'
            elif feature == 'Shrub_nd':
                feature_mapping['Shrub_nd'] = 'Shrub_nd'
            elif feature == 'Shrub_ne':
                feature_mapping['Shrub_ne'] = 'Shrub_ne'
            elif feature == 'Tree_bd':
                feature_mapping['Tree_bd'] = 'Tree_bd'
            elif feature == 'Tree_be':
                feature_mapping['Tree_be'] = 'Tree_be'
            elif feature == 'Tree_nd':
                feature_mapping['Tree_nd'] = 'Tree_nd'
            elif feature == 'Tree_ne':
                feature_mapping['Tree_ne'] = 'Tree_ne'
        
        # 检查是否包含所有必要特征
        missing_features = []
        for data_feature in feature_mapping.keys():
            if data_feature not in df.columns:
                missing_features.append(data_feature)
        
        if missing_features:
            print(f"  {sheet} 缺少特征: {', '.join(missing_features)}")
            continue
        
        # 准备数据（使用重命名的特征）
        X = df[list(feature_mapping.keys())].copy()
        X.columns = [feature_mapping[col] for col in X.columns]
        
        # 确保特征顺序与模型期望一致
        if hasattr(model, 'feature_names_in_'):
            X = X[list(model.feature_names_in_)]
        
        # 移除缺失值
        initial_count = len(X)
        X = X.dropna()
        removed_count = initial_count - len(X)
        if removed_count > 0:
            print(f"  {sheet} 移除了 {removed_count} 行包含缺失值的数据")
        
        if X.empty:
            print(f"  {sheet} 无有效数据可用于预测")
            continue
        
        # 预测VWC
        y_pred = model.predict(X)
        predictions[sheet] = {
            'actual': df.loc[X.index, VWC_COLUMNS[sheet]],
            'predicted': y_pred,
            'source': sheet,
            'lat': df.loc[X.index, 'Center_Latitude'],
            'lon': df.loc[X.index, 'Center_Longitude'],
            'date': df.loc[X.index, 'Date']
        }
        print(f"  {sheet} 预测完成: {len(y_pred)} 个样本")
    
    return predictions

def calculate_rmse(actual, predicted):
    """
    计算RMSE
    
    参数:
    actual (array-like): 实际值
    predicted (array-like): 预测值
    
    返回:
    float: RMSE值
    """
    return np.sqrt(np.mean((actual - predicted)**2))

def create_scatter_plots(all_predictions):
    """
    创建3x3散点子图
    
    参数:
    all_predictions (dict): 包含所有波段和极化组合预测结果的字典
    """
    print("创建散点图...")
    
    # 创建图形
    fig = plt.figure(figsize=(18, 18))
    gs = gridspec.GridSpec(3, 3, figure=fig)
    
    # 遍历所有波段和极化组合
    for i, band in enumerate(BANDS):
        for j, pol in enumerate(POLS):
            ax = fig.add_subplot(gs[i, j])
            
            # 获取当前组合的预测结果
            predictions = all_predictions.get((band, pol), {})
            
            # 收集所有数据点
            all_actual = []
            all_predicted = []
            
            # 绘制每个sheet的数据点
            for sheet in SHEET_NAMES:
                if sheet in predictions:
                    actual = predictions[sheet]['actual']
                    predicted = predictions[sheet]['predicted']
                    
                    # 添加到总集合
                    all_actual.extend(actual)
                    all_predicted.extend(predicted)
                    
                    # 绘制当前sheet的点
                    if sheet in ['CLASIC07', 'SMAPVEX16']:
                        # 对CLASIC07、SMAPVEX16特殊处理：空心
                        ax.scatter(
                            actual, predicted,
                            marker=MARKER_STYLES[sheet]['marker'],
                            facecolor=MARKER_STYLES[sheet]['facecolor'],  # 内部无填充
                            edgecolor=MARKER_STYLES[sheet]['edgecolor'],  # 使用边缘颜色
                            s=50,
                            alpha=0.7,
                            linewidths=1.0,  # 确保边框可见
                            label=sheet
                        )
                    else:
                        # 其他数据集保持原样
                        ax.scatter(
                            actual, predicted,
                            marker=MARKER_STYLES[sheet]['marker'],
                            color=MARKER_STYLES[sheet].get('color', MARKER_STYLES[sheet].get('edgecolor', None)),
                            s=50,
                            alpha=0.7,
                            label=sheet
                        )
            
            # 如果没有数据，跳过
            if not all_actual:
                ax.text(0.5, 0.5, '无数据', 
                        horizontalalignment='center', 
                        verticalalignment='center', 
                        transform=ax.transAxes,
                        fontsize=16)
                ax.set_title(f"{band}-{pol}", fontsize=16, fontweight='bold')
                continue
            
            # 计算整体RMSE
            rmse = calculate_rmse(np.array(all_actual), np.array(all_predicted))
            
            # 添加1:1参考线
            max_val = max(max(all_actual), max(all_predicted)) * 1.05
            ax.plot([0, max_val], [0, max_val], 'k--', lw=1.5, alpha=0.7)
            
            # 设置坐标轴范围
            ax.set_xlim(0, max_val)
            ax.set_ylim(0, max_val)
            
            # 设置坐标轴标签
            if i == 2:  # 最后一行
                ax.set_xlabel('Insitu VWC (kg/m²)', fontsize=14, fontweight='bold')
            if j == 0:  # 第一列
                ax.set_ylabel('Predicted VWC (kg/m²)', fontsize=14, fontweight='bold')
            
            # 添加标题和RMSE
            ax.set_title(f"{band}-{pol}", fontsize=16, fontweight='bold')
            ax.text(0.05, 0.95, f"RMSE: {rmse:.3f} kg/m²", 
                    transform=ax.transAxes,
                    fontsize=16,
                    fontweight='bold',
                    verticalalignment='top')
            
            # 添加样本数量
            ax.text(0.05, 0.88, f"n = {len(all_actual)}", 
                    transform=ax.transAxes,
                    fontsize=14,
                    fontweight='bold',
                    verticalalignment='top')
            
            # 添加网格
            ax.grid(True, linestyle='--', alpha=0.3)
    
    # 添加图例
    handles, labels = [], []
    for sheet in SHEET_NAMES:
        style = MARKER_STYLES[sheet]
        
        if sheet in ['CLASIC07', 'SMAPVEX16']:
            # 为CLASIC07、SMAPVEX16创建空心图例
            handles.append(plt.Line2D([0], [0], 
                                     marker=style['marker'], 
                                     color='w',
                                     markerfacecolor=style['facecolor'],  # 内部白色
                                     markeredgecolor=style['edgecolor'],  # 边缘颜色
                                     markersize=10,
                                     markeredgewidth=1.0))  # 边框宽度
        else:
            handles.append(plt.Line2D([0], [0], 
                                     marker=style['marker'], 
                                     color='w', 
                                     markerfacecolor=style.get('color', style.get('edgecolor')),
                                     markeredgecolor=style.get('color', style.get('edgecolor')), 
                                     markersize=10))
        labels.append(sheet)
    
    fig.legend(handles, labels, 
               loc='lower center', 
               ncol=4, 
               fontsize=12,
               frameon=True,
               fancybox=True,
               shadow=True,
               bbox_to_anchor=(0.5, 0.02))
    
    # 调整布局
    plt.tight_layout(rect=[0, 0.05, 1, 0.95])
    
    # 保存图像 - 使用新路径
    fig_path = "figures/AllSMAPInsituData_PointVWC_Scatter_Primary.png"
    os.makedirs(os.path.dirname(fig_path), exist_ok=True)
    plt.savefig(fig_path, dpi=300, bbox_inches='tight')
    print(f"散点图已保存至: {fig_path}")
    plt.close()

def save_prediction_details(all_predictions):
    """
    将预测结果保存到Excel文件中
    
    参数:
    all_predictions (dict): 包含所有波段和极化组合预测结果的字典
    """
    output_dir = Path(r"E:\data\VWC\test-VWC\SMEX02_CLASIC07_SMEX08_SMAPVEX16")
    output_file = output_dir / "details_Point_Primary.xlsx"  # 修改输出文件名
    
    # 创建Excel写入器
    with pd.ExcelWriter(output_file, engine='openpyxl') as writer:
        # 遍历所有波段和极化组合
        for (band, pol), predictions in all_predictions.items():
            if not predictions:
                continue
                
            # 创建当前组合的数据框
            all_data = []
            
            # 收集所有sheet的数据
            for sheet, data in predictions.items():
                # 创建当前sheet的数据框
                sheet_df = pd.DataFrame({
                    'Date': data['date'],
                    'Latitude': data['lat'],
                    'Longitude': data['lon'],
                    'Actual_VWC': data['actual'],
                    'Predicted_VWC': data['predicted'],
                    'Source': data['source']
                })
                
                # 添加波段和极化信息
                sheet_df['Band'] = band
                sheet_df['Polarization'] = pol
                
                all_data.append(sheet_df)
            
            # 合并所有数据
            if all_data:
                combined_df = pd.concat(all_data, ignore_index=True)
                
                # 保存到Excel
                sheet_name = f"{band}_{pol}"
                combined_df.to_excel(writer, sheet_name=sheet_name, index=False)
                print(f"保存预测结果到: {sheet_name} ({len(combined_df)}行)")
    
    print(f"所有预测结果已保存至: {output_file}")

def main():
    # 输入文件路径 - 使用新路径
    input_file = r"E:\data\VWC\test-VWC\SMEX02_CLASIC07_SMEX08_SMAPVEX16\InsituData_Pixel_ML.xlsx"
    
    # 加载并预处理数据
    data_dict = load_and_preprocess_data(input_file)
    
    # 存储所有预测结果
    all_predictions = {}
    
    # 遍历所有波段和极化组合
    for band in BANDS:
        for pol in POLS:
            print(f"\n处理波段-极化组合: {band}-{pol}")
            predictions = predict_vwc(data_dict, band, pol)
            all_predictions[(band, pol)] = predictions
    
    # 创建散点图
    create_scatter_plots(all_predictions)
    
    # 保存预测结果到Excel
    save_prediction_details(all_predictions)
    
    print("\n处理完成!")

if __name__ == "__main__":
    main()

加载文件: E:\data\VWC\test-VWC\SMEX02_CLASIC07_SMEX08_SMAPVEX16\InsituData_Pixel_ML.xlsx
  - SMEX02: 16行
    创建列 Grass_man 并初始化为0
    创建列 Grass_nat 并初始化为0
    创建列 Shrub_bd 并初始化为0
    创建列 Shrub_be 并初始化为0
    创建列 Shrub_nd 并初始化为0
    创建列 Shrub_ne 并初始化为0
    创建列 Tree_bd 并初始化为0
    创建列 Tree_be 并初始化为0
    创建列 Tree_nd 并初始化为0
    创建列 Tree_ne 并初始化为0
    设置Grass_man=1, 其他PFT特征=0
    创建列 LAI 并初始化为0
    替换了 16 行LAI_Satellite数据
  - CLASIC07: 18行
    创建列 Grass_man 并初始化为0
    创建列 Grass_nat 并初始化为0
    创建列 Shrub_bd 并初始化为0
    创建列 Shrub_be 并初始化为0
    创建列 Shrub_nd 并初始化为0
    创建列 Shrub_ne 并初始化为0
    创建列 Tree_bd 并初始化为0
    创建列 Tree_be 并初始化为0
    创建列 Tree_nd 并初始化为0
    创建列 Tree_ne 并初始化为0
    设置Grass_man=1, 其他PFT特征=0
    创建列 LAI 并初始化为0
    替换了 18 行LAI_Satellite数据
  - SMAPVEX08: 6行
    创建列 Grass_man 并初始化为0
    创建列 Grass_nat 并初始化为0
    创建列 Shrub_bd 并初始化为0
    创建列 Shrub_be 并初始化为0
    创建列 Shrub_nd 并初始化为0
    创建列 Shrub_ne 并初始化为0
    创建列 Tree_bd 并初始化为0
    创建列 Tree_be 并初始化为0
    创建列 Tree_nd 并初始化为0
    创建列 Tree_ne 并初始化

# 中国那两个数据的验证

## 模型预测值、保存结果，绘制折线图

In [10]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import matplotlib.gridspec as gridspec
import joblib
import os
from pathlib import Path
import warnings
from datetime import datetime
from sklearn.metrics import mean_squared_error, r2_score
from scipy.interpolate import make_interp_spline  # 导入样条插值函数
warnings.filterwarnings('ignore')

# 设置全局字体
plt.rcParams['font.family'] = 'Arial'
plt.rcParams['font.weight'] = 'bold'
plt.rcParams['axes.labelweight'] = 'bold'
plt.rcParams['axes.titleweight'] = 'bold'
plt.rcParams['figure.titlesize'] = 16
plt.rcParams['figure.titleweight'] = 'bold'

# 常量定义
BANDS = ['Ku', 'X', 'C']
BAND_COLORS = {
    'Ku': '#1f77b4',  # 蓝色
    'X': '#ff7f0e',   # 橙色
    'C': '#2ca02c'    # 绿色
}
POLS = ['H', 'V', 'HV']
POL_LINESTYLES = {
    'H': '-',     # 实线
    'V': '--',    # 虚线
    'HV': ':'     # 点线
}
POL_MARKERS = {
    'H': '+',  # 加号
    'V': '^',  # 三角形
    'HV': 's'  # 正方形
}
POL_LABELS = {
    'H': 'H-Pol',
    'V': 'V-Pol',
    'HV': 'H&V-Pol'
}

# 植被类型映射
VEGETATION_TYPES = {
    'CornVegMeasured': 'Corn (2017)',
    'OatVegMeasured': 'Oat (2017)',
    'GrassVWC': 'Grass (2018)'
}

# 实测与拟合数据映射
FITTING_MAPPING = {
    'CornVegMeasured': 'CornVegFitting',
    'OatVegMeasured': 'OatVegFitting',
    'GrassVWC': 'GrassVWC'  # 2018年没有拟合数据
}

# 实测VWC列名映射
ACTUAL_COL_MAPPING = {
    'CornVegMeasured': 'total_VWC(kg/m2)',
    'OatVegMeasured': 'total_VWC(kg/m2)',
    'GrassVWC': 'vegetation water content(kg/m2)'
}

# 实测数据样式
ACTUAL_STYLE = {
    'color': 'black',
    'marker': 'o',
    'markersize': 8,
    'markerfacecolor': 'none',
    'markeredgewidth': 1.5,
    'label': 'Measured'
}

def load_data(file_path):
    """加载Excel文件中的所有工作表"""
    print(f"加载文件: {file_path}")
    data_dict = {}
    
    # 获取所有工作表名称
    xl = pd.ExcelFile(file_path)
    sheet_names = xl.sheet_names
    
    for sheet in sheet_names:
        try:
            df = pd.read_excel(file_path, sheet_name=sheet)
            print(f"  - {sheet}: {len(df)}行")
            
            # 确保日期是datetime类型
            if 'Date' in df.columns:
                df['Date'] = pd.to_datetime(df['Date'])
            
            data_dict[sheet] = df
        except Exception as e:
            print(f"  加载 {sheet} 时出错: {str(e)}")
            data_dict[sheet] = pd.DataFrame()
    
    return data_dict

def predict_vwc_for_sheet(df, band, pol):
    """
    使用机器学习模型预测VWC，确保特征名称匹配
    """
    # 加载模型 - 使用新模型名称（添加Hveg后缀）
    model_path = f"models/RFR_{band}_{pol}-pol_Type1_Primary.pkl"
    if not os.path.exists(model_path):
        print(f"警告: 模型文件不存在: {model_path}")
        return pd.Series(np.nan, index=df.index)
    
    try:
        model = joblib.load(model_path)
        print(f"加载模型: {model_path}")
        
        # 获取模型期望的特征名称
        if hasattr(model, 'feature_names_in_'):
            expected_features = list(model.feature_names_in_)
            print(f"  模型期望特征: {expected_features}")
        else:
            print("  警告: 模型没有feature_names_in_属性")
            expected_features = []
    except Exception as e:
        print(f"加载模型失败: {str(e)}")
        return pd.Series(np.nan, index=df.index)
    
    # 1. 优先使用地面实测数据替换卫星数据
    
    # LAI替换
    if 'LAI' in df.columns:
        lai_mask = df['LAI'].notna() & (df['LAI'] > 0)
        if lai_mask.any():
            df.loc[lai_mask, 'LAI_Satellite'] = df.loc[lai_mask, 'LAI']
            print(f"  使用实测LAI替换了 {lai_mask.sum()} 行数据")
    
    # Hveg替换（仅Duolun表格）
    if 'Height(cm)' in df.columns:
        hveg_mask = df['Height(cm)'].notna() & (df['Height(cm)'] > 0)
        if hveg_mask.any():
            # 将高度从cm转换为m（除以100）
            df.loc[hveg_mask, 'Hveg_Satellite'] = df.loc[hveg_mask, 'Height(cm)'] / 100.0
            print(f"  使用实测Hveg替换了 {hveg_mask.sum()} 行数据（高度单位转换为m）")
    
    # 2. 根据波段和极化组合确定特征映射
    feature_mapping = {}
    
    # Ku波段
    if band == 'Ku':
        if pol == 'H':
            feature_mapping = {
                'ku_vod_H': 'VOD',
                'LAI_Satellite': 'LAI',
                'Hveg_Satellite': 'Hveg',  # 添加Hveg变量
                'grassman': 'Grass_man',
                'grassnat': 'Grass_nat',
                'shrubbd': 'Shrub_bd',
                'shrubbe': 'Shrub_be',
                'shrubnd': 'Shrub_nd',
                'shrubne': 'Shrub_ne',
                'treebd': 'Tree_bd',
                'treebe': 'Tree_be',
                'treend': 'Tree_nd',
                'treene': 'Tree_ne'
            }
        elif pol == 'V':
            feature_mapping = {
                'ku_vod_V': 'VOD',
                'LAI_Satellite': 'LAI',
                'Hveg_Satellite': 'Hveg',  # 添加Hveg变量
                'grassman': 'Grass_man',
                'grassnat': 'Grass_nat',
                'shrubbd': 'Shrub_bd',
                'shrubbe': 'Shrub_be',
                'shrubnd': 'Shrub_nd',
                'shrubne': 'Shrub_ne',
                'treebd': 'Tree_bd',
                'treebe': 'Tree_be',
                'treend': 'Tree_nd',
                'treene': 'Tree_ne'
            }
        elif pol == 'HV':
            feature_mapping = {
                'ku_vod_H': 'VOD-Hpol',
                'ku_vod_V': 'VOD-Vpol',
                'LAI_Satellite': 'LAI',
                'Hveg_Satellite': 'Hveg',  # 添加Hveg变量
                'grassman': 'Grass_man',
                'grassnat': 'Grass_nat',
                'shrubbd': 'Shrub_bd',
                'shrubbe': 'Shrub_be',
                'shrubnd': 'Shrub_nd',
                'shrubne': 'Shrub_ne',
                'treebd': 'Tree_bd',
                'treebe': 'Tree_be',
                'treend': 'Tree_nd',
                'treene': 'Tree_ne'
            }
    
    # X波段
    elif band == 'X':
        if pol == 'H':
            feature_mapping = {
                'x_vod_H': 'VOD',
                'LAI_Satellite': 'LAI',
                'Hveg_Satellite': 'Hveg',  # 添加Hveg变量
                'grassman': 'Grass_man',
                'grassnat': 'Grass_nat',
                'shrubbd': 'Shrub_bd',
                'shrubbe': 'Shrub_be',
                'shrubnd': 'Shrub_nd',
                'shrubne': 'Shrub_ne',
                'treebd': 'Tree_bd',
                'treebe': 'Tree_be',
                'treend': 'Tree_nd',
                'treene': 'Tree_ne'
            }
        elif pol == 'V':
            feature_mapping = {
                'x_vod_V': 'VOD',
                'LAI_Satellite': 'LAI',
                'Hveg_Satellite': 'Hveg',  # 添加Hveg变量
                'grassman': 'Grass_man',
                'grassnat': 'Grass_nat',
                'shrubbd': 'Shrub_bd',
                'shrubbe': 'Shrub_be',
                'shrubnd': 'Shrub_nd',
                'shrubne': 'Shrub_ne',
                'treebd': 'Tree_bd',
                'treebe': 'Tree_be',
                'treend': 'Tree_nd',
                'treene': 'Tree_ne'
            }
        elif pol == 'HV':
            feature_mapping = {
                'x_vod_H': 'VOD-Hpol',
                'x_vod_V': 'VOD-Vpol',
                'LAI_Satellite': 'LAI',
                'Hveg_Satellite': 'Hveg',  # 添加Hveg变量
                'grassman': 'Grass_man',
                'grassnat': 'Grass_nat',
                'shrubbd': 'Shrub_bd',
                'shrubbe': 'Shrub_be',
                'shrubnd': 'Shrub_nd',
                'shrubne': 'Shrub_ne',
                'treebd': 'Tree_bd',
                'treebe': 'Tree_be',
                'treend': 'Tree_nd',
                'treene': 'Tree_ne'
            }
    
    # C波段
    elif band == 'C':
        if pol == 'H':
            feature_mapping = {
                'c_vod_H': 'VOD',
                'LAI_Satellite': 'LAI',
                'Hveg_Satellite': 'Hveg',  # 添加Hveg变量
                'grassman': 'Grass_man',
                'grassnat': 'Grass_nat',
                'shrubbd': 'Shrub_bd',
                'shrubbe': 'Shrub_be',
                'shrubnd': 'Shrub_nd',
                'shrubne': 'Shrub_ne',
                'treebd': 'Tree_bd',
                'treebe': 'Tree_be',
                'treend': 'Tree_nd',
                'treene': 'Tree_ne'
            }
        elif pol == 'V':
            feature_mapping = {
                'c_vod_V': 'VOD',
                'LAI_Satellite': 'LAI',
                'Hveg_Satellite': 'Hveg',  # 添加Hveg变量
                'grassman': 'Grass_man',
                'grassnat': 'Grass_nat',
                'shrubbd': 'Shrub_bd',
                'shrubbe': 'Shrub_be',
                'shrubnd': 'Shrub_nd',
                'shrubne': 'Shrub_ne',
                'treebd': 'Tree_bd',
                'treebe': 'Tree_be',
                'treend': 'Tree_nd',
                'treene': 'Tree_ne'
            }
        elif pol == 'HV':
            feature_mapping = {
                'c_vod_H': 'VOD-Hpol',
                'c_vod_V': 'VOD-Vpol',
                'LAI_Satellite': 'LAI',
                'Hveg_Satellite': 'Hveg',  # 添加Hveg变量
                'grassman': 'Grass_man',
                'grassnat': 'Grass_nat',
                'shrubbd': 'Shrub_bd',
                'shrubbe': 'Shrub_be',
                'shrubnd': 'Shrub_nd',
                'shrubne': 'Shrub_ne',
                'treebd': 'Tree_bd',
                'treebe': 'Tree_be',
                'treend': 'Tree_nd',
                'treene': 'Tree_ne'
            }
    
    # 时间序列插值
    if 'Date' in df.columns and not df.empty:
        # 确保按日期排序
        df = df.sort_values('Date')
        
        # 确定需要插值的特征列
        interpolate_cols = list(feature_mapping.keys())
        valid_cols = [col for col in interpolate_cols if col in df.columns]
        
        # 设置时间索引
        date_index = pd.DatetimeIndex(df['Date'])
        df_temp = df.set_index('Date')
        
        # 生成完整的时间序列范围
        full_range = pd.date_range(start=date_index.min(), end=date_index.max(), freq='D')
        df_full = df_temp.reindex(full_range)
        
        # 对特征列进行线性插值
        for col in valid_cols:
            df_full[col] = df_full[col].interpolate(method='time', limit_direction='both')
            print(f"  已完成{col}的时间序列插值")
        
        # 重置索引
        df = df_full.reset_index().rename(columns={'index': 'Date'})
    else:
        print("  无日期列或数据为空，跳过插值")
    
    # 3. 检查是否所有映射后的特征都存在
    missing_features = []
    for data_col in feature_mapping.keys():
        if data_col not in df.columns:
            missing_features.append(data_col)
    
    if missing_features:
        print(f"  缺少特征: {', '.join(missing_features)}")
        return pd.Series(np.nan, index=df.index)
    
    # 4. 准备特征数据
    X = pd.DataFrame()
    for data_col, model_feature in feature_mapping.items():
        X[model_feature] = df[data_col]
    
    # 5. 应用特征归一化
    # VOD特征归一化（除以2）
    vod_features = ['VOD', 'VOD-Hpol', 'VOD-Vpol']
    for vod_feature in vod_features:
        if vod_feature in X.columns:
            X[vod_feature] = X[vod_feature]
    
    # LAI特征归一化（除以6）
    if 'LAI' in X.columns:
        X['LAI'] = X['LAI']
    
    # Hveg特征归一化（除以40）
    if 'Hveg' in X.columns:
        X['Hveg'] = X['Hveg']
    
    # PFT特征归一化（除以100）
    pft_features = [
        'Grass_man', 'Grass_nat',
        'Shrub_bd', 'Shrub_be', 'Shrub_nd', 'Shrub_ne',
        'Tree_bd', 'Tree_be', 'Tree_nd', 'Tree_ne'
    ]
    for pft_feature in pft_features:
        if pft_feature in X.columns:
            X[pft_feature] = X[pft_feature] / 100.0
    
    # 6. 移除缺失值
    initial_count = len(X)
    X = X.dropna()
    removed_count = initial_count - len(X)
    if removed_count > 0:
        print(f"  移除了 {removed_count} 行包含缺失值的数据")
    
    if X.empty:
        print("  无有效数据可用于预测")
        return pd.Series(np.nan, index=df.index)
    
    # 7. 确保特征顺序与模型期望一致
    if hasattr(model, 'feature_names_in_'):
        X = X[expected_features]
    
    # 8. 预测VWC
    try:
        y_pred = model.predict(X)
        
        # 创建完整长度的预测序列
        full_pred = pd.Series(np.nan, index=df.index)
        full_pred.loc[X.index] = y_pred
        
        return full_pred
    except Exception as e:
        print(f"  预测失败: {str(e)}")
        return pd.Series(np.nan, index=df.index)

def create_combined_plots(data_dict_2017, data_dict_2018):
    """创建组合时间序列图并保存预测结果"""
    print("创建组合时间序列图...")
    
    # 创建输出目录
    output_dir = Path("prediction_results")
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # 创建图形 - 增加底部空间用于多行图例
    fig = plt.figure(figsize=(15, 18))
    gs = gridspec.GridSpec(4, 1, figure=fig, height_ratios=[1, 1, 1, 0.4], hspace=0.3)
    
    # 设置全局标题
    fig.suptitle('Vegetation Water Content Time Series (Hveg Added)', fontsize=20, fontweight='bold', y=0.95)
    
    # 植被类型列表
    vegetation_types = [
        ('CornVegMeasured', data_dict_2017),  # 玉米
        ('OatVegMeasured', data_dict_2017),   # 燕麦
        ('GrassVWC', data_dict_2018)           # 草
    ]
    
    # 实测VWC列名映射
    ACTUAL_COL_MAPPING = {
        'CornVegMeasured': 'total_VWC(kg/m2)',
        'OatVegMeasured': 'total_VWC(kg/m2)',
        'GrassVWC': 'vegetation water content(kg/m2)'
    }
    
    # 存储所有评估指标
    all_metrics = {}
    
    # 存储所有预测结果
    all_predictions = {}
    
    # 根据图片定义极化标记样式
    POL_MARKERS = {
        'H': '+',   # 加号
        'V': '^',   # 三角形
        'HV': 's'   # 正方形
    }
    
    # 波段颜色
    BAND_COLORS = {
        'Ku': 'blue',
        'X': 'green',
        'C': 'red'
    }
    
    # 极化线型
    POL_LINESTYLES = {
        'H': '-',
        'V': '--',
        'HV': '-.'
    }
    
    # 极化名称映射
    POL_NAMES = {
        'H': 'H-Pol',
        'V': 'V-Pol',
        'HV': 'H&V-Pol'
    }
    
    # 波段名称映射
    BAND_NAMES = {
        'Ku': 'Ku-Band',
        'X': 'X-Band',
        'C': 'C-Band'
    }
    
    # 植被类型显示名称
    VEGETATION_TYPES = {
        'CornVegMeasured': 'Corn',
        'OatVegMeasured': 'Oat',
        'GrassVWC': 'Grass'
    }
    
    # 遍历所有植被类型
    for idx, (veg_type, data_dict) in enumerate(vegetation_types):
        ax = fig.add_subplot(gs[idx])
        
        # 获取当前植被类型的实测列名
        actual_col = ACTUAL_COL_MAPPING[veg_type]
        
        # 初始化Y轴范围
        y_min = float('inf')
        y_max = float('-inf')
        
        # 获取实测数据
        if veg_type in data_dict:
            df_measured = data_dict[veg_type].copy()
            
            # 确保日期列存在
            if 'Date' not in df_measured.columns:
                print(f"警告: {veg_type} 中没有 'Date' 列")
                continue
            
            # 按日期排序
            df_measured = df_measured.sort_values('Date')
            
            # 更新Y轴范围（实测值）
            if actual_col in df_measured.columns:
                measured_values = df_measured[actual_col].dropna()
                if not measured_values.empty:
                    y_min = min(y_min, measured_values.min())
                    y_max = max(y_max, measured_values.max())
            
            # 获取拟合数据用于预测
            fitting_sheet = FITTING_MAPPING.get(veg_type, veg_type)
            if fitting_sheet in data_dict:
                df_fitting = data_dict[fitting_sheet].copy()
                
                # 确保日期列存在
                if 'Date' not in df_fitting.columns:
                    print(f"警告: {fitting_sheet} 中没有 'Date' 列")
                    continue
                
                # 按日期排序
                df_fitting = df_fitting.sort_values('Date')
            else:
                # 2018年没有单独的拟合数据
                df_fitting = df_measured.copy()
            
            # 存储评估指标
            metrics = []
            
            # 为每个波段和极化组合预测VWC
            for band in BAND_COLORS.keys():
                for pol in POL_LINESTYLES.keys():
                    # 生成列名
                    col_name = f"Predicted_VWC_{band}_{pol}"
                    
                    # 如果列不存在，使用模型预测
                    if col_name not in df_fitting.columns:
                        print(f"为 {fitting_sheet} 预测 {band}-{pol} VWC...")
                        df_fitting[col_name] = predict_vwc_for_sheet(df_fitting, band, pol)
                    
                    # 只在有有效预测值的点进行绘制和评估
                    if col_name in df_fitting.columns:
                        # 更新Y轴范围（预测值）
                        pred_values = df_fitting[col_name].dropna()
                        if not pred_values.empty:
                            y_min = min(y_min, pred_values.min())
                            y_max = max(y_max, pred_values.max())
                        
                        # 获取有效预测数据点
                        valid_mask = df_fitting[col_name].notna()
                        valid_dates = df_fitting['Date'][valid_mask]
                        valid_values = df_fitting[col_name][valid_mask]
                        
                        # 如果数据点足够多，使用样条插值生成平滑曲线
                        if len(valid_dates) > 3:
                            try:
                                # 将日期转换为数值（从最小日期开始的天数）
                                date_numeric = (valid_dates - valid_dates.min()).dt.days
                                
                                # 创建样条插值对象
                                spline = make_interp_spline(date_numeric, valid_values, k=3)
                                
                                # 生成更密集的时间点
                                dense_dates = np.linspace(date_numeric.min(), date_numeric.max(), 300)
                                dense_values = spline(dense_dates)
                                
                                # 将数值日期转换回实际日期
                                dense_dates = valid_dates.min() + pd.to_timedelta(dense_dates, unit='D')
                                
                                # 绘制平滑曲线
                                ax.plot(dense_dates, dense_values,
                                        color=BAND_COLORS[band],
                                        linestyle=POL_LINESTYLES[pol],
                                        linewidth=1.5)
                            except Exception as e:
                                print(f"样条插值失败: {str(e)}")
                                # 如果插值失败，使用原始数据点绘制折线
                                ax.plot(valid_dates, valid_values,
                                        color=BAND_COLORS[band],
                                        linestyle=POL_LINESTYLES[pol],
                                        linewidth=1.5)
                        else:
                            # 数据点太少，直接绘制折线
                            ax.plot(valid_dates, valid_values,
                                    color=BAND_COLORS[band],
                                    linestyle=POL_LINESTYLES[pol],
                                    linewidth=1.5)
                        
                        # 找出同时有实测值和预测值的点
                        common_data = pd.merge(
                            df_measured[['Date', actual_col]], 
                            df_fitting[['Date', col_name]], 
                            on='Date', 
                            how='inner'
                        ).dropna(subset=[actual_col, col_name])
                        
                        if not common_data.empty:
                            # 更新Y轴范围（共同数据）
                            common_min = min(common_data[actual_col].min(), common_data[col_name].min())
                            common_max = max(common_data[actual_col].max(), common_data[col_name].max())
                            y_min = min(y_min, common_min)
                            y_max = max(y_max, common_max)
                            
                            # 在实测日期位置绘制实测值点（空心圆）
                            ax.plot(common_data['Date'], common_data[actual_col],
                                    linestyle='',  # 无线条
                                    color='black',
                                    marker='o',
                                    markersize=8,
                                    markerfacecolor='none',  # 透明填充（空心）
                                    markeredgewidth=1.5)
                            
                            # 在实测日期位置绘制预测点（空心标记）
                            ax.plot(common_data['Date'], common_data[col_name],
                                    linestyle='',  # 无线条
                                    color=BAND_COLORS[band],
                                    marker=POL_MARKERS[pol],
                                    markersize=10,
                                    markerfacecolor='none',  # 透明填充（空心）
                                    markeredgewidth=1.5)
                            
                            # 计算评估指标
                            rmse = np.sqrt(mean_squared_error(common_data[actual_col], common_data[col_name]))
                            r2 = r2_score(common_data[actual_col], common_data[col_name])
                            
                            # 添加到指标列表
                            metrics.append({
                                'band': band,
                                'pol': pol,
                                'rmse': rmse,
                                'r2': r2
                            })
                            
                            # 保存预测结果
                            model_key = f"{veg_type}_{band}_{pol}"
                            all_predictions[model_key] = {
                                'dates': common_data['Date'].tolist(),
                                'measured': common_data[actual_col].tolist(),
                                'predicted': common_data[col_name].tolist(),
                                'rmse': rmse,
                                'r2': r2
                            }
            
            # 设置子图标题
            ax.set_title(VEGETATION_TYPES.get(veg_type, veg_type), 
                         fontsize=16, fontweight='bold')
            
            # 设置坐标轴标签
            if idx == 2:  # 最后一行
                ax.set_xlabel('Date', fontsize=12, fontweight='bold')
            ax.set_ylabel('VWC (kg/m²)', fontsize=12, fontweight='bold')
            
            # 设置X轴格式
            ax.xaxis.set_major_locator(mdates.DayLocator(interval=10))
            ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d'))
            
            # 添加网格
            ax.grid(True, linestyle='--', alpha=0.3)
            
            # 动态设置Y轴范围
            if y_min != float('inf') and y_max != float('-inf'):
                # 添加10%的边距
                y_range = y_max - y_min
                padding = y_range * 0.1
                
                # 确保最小值不小于0
                y_min = max(0, y_min - padding)
                y_max = y_max + padding
                
                ax.set_ylim(y_min, y_max)
            else:
                # 默认范围
                ax.set_ylim(0, 10)
            
            # 存储指标
            all_metrics[veg_type] = metrics
    
    # ==================================
    # 创建图例（精确匹配要求）
    # ==================================
    
    # 创建图例区域的轴
    ax_legend = fig.add_subplot(gs[3])
    ax_legend.axis('off')  # 隐藏坐标轴
    
    # 定义图例行内容（标题+项目）
    legend_rows = [
        # 第一行：Ku波段
        [f"{BAND_NAMES['Ku']},{POL_NAMES[pol]}" for pol in ['H', 'V', 'HV']],
        
        # 第二行：X波段
        [f"{BAND_NAMES['X']},{POL_NAMES[pol]}" for pol in ['H', 'V', 'HV']],
        
        # 第三行：C波段
        [f"{BAND_NAMES['C']},{POL_NAMES[pol]}" for pol in ['H', 'V', 'HV']],

        # 第四行：实测点
        [f"Insitu VWC"]
    ]
    
    # 创建代理艺术家
    proxies = {}
    
    # Insitu VWC代理（空心圆）
    proxies['insitu'] = plt.Line2D([], [], 
                     linestyle='', 
                     marker='o',
                     markersize=10,
                     markerfacecolor='none',
                     markeredgecolor='black',
                     markeredgewidth=1.5,
                     label='Insitu VWC')
    
    # 波段-极化组合代理
    for band in ['Ku', 'X', 'C']:
        color = BAND_COLORS[band]
        for pol in ['H', 'V', 'HV']:
            proxies[f"{band}-{pol}"] = plt.Line2D([], [],
                color=color,
                linestyle=POL_LINESTYLES[pol],
                linewidth=2,
                marker=POL_MARKERS[pol],
                markersize=10,
                markerfacecolor='none',
                markeredgecolor=color,
                markeredgewidth=1.5,
                label=f"{BAND_NAMES[band]},{POL_NAMES[pol]}")
    
    # 为每行创建图例
    y_positions = [0.85, 0.60, 0.35, 0.10]  # 三行垂直位置
    
    for row_idx, row_items in enumerate(legend_rows):
        handles = []
        labels = []
        
        for item in row_items:
            # 处理Insitu项
            if item == "Insitu VWC":
                handles.append(proxies['insitu'])
                labels.append(item)
            # 处理波段-极化项
            else:
                # 解析新的标签格式
                band_part, pol_part = item.split(',')
                band = band_part.split('-')[0]  # 提取波段名称
                
                handles.append(proxies[f"{band}-{pol}"])
                labels.append(item)  # 使用完整的标签文本
        
        # 计算当前行文本宽度（均匀分布）
        n_items = len(handles)
        x_positions = np.linspace(0.05, 0.95, n_items)
        
        # 绘制当前行的图例项
        for i, (handle, label) in enumerate(zip(handles, labels)):
            ax_legend.plot([], [])  # 空白绘图以创建图例项
            
            # 创建图例句柄
            leg = ax_legend.legend([handle], [label], 
                                  loc='lower center',
                                  bbox_to_anchor=(x_positions[i], y_positions[row_idx]),
                                  frameon=False,
                                  handlelength=2,
                                  fontsize=10,
                                  handletextpad=0.8)
            
            # 添加到轴（否则会被覆盖）
            ax_legend.add_artist(leg)
    
    # 在子图中显示评估指标
    for idx, (veg_type, metrics) in enumerate(all_metrics.items()):
        if idx < 3:  # 确保索引有效（排除图例轴）
            ax = fig.axes[idx]
            
            # 创建指标文本
            if metrics:
                # 使用多列格式显示所有指标
                metric_text = "Evaluation Metrics:\n"
                
                # 按波段分组指标
                band_metrics = {}
                for metric in metrics:
                    band = metric['band']
                    if band not in band_metrics:
                        band_metrics[band] = []
                    band_metrics[band].append(metric)
                
                # 为每个波段创建一行文本
                for band in ['Ku', 'X', 'C']:
                    if band in band_metrics:
                        band_text = f"{BAND_NAMES[band]}: "
                        pol_texts = []
                        for metric in band_metrics[band]:
                            pol_texts.append(f"{POL_NAMES[metric['pol']]}(RMSE={metric['rmse']:.3f})")
                        band_text += ", ".join(pol_texts)
                        metric_text += band_text + "\n"
                
                # 添加文本框
                ax.text(0.02, 0.95, metric_text, 
                        transform=ax.transAxes,
                        fontsize=9,
                        verticalalignment='top',
                        bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    # 调整布局
    plt.tight_layout(rect=[0, 0, 1, 0.93])
    
    # 保存图像
    figures_dir = Path("figures")
    figures_dir.mkdir(parents=True, exist_ok=True)
    fig_path = figures_dir / "Combined_VWC_Time_Series_Primary.png"
    plt.savefig(fig_path, dpi=300, bbox_inches='tight')
    print(f"组合时间序列图已保存至: {fig_path}")
    plt.close()
    
    # 保存所有预测结果到CSV文件
    for model_key, data in all_predictions.items():
        veg_type, band, pol = model_key.split('_')
        df = pd.DataFrame({
            'Date': data['dates'],
            'Measured': data['measured'],
            'Predicted': data['predicted']
        })
        csv_path = output_dir / f"{veg_type}_{band}_{pol}_predictions_Primary.csv"
        df.to_csv(csv_path, index=False)
        print(f"保存预测结果至: {csv_path}")
    
    # 保存评估指标
    metrics_path = output_dir / "model_metrics_Primary.csv"
    metrics_data = []
    for veg_type, metrics in all_metrics.items():
        for metric in metrics:
            metrics_data.append({
                'Vegetation': veg_type,
                'Band': metric['band'],
                'Polarization': metric['pol'],
                'RMSE': metric['rmse'],
                'R2': metric['r2']
            })
    
    metrics_df = pd.DataFrame(metrics_data)
    metrics_df.to_csv(metrics_path, index=False)
    print(f"保存模型评估指标至: {metrics_path}")

def main():
    # 2017年数据文件
    file_2017 = r"E:\data\VWC\test-VWC\多频多角度地基微波辐射计及地表参量观测数据集\DuolunExp_Veg_ML.xlsx"
    data_2017 = load_data(file_2017)
    
    # 2018年数据文件
    file_2018 = r"E:\data\VWC\test-VWC\多频多角度地基微波辐射计及地表参量观测数据集\ZhenglanqiExp_VWC_ML.xlsx"
    data_2018 = load_data(file_2018)
    
    # 创建组合时间序列图
    create_combined_plots(data_2017, data_2018)
    
    print("\n处理完成!")

if __name__ == "__main__":
    main()

加载文件: E:\data\VWC\test-VWC\多频多角度地基微波辐射计及地表参量观测数据集\DuolunExp_Veg_ML.xlsx
  - CornVegMeasured: 8行
  - CornVegFitting: 64行
  - OatVegMeasured: 7行
  - OatVegFitting: 64行
加载文件: E:\data\VWC\test-VWC\多频多角度地基微波辐射计及地表参量观测数据集\ZhenglanqiExp_VWC_ML.xlsx
  - GrassVWC: 13行
创建组合时间序列图...
为 CornVegFitting 预测 Ku-H VWC...
加载模型: models/RFR_Ku_H-pol_Type1_Primary.pkl
  模型期望特征: ['VOD', 'LAI', 'Hveg', 'Grass_man', 'Grass_nat', 'Shrub_bd', 'Shrub_be', 'Shrub_nd', 'Shrub_ne', 'Tree_bd', 'Tree_be', 'Tree_nd', 'Tree_ne']
  使用实测LAI替换了 63 行数据
  使用实测Hveg替换了 63 行数据（高度单位转换为m）
  已完成ku_vod_H的时间序列插值
  已完成LAI_Satellite的时间序列插值
  已完成Hveg_Satellite的时间序列插值
  已完成grassman的时间序列插值
  已完成grassnat的时间序列插值
  已完成shrubbd的时间序列插值
  已完成shrubbe的时间序列插值
  已完成shrubnd的时间序列插值
  已完成shrubne的时间序列插值
  已完成treebd的时间序列插值
  已完成treebe的时间序列插值
  已完成treend的时间序列插值
  已完成treene的时间序列插值
为 CornVegFitting 预测 Ku-V VWC...
加载模型: models/RFR_Ku_V-pol_Type1_Primary.pkl
  模型期望特征: ['VOD', 'LAI', 'Hveg', 'Grass_man', 'Grass_nat', 'Shrub_bd', 'Shrub_be', 'Shrub_nd', 'Shrub

## 散点图

In [11]:
# 3 * 3 散点图结果（添加Hveg变量）
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from pathlib import Path
import warnings
from sklearn.metrics import mean_squared_error
warnings.filterwarnings('ignore')

# 设置全局字体
plt.rcParams['font.family'] = 'Arial'
plt.rcParams['font.weight'] = 'bold'
plt.rcParams['axes.labelweight'] = 'bold'
plt.rcParams['axes.titleweight'] = 'bold'
plt.rcParams['figure.titlesize'] = 16
plt.rcParams['figure.titleweight'] = 'bold'

# 常量定义
BANDS = ['Ku', 'X', 'C']
BAND_COLORS = {
    'Ku': '#1f77b4',  # 蓝色
    'X': '#ff7f0e',   # 橙色
    'C': '#2ca02c'    # 绿色
}
POLS = ['H', 'V', 'HV']

# 植被类型标记样式 - 玉米标记改为空心方形（'s'）
VEG_MARKERS = {
    'CornVegMeasured': {'marker': 's', 'size': 80, 'label': 'Corn (2017)'},  # 改为方形
    'OatVegMeasured': {'marker': '^', 'size': 80, 'label': 'Oat (2017)'},
    'GrassVWC': {'marker': 'o', 'size': 80, 'label': 'Grass (2018)'}
}

def load_prediction_data(prediction_dir):
    """从CSV文件加载预测结果（添加Hveg后缀）"""
    print(f"加载预测结果: {prediction_dir}")
    all_predictions = {}
    
    # 遍历所有CSV文件（匹配新文件名）
    for csv_file in prediction_dir.glob("*_predictions_Primary.csv"):
        # 解析文件名获取模型信息
        filename = csv_file.stem
        parts = filename.split('_')
        
        if len(parts) >= 4:  # 格式: {植被类型}_{波段}_{极化}_predictions_Primary
            veg_type = parts[0]
            band = parts[1]
            pol = parts[2]
            model_key = f"{band}_{pol}"
            
            # 加载数据
            df = pd.read_csv(csv_file)
            
            # 确保日期是datetime类型
            if 'Date' in df.columns:
                df['Date'] = pd.to_datetime(df['Date'])
            
            # 存储数据
            if model_key not in all_predictions:
                all_predictions[model_key] = {}
            
            all_predictions[model_key][veg_type] = df
    
    return all_predictions

def get_model_title(band, pol):
    """根据波段和极化返回自定义标题"""
    band_names = {
        'Ku': 'Ku-Band',
        'X': 'X-Band',
        'C': 'C-Band'
    }
    pol_names = {
        'H': 'H-Pol',
        'V': 'V-Pol',
        'HV': 'H&V-Pol'  # 修改这里
    }
    return f"{band_names.get(band, band)},{pol_names.get(pol, pol)}"

def create_scatter_plots_from_predictions(prediction_dir):
    """从预测结果文件创建9个模型的真值与预测值散点图（3x3网格）"""
    # 加载预测结果
    all_predictions = load_prediction_data(prediction_dir)
    
    if not all_predictions:
        print("警告: 没有找到预测结果文件")
        return
    
    # 创建3x3网格图
    fig = plt.figure(figsize=(15, 15))
    # fig.suptitle('VWC预测散点图（添加Hveg变量）', fontsize=20, y=0.95)  # 添加标题
    gs = gridspec.GridSpec(3, 3, wspace=0.3, hspace=0.3)
    
    # 收集所有散点的最小和最大值（用于统一坐标轴）
    all_actual_min, all_actual_max = np.inf, -np.inf
    all_pred_min, all_pred_max = np.inf, -np.inf
    
    # 收集所有评估指标
    all_metrics = {}

    # 处理每个模型（波段和极化组合）
    for i, band in enumerate(BANDS):
        for j, pol in enumerate(POLS):
            model_key = f"{band}_{pol}"
            ax = plt.subplot(gs[i, j])
            print(f"处理模型: {model_key}")
            
            # 检查该模型是否有预测数据
            if model_key not in all_predictions:
                print(f"警告: {model_key} 模型没有预测数据")
                ax.text(0.5, 0.5, 'No Data', horizontalalignment='center', 
                        verticalalignment='center', transform=ax.transAxes,
                        fontsize=14, color='red')
                ax.set_title(get_model_title(band, pol), fontsize=14)
                continue
                
            # 收集该模型的所有植被类型的数据
            all_actual = []
            all_predicted = []
            all_veg_types = []
            
            # 存储各植被类型的数据点
            veg_data = {
                'CornVegMeasured': {'actual': [], 'predicted': []},
                'OatVegMeasured': {'actual': [], 'predicted': []},
                'GrassVWC': {'actual': [], 'predicted': []}
            }
            
            # 处理玉米数据
            veg_type = 'CornVegMeasured'
            if veg_type in all_predictions[model_key]:
                df = all_predictions[model_key][veg_type]
                if 'Measured' in df.columns and 'Predicted' in df.columns:
                    # 添加数据点
                    veg_data[veg_type]['actual'] = df['Measured'].tolist()
                    veg_data[veg_type]['predicted'] = df['Predicted'].tolist()
                    
                    # 添加到总数据
                    all_actual.extend(df['Measured'])
                    all_predicted.extend(df['Predicted'])
                    all_veg_types.extend([veg_type] * len(df))
                    print(f"  - 玉米数据点: {len(df)}")
            
            # 处理燕麦数据
            veg_type = 'OatVegMeasured'
            if veg_type in all_predictions[model_key]:
                df = all_predictions[model_key][veg_type]
                if 'Measured' in df.columns and 'Predicted' in df.columns:
                    # 添加数据点
                    veg_data[veg_type]['actual'] = df['Measured'].tolist()
                    veg_data[veg_type]['predicted'] = df['Predicted'].tolist()
                    
                    # 添加到总数据
                    all_actual.extend(df['Measured'])
                    all_predicted.extend(df['Predicted'])
                    all_veg_types.extend([veg_type] * len(df))
                    print(f"  - 燕麦数据点: {len(df)}")
            
            # 处理草数据
            veg_type = 'GrassVWC'
            if veg_type in all_predictions[model_key]:
                df = all_predictions[model_key][veg_type]
                if 'Measured' in df.columns and 'Predicted' in df.columns:
                    # 添加数据点
                    veg_data[veg_type]['actual'] = df['Measured'].tolist()
                    veg_data[veg_type]['predicted'] = df['Predicted'].tolist()
                    
                    # 添加到总数据
                    all_actual.extend(df['Measured'])
                    all_predicted.extend(df['Predicted'])
                    all_veg_types.extend([veg_type] * len(df))
                    print(f"  - 草数据点: {len(df)}")
            
            # 如果没有数据点，跳过
            if len(all_actual) == 0:
                print(f"警告: {model_key} 模型没有有效数据点")
                ax.text(0.5, 0.5, 'No Data', horizontalalignment='center', 
                        verticalalignment='center', transform=ax.transAxes,
                        fontsize=14, color='red')
                ax.set_title(get_model_title(band, pol), fontsize=14)
                continue
                
            # 转换为numpy数组
            all_actual = np.array(all_actual)
            all_predicted = np.array(all_predicted)
            
            # 更新全局最小/最大值
            all_actual_min = min(all_actual_min, np.min(all_actual))
            all_actual_max = max(all_actual_max, np.max(all_actual))
            all_pred_min = min(all_pred_min, np.min(all_predicted))
            all_pred_max = max(all_pred_max, np.max(all_predicted))
            
            # 计算各植被类型的RMSE
            rmse_corn = None
            rmse_oat = None
            rmse_grass = None
            
            if veg_data['CornVegMeasured']['actual']:
                actual_corn = np.array(veg_data['CornVegMeasured']['actual'])
                predicted_corn = np.array(veg_data['CornVegMeasured']['predicted'])
                rmse_corn = np.sqrt(mean_squared_error(actual_corn, predicted_corn))
            
            if veg_data['OatVegMeasured']['actual']:
                actual_oat = np.array(veg_data['OatVegMeasured']['actual'])
                predicted_oat = np.array(veg_data['OatVegMeasured']['predicted'])
                rmse_oat = np.sqrt(mean_squared_error(actual_oat, predicted_oat))
            
            if veg_data['GrassVWC']['actual']:
                actual_grass = np.array(veg_data['GrassVWC']['actual'])
                predicted_grass = np.array(veg_data['GrassVWC']['predicted'])
                rmse_grass = np.sqrt(mean_squared_error(actual_grass, predicted_grass))
            
            # 计算整体RMSE
            rmse_total = np.sqrt(mean_squared_error(all_actual, all_predicted))
            
            # 存储评估指标
            all_metrics[model_key] = {
                'RMSE_Corn': rmse_corn,
                'RMSE_Oat': rmse_oat,
                'RMSE_Grass': rmse_grass,
                'RMSE_Total': rmse_total
            }
            
            # 绘制散点图 - 按植被类型区分标记
            # 先绘制草和燕麦，最后绘制玉米（确保玉米在最上层）
            for veg_type in ['GrassVWC', 'OatVegMeasured', 'CornVegMeasured']:
                if veg_data[veg_type]['actual']:
                    actual_values = np.array(veg_data[veg_type]['actual'])
                    predicted_values = np.array(veg_data[veg_type]['predicted'])
                    
                    marker_style = VEG_MARKERS[veg_type]
                    
                    # 为玉米标记使用更大的尺寸和线宽
                    if veg_type == 'CornVegMeasured':
                        size = 100  # 增加大小
                        edgewidth = 1.5  # 更粗的线宽
                        alpha = 0.9  # 更高的不透明度
                    else:
                        size = marker_style['size']
                        edgewidth = 1.0
                        alpha = 0.8
                    
                    # 所有标记使用相同的波段颜色
                    ax.scatter(actual_values, predicted_values, 
                              marker=marker_style['marker'], 
                              s=size,
                              alpha=alpha,  # 调整透明度
                              facecolor='none', 
                              edgecolor=BAND_COLORS[band],  # 使用波段颜色
                              linewidths=edgewidth,
                              label=marker_style['label'])
            
            # 添加1:1参考线
            ax.plot([0, 4], [0, 4], 'k--', linewidth=1, label='1:1 Line')
            
            # 设置标题和坐标轴标签
            ax.set_title(get_model_title(band, pol), fontsize=14)
            if j == 0:  # 第一列添加y轴标签
                ax.set_ylabel('RF VWC (kg/m²)', fontsize=12)
            if i == 2:  # 最后一行添加x轴标签
                ax.set_xlabel('In Situ VWC (kg/m²)', fontsize=12)
            
            # 添加网格
            ax.grid(True, linestyle='--', alpha=0.3)
            
            # 显示评估指标
            metric_text = ""
            if rmse_corn is not None:
                metric_text += f"Corn RMSE = {rmse_corn:.3f}\n"
            if rmse_oat is not None:
                metric_text += f"Oat RMSE = {rmse_oat:.3f}\n"
            if rmse_grass is not None:
                metric_text += f"Grass RMSE = {rmse_grass:.3f}\n"
            metric_text += f"Total RMSE = {rmse_total:.3f}"
            
            ax.text(0.05, 0.95, metric_text, transform=ax.transAxes, 
                   fontsize=9, verticalalignment='top',
                   bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    # 设置所有子图的坐标轴范围一致
    max_val = 4
    min_val = 0
    for ax in fig.get_axes():
        ax.set_xlim(min_val, max_val)
        ax.set_ylim(min_val, max_val)
    
    # 添加图例
    # 创建代理艺术家用于图例
    handles = []
    labels = []
    
    # 添加植被类型标记
    for veg_type, style in VEG_MARKERS.items():
        # 为玉米标记使用特殊大小
        if veg_type == 'CornVegMeasured':
            markersize = 10  # 图例中保持相同大小
        else:
            markersize = 8
            
        handles.append(
            plt.Line2D([], [], marker=style['marker'], linestyle='None', 
                       markersize=markersize, alpha=0.7, markerfacecolor='none', 
                       markeredgecolor='gray', label=style['label'])
        )
    
    # 添加1:1参考线
    handles.append(
        plt.Line2D([], [], color='k', linestyle='--', linewidth=1, label='1:1 Line')
    )
    
    plt.tight_layout(rect=[0, 0.01, 1, 0.95])  # 调整底部空间
 
    # 添加图例到整个图形
    fig.legend(handles=handles, loc='lower center', 
               bbox_to_anchor=(0.5, 0.05), ncol=4, fontsize=10, 
               title="")
    output_dir = Path("figures")
    output_dir.mkdir(parents=True, exist_ok=True)
    fig_path = output_dir / "Scatter_Predictions_From_Saved_Data_Primary.png"  # 添加后缀
    plt.savefig(fig_path, dpi=1000, bbox_inches='tight', pad_inches=0.1)
    print(f"散点图已保存至: {fig_path}")
    plt.close()
    
    # 打印所有模型的评估指标
    print("\n模型评估指标:")
    for model_name, metrics in all_metrics.items():
        print(f"{model_name}:")
        if metrics['RMSE_Corn'] is not None:
            print(f"  Corn RMSE = {metrics['RMSE_Corn']:.4f}")
        if metrics['RMSE_Oat'] is not None:
            print(f"  Oat RMSE = {metrics['RMSE_Oat']:.4f}")
        if metrics['RMSE_Grass'] is not None:
            print(f"  Grass RMSE = {metrics['RMSE_Grass']:.4f}")
        print(f"  Total RMSE = {metrics['RMSE_Total']:.4f}")

def main():
    # 设置预测结果目录
    prediction_dir = Path("prediction_results")
    
    # 创建散点图
    create_scatter_plots_from_predictions(prediction_dir)
    
    print("\n处理完成!")

if __name__ == "__main__":
    main()

加载预测结果: prediction_results
处理模型: Ku_H
  - 玉米数据点: 8
  - 燕麦数据点: 7
  - 草数据点: 13
处理模型: Ku_V
  - 玉米数据点: 8
  - 燕麦数据点: 7
  - 草数据点: 13
处理模型: Ku_HV
  - 玉米数据点: 8
  - 燕麦数据点: 7
  - 草数据点: 13
处理模型: X_H
  - 玉米数据点: 8
  - 燕麦数据点: 7
  - 草数据点: 13
处理模型: X_V
  - 玉米数据点: 8
  - 燕麦数据点: 7
  - 草数据点: 13
处理模型: X_HV
  - 玉米数据点: 8
  - 燕麦数据点: 7
  - 草数据点: 13
处理模型: C_H
  - 玉米数据点: 8
  - 燕麦数据点: 7
  - 草数据点: 13
处理模型: C_V
  - 玉米数据点: 8
  - 燕麦数据点: 7
  - 草数据点: 13
处理模型: C_HV
  - 玉米数据点: 8
  - 燕麦数据点: 7
  - 草数据点: 13
散点图已保存至: figures\Scatter_Predictions_From_Saved_Data_Primary.png

模型评估指标:
Ku_H:
  Corn RMSE = 0.5561
  Oat RMSE = 0.3788
  Grass RMSE = 0.2649
  Total RMSE = 0.3960
Ku_V:
  Corn RMSE = 1.2323
  Oat RMSE = 0.4925
  Grass RMSE = 0.3885
  Total RMSE = 0.7514
Ku_HV:
  Corn RMSE = 1.0768
  Oat RMSE = 0.4933
  Grass RMSE = 0.4472
  Total RMSE = 0.6964
X_H:
  Corn RMSE = 0.6267
  Oat RMSE = 0.3992
  Grass RMSE = 0.3589
  Total RMSE = 0.4602
X_V:
  Corn RMSE = 0.7077
  Oat RMSE = 0.6552
  Grass RMSE = 0.2363
  Total RMSE = 0.525