# 模型生成图像

## 单日数据生成及掩膜（这里把主程序的overwrite改为true即可，以方便直接复写）

In [1]:
# 单日数据的生成（修复VOD变量名称问题）
import pandas as pd
import numpy as np
import time
import random
import os
import joblib
from datetime import datetime, timedelta
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
import matplotlib.pyplot as plt
import h5py
from osgeo import gdal, osr
import logging
from tqdm import tqdm
import concurrent.futures
import warnings
import sys

# 忽略警告
warnings.filterwarnings('ignore')

# 设置随机种子保证可重复性
SEED = 42
random.seed(SEED)
np.random.seed(SEED)

# 全局文件路径常量
DATA_FILE_PATH = r"E:\Matlab\EX2025\AuxiliaryData\LFMC-gridMean-ML.xlsx"
VOD_BASE_PATH = r"E:\data\VOD\mat\kuxcVOD\ASC"
LAI_BASE_PATH = r"E:\data\GLASS LAI\mat\0.1Deg\Dataset"
PFT_BASE_PATH = r"E:\data\ESACCI PFT\Resample\Data"
OUTPUT_PATH = r"E:\data\VWC\VWCMap\Daily"

# 确保输出目录存在
os.makedirs(OUTPUT_PATH, exist_ok=True)

# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(),
        logging.FileHandler('vwc_generation.log')
    ]
)
logger = logging.getLogger()

# 加载模型
MODEL_PATH = os.path.join(os.getcwd(), 'models', 'RFR_6VOD_LAI_PFTs.pkl')
try:
    if os.path.exists(MODEL_PATH):
        MODEL = joblib.load(MODEL_PATH)
        logger.info(f"成功加载新模型: {MODEL_PATH}")
        logger.info(f"模型特征数: {MODEL.n_features_in_}")
        
        # 记录模型期望的特征名称
        if hasattr(MODEL, 'feature_names_in_'):
            logger.info(f"模型期望的特征顺序: {', '.join(MODEL.feature_names_in_)}")
    else:
        logger.error(f"模型文件不存在: {MODEL_PATH}")
        sys.exit(1)
except Exception as e:
    logger.error(f"加载模型失败: {str(e)}")
    sys.exit(1)

# 缓存机制
LAI_MONTH_CACHE = {}  # 缓存LAI数据（按年月）
PFT_YEAR_CACHE = {}    # 缓存PFT数据（按年）

# ============================== 辅助函数 ==============================
def create_singleband_geotiff(data, output_path, nodata=-9999.0):
    """创建单波段地理参考的TIFF文件"""
    try:
        driver = gdal.GetDriverByName('GTiff')
        rows, cols = data.shape
        
        # 创建数据集
        out_ds = driver.Create(
            output_path, 
            cols, 
            rows, 
            1, 
            gdal.GDT_Float32,
            options=['COMPRESS=LZW', 'BIGTIFF=YES']
        )
        
        # 设置地理变换
        geotransform = (-180.0, 0.1, 0.0, 90.0, 0.0, -0.1)
        out_ds.SetGeoTransform(geotransform)
        
        # 设置坐标系 (WGS84)
        srs = osr.SpatialReference()
        srs.ImportFromEPSG(4326)
        out_ds.SetProjection(srs.ExportToWkt())
        
        # 写入波段数据
        band = out_ds.GetRasterBand(1)
        band.WriteArray(data)
        band.SetNoDataValue(nodata)
        band.SetDescription('VWC')
        
        # 清理
        out_ds.FlushCache()
        out_ds = None
        
        logger.info(f"成功创建GeoTIFF: {output_path}")
        return True
    except Exception as e:
        logger.error(f"创建GeoTIFF失败: {str(e)}")
        return False

def get_vod_file(date):
    """获取VOD文件路径"""
    date_str = date.strftime('%Y%m%d')
    
    # 查找VOD文件
    possible_files = [
        f'MCCA_AMSR2_010D_CCXH_VSM_VOD_Asc_{date_str}_V0.nc4.mat',
        f'MCCA_AMSR2_010D_CCXH_VSM_VOD_Asc_{date_str}_V0.mat',
        f'MCCA_AMSR2_010D_CCXH_VSM_VOD_Asc_{date_str}_V1.nc4.mat',
        f'MCCA_AMSR2_010D_CCXH_VSM_VOD_Asc_{date_str}.mat'
    ]
    
    for filename in possible_files:
        file_path = os.path.join(VOD_BASE_PATH, filename)
        if os.path.exists(file_path):
            return file_path
    
    logger.warning(f"未找到VOD文件: {date_str}")
    return None

def get_lai_file(year, month):
    """获取LAI文件路径"""
    month_str = str(month).zfill(2)
    
    # 尝试文件名格式
    possible_files = [
        f'{year}-{month_str}-01.tif.mat',
        f'{year}-{month_str}-01.mat',
        f'LAI_{year}{month_str}.mat',
        f'{year}_{month_str}_LAI.mat'
    ]
    
    for filename in possible_files:
        file_path = os.path.join(LAI_BASE_PATH, filename)
        if os.path.exists(file_path):
            return file_path
    
    logger.warning(f"未找到LAI文件: {year}-{month_str}")
    return None

def get_pft_file(year):
    """获取PFT文件路径"""
    # 尝试不同文件名
    possible_files = [
        f'{year}.mat',
        f'PFT_{year}.mat',
        f'ESACCI_PFT_{year}.mat',
        f'pft_{year}.mat'
    ]
    
    for filename in possible_files:
        file_path = os.path.join(PFT_BASE_PATH, filename)
        if os.path.exists(file_path):
            return file_path
    
    logger.warning(f"未找到PFT文件: {year}")
    return None

def get_month_centers(date):
    """获取用于插值的前后月中日期"""
    current_mid = date.replace(day=15)
    
    if date.day <= 15:
        prev_mid = (current_mid - timedelta(days=30)).replace(day=15)
        next_mid = current_mid
    else:
        prev_mid = current_mid
        next_mid = (current_mid + timedelta(days=30)).replace(day=15)
    return prev_mid, next_mid

def load_lai_matrix(file_path):
    """加载LAI矩阵数据"""
    if not file_path or not os.path.exists(file_path): 
        return None
    
    try:
        with h5py.File(file_path, 'r') as f:
            # 尝试不同数据集名称
            for key in ['lai', 'Layer', 'data']:
                if key in f:
                    lai_data = np.array(f[key][:])
                    break
            else:
                # 使用第一个数据集
                keys = list(f.keys())
                if keys:
                    lai_data = np.array(f[keys[0]][:])
                    logger.warning(f"使用默认数据集 '{keys[0]}' 作为LAI数据: {file_path}")
                else:
                    logger.warning(f"未找到数据集: {file_path}")
                    return None
            
            # 转置为(1800,3600)
            if lai_data.shape == (3600, 1800):
                lai_data = lai_data.T
            elif lai_data.shape != (1800, 3600):
                logger.error(f"不支持的LAI数据形状: {lai_data.shape}")
                return None
            
            return lai_data
    except Exception as e:
        logger.error(f"加载LAI文件失败: {file_path} - {str(e)}")
        return None

def load_pft_matrix(file_path):
    """加载并旋转PFT矩阵数据"""
    if not file_path or not os.path.exists(file_path): 
        return None
    
    try:
        pft_data = {}
        with h5py.File(file_path, 'r') as f:
            # 定义植被类型映射
            pft_mapping = {
                'water': ['water', 'WATER'],
                'bare': ['bare', 'bareland', 'BARE'],
                'snowice': ['snowice', 'snow', 'ice', 'SNOWICE'],
                'built': ['built', 'urban', 'BUILT'],
                'grassnat': ['grassnat', 'grass_natural', 'GRASSNAT'],
                'grassman': ['grassman', 'grass_managed', 'GRASSMAN'],
                'shrubbd': ['shrubbd', 'shrub_bd', 'SHRUBBD'],
                'shrubbe': ['shrubbe', 'shrub_be', 'SHRUBBE'],
                'shrubnd': ['shrubnd', 'shrub_nd', 'SHRUBND'],
                'shrubne': ['shrubne', 'shrub_ne', 'SHRUBNE'],
                'treebd': ['treebd', 'tree_bd', 'TREEBD'],
                'treebe': ['treebe', 'tree_be', 'TREEBE'],
                'treend': ['treend', 'tree_nd', 'TREEND'],
                'treene': ['treene', 'tree_ne', 'TREENE']
            }
            
            # 查找匹配的数据集
            matched_datasets = {}
            for target, aliases in pft_mapping.items():
                for alias in aliases:
                    if alias in f:
                        data = np.array(f[alias][:])
                        # 旋转和转置
                        if data.shape == (3600, 7200):
                            # 从0.05度到0.1度的降采样
                            data = np.rot90(data, k=-1)
                            # 平均聚合到0.1度
                            data = (data[:, ::2] + data[:, 1::2]) / 2.0
                            pft_data[target] = data.T
                        elif data.shape == (3600, 1800):
                            pft_data[target] = np.rot90(data, k=-1)
                        elif data.shape == (1800, 3600):
                            pft_data[target] = data
                        else:
                            logger.warning(f"未知PFT形状: {data.shape} for {target}")
                            pft_data[target] = np.zeros((1800, 3600))
                        matched_datasets[alias] = True
                        break
                else:
                    logger.warning(f"未找到 {target} 的PFT数据")
                    pft_data[target] = np.zeros((1800, 3600))
                
        return pft_data
    except Exception as e:
        logger.error(f"加载PFT文件失败: {file_path} - {str(e)}")
        return None

def create_land_mask(rows=1800, cols=3600):
    """创建简单的陆地掩膜"""
    # 创建纬度数组
    lats = np.linspace(90, -90, rows)
    
    # 创建陆地掩膜 (排除极地和海洋)
    land_mask = np.zeros((rows, cols), dtype=bool)
    for i in range(rows):
        if -60 <= lats[i] <= 80:  # 排除南极和北极
            land_mask[i, :] = True
    
    return land_mask

def prepare_features(vod_data, lai_data, pft_data, valid_mask):
    """准备特征矩阵（6VOD + LAI + PFTs） - 修复特征名称问题"""
    # 获取有效索引
    valid_indices = np.where(valid_mask)
    num_valid = len(valid_indices[0])
    
    if num_valid == 0:
        logger.warning("无有效数据点")
        return None, None, None
    
    # 初始化特征矩阵
    features = np.zeros((num_valid, 0), dtype=np.float32)
    feature_names = []
    
    # 1. VOD特征 (6个波段) - 使用训练代码中的名称
    vod_keys = [
        'VOD_Ku_Hpol_Asc', 'VOD_Ku_Vpol_Asc',
        'VOD_X_Hpol_Asc', 'VOD_X_Vpol_Asc',
        'VOD_C_Hpol_Asc', 'VOD_C_Vpol_Asc'
    ]
    
    # 定义MAT文件中VOD变量的实际名称
    vod_mat_mapping = {
        'VOD_Ku_Hpol_Asc': 'ku_vod_H',
        'VOD_Ku_Vpol_Asc': 'ku_vod_V',
        'VOD_X_Hpol_Asc': 'x_vod_H',
        'VOD_X_Vpol_Asc': 'x_vod_V',
        'VOD_C_Hpol_Asc': 'c_vod_H',
        'VOD_C_Vpol_Asc': 'c_vod_V'
    }
    
    for key in vod_keys:
        # 获取MAT文件中的实际变量名
        mat_key = vod_mat_mapping.get(key)
        
        # 检查是否在VOD数据中
        if mat_key in vod_data:
            vod_val = vod_data[mat_key] / 1.5
            features = np.column_stack((features, vod_val[valid_indices]))
            feature_names.append(key)
            logger.debug(f"成功加载VOD特征: {key} (MAT变量: {mat_key})")
        elif key in vod_data:
            # 如果使用标准名称存在
            vod_val = vod_data[key] / 1.5
            features = np.column_stack((features, vod_val[valid_indices]))
            feature_names.append(key)
            logger.debug(f"成功加载VOD特征: {key}")
        else:
            logger.warning(f"VOD特征缺失: {key}，使用0填充")
            features = np.column_stack((features, np.zeros(num_valid)))
            feature_names.append(key)
    
    # 2. LAI特征 (归一化到0-1)
    feature_names.append('LAI')
    if lai_data is not None:
        lai_val = lai_data / 6.0
        features = np.column_stack((features, lai_val[valid_indices]))
    else:
        logger.warning("LAI特征缺失，使用0填充")
        features = np.column_stack((features, np.zeros(num_valid)))
    
    # 3. PFT特征 (归一化)
    pft_features = [
        'Grass_man', 'Grass_nat',
        'Shrub_bd', 'Shrub_be', 'Shrub_nd', 'Shrub_ne',
        'Tree_bd', 'Tree_be', 'Tree_nd', 'Tree_ne'
    ]
    
    for feature in pft_features:
        feature_names.append(feature)
        if pft_data is not None:
            # PFT特征映射
            pft_mapping = {
                'Grass_man': 'grassman',
                'Grass_nat': 'grassnat',
                'Shrub_bd': 'shrubbd',
                'Shrub_be': 'shrubbe',
                'Shrub_nd': 'shrubnd',
                'Shrub_ne': 'shrubne',
                'Tree_bd': 'treebd',
                'Tree_be': 'treebe',
                'Tree_nd': 'treend',
                'Tree_ne': 'treene'
            }
            
            pft_key = pft_mapping.get(feature)
            if pft_key and pft_key in pft_data:
                pft_val = pft_data[pft_key] / 100.0
                features = np.column_stack((features, pft_val[valid_indices]))
            else:
                logger.warning(f"PFT特征缺失: {feature}，使用0填充")
                features = np.column_stack((features, np.zeros(num_valid)))
        else:
            logger.warning("PFT数据缺失，使用0填充")
            features = np.column_stack((features, np.zeros(num_valid)))
    
    logger.info(f"准备特征完成: {features.shape} 形状, 特征数: {len(feature_names)}")
    return features, feature_names, valid_indices

def predict_vwc(features, feature_names, rows, cols, valid_mask):
    """预测VWC"""
    if features is None:
        return None
    
    # 初始化结果数组
    vwc_array = np.full((rows, cols), -9999.0, dtype=np.float32)
    valid_indices = np.where(valid_mask)
    num_valid = len(valid_indices[0])
    
    if num_valid == 0:
        logger.warning("无有效数据点")
        return vwc_array
    
    try:
        # 检查特征顺序是否匹配
        if hasattr(MODEL, 'feature_names_in_'):
            model_features = list(MODEL.feature_names_in_)
            
            # 记录模型期望的特征名称
            logger.info(f"模型期望的特征 ({len(model_features)}): {', '.join(model_features)}")
            logger.info(f"实际提供的特征 ({len(feature_names)}): {', '.join(feature_names)}")
            
            if feature_names:
                # 检查特征顺序是否一致
                if feature_names != model_features:
                    logger.warning("特征顺序不匹配! 尝试重新排序...")
                    try:
                        # 验证所有特征都存在
                        missing_features = set(model_features) - set(feature_names)
                        if missing_features:
                            logger.error(f"缺失特征: {', '.join(missing_features)}")
                            return None
                        
                        # 重新排序特征以匹配模型期望
                        sorted_indices = [feature_names.index(f) for f in model_features]
                        features = features[:, sorted_indices]
                        feature_names = [feature_names[i] for i in sorted_indices]
                        logger.info("特征重新排序完成")
                    except Exception as e:
                        logger.error(f"特征重新排序失败: {str(e)}")
                        return None
        
        # 预测 (分批处理避免内存溢出)
        predictions = np.zeros(num_valid, dtype=np.float32)
        chunk_size = 500000  # 增大块大小以提升效率
        chunks = (num_valid + chunk_size - 1) // chunk_size
        
        for chunk_idx in tqdm(range(chunks), desc="模型预测(分块)", leave=False):
            start = chunk_idx * chunk_size
            end = min((chunk_idx + 1) * chunk_size, num_valid)
            X_chunk = features[start:end]
            predictions[start:end] = MODEL.predict(X_chunk)
        
        # 填充结果数组
        vwc_array[valid_indices] = predictions
        
        # 添加预测统计
        valid_predictions = predictions[~np.isnan(predictions)]
        if len(valid_predictions) > 0:
            logger.info(f"预测值统计: min={np.min(valid_predictions):.4f}, "
                        f"max={np.max(valid_predictions):.4f}, "
                        f"mean={np.mean(valid_predictions):.4f}")
        else:
            logger.warning("无有效预测值")
        
        return vwc_array
        
    except Exception as e:
        logger.error(f"预测失败: {str(e)}")
        return None

def process_one_date(date, land_mask, overwrite=False):
    """处理单日数据，使用新模型（6VOD+LAI+PFTs）"""
    # 1. 输出文件路径
    output_filename = f'VWC-{date.strftime("%Y%m%d")}.tif'
    output_path = os.path.join(OUTPUT_PATH, output_filename)
    
    # 检查文件是否已存在
    if os.path.exists(output_path) and not overwrite:
        logger.info(f"文件已存在: {output_path} - 跳过")
        return True
    
    # 如果存在且需覆盖则删除
    if os.path.exists(output_path):
        try:
            os.remove(output_path)
            logger.info(f"已删除现有文件: {output_path}")
        except Exception as e:
            logger.error(f"删除文件失败: {output_path} - {str(e)}")
            return False
    
    logger.info(f"处理日期: {date.strftime('%Y-%m-%d')}")
    
    try:
        # 2. 加载VOD数据
        vod_file = get_vod_file(date)
        if not vod_file or not os.path.exists(vod_file):
            logger.error(f"VOD文件未找到: {date}")
            return False
            
        logger.info(f"加载VOD文件: {vod_file}")
        vod_data = {}
        with h5py.File(vod_file, 'r') as f:
            # 定义MAT文件中VOD变量的实际名称
            vod_mat_mapping = {
                'ku_vod_H': 'VOD_Ku_Hpol_Asc',
                'ku_vod_V': 'VOD_Ku_Vpol_Asc',
                'x_vod_H': 'VOD_X_Hpol_Asc',
                'x_vod_V': 'VOD_X_Vpol_Asc',
                'c_vod_H': 'VOD_C_Hpol_Asc',
                'c_vod_V': 'VOD_C_Vpol_Asc'
            }
            
            # 加载所有VOD数据集
            for key in f.keys():
                # 检查是否在映射中
                if key in vod_mat_mapping:
                    # 使用映射后的名称
                    new_key = vod_mat_mapping[key]
                    vod_data[new_key] = np.array(f[key][:]).T
                    logger.debug(f"加载VOD变量: {key} -> {new_key}")
                elif key.lower() in vod_mat_mapping:
                    # 处理小写变体
                    new_key = vod_mat_mapping[key.lower()]
                    vod_data[new_key] = np.array(f[key][:]).T
                    logger.debug(f"加载VOD变量: {key} -> {new_key}")
                else:
                    # 其他变量暂时不加载
                    pass
            
            # 加载QC数据
            qc_data = np.array(f['QC'][:,:]).T
        
        # 3. 创建有效掩膜
        qc_mask = qc_data == 0  # QC=0 表示有效
        lai_data = None  # 将在下一步加载
        valid_mask = qc_mask & land_mask  # 先加入QC和陆地掩码
        
        # 4. 加载并插值LAI数据（使用缓存）
        try:
            # 获取前后月中日期
            prev_mid, next_mid = get_month_centers(date)
            
            # 检查缓存 (前月)
            prev_key = (prev_mid.year, prev_mid.month)
            if prev_key in LAI_MONTH_CACHE:
                logger.info(f"使用缓存的前月LAI: {prev_key[0]}-{prev_key[1]:02d}")
                lai_prev_data = LAI_MONTH_CACHE[prev_key]
            else:
                lai_prev_file = get_lai_file(prev_mid.year, prev_mid.month)
                lai_prev_data = load_lai_matrix(lai_prev_file)
                if lai_prev_data is None:
                    logger.error(f"前月LAI文件加载失败: {prev_key[0]}-{prev_key[1]:02d}")
                    return False
                LAI_MONTH_CACHE[prev_key] = lai_prev_data
            
            # 检查缓存 (后月)
            next_key = (next_mid.year, next_mid.month)
            if next_key in LAI_MONTH_CACHE:
                logger.info(f"使用缓存的后月LAI: {next_key[0]}-{next_key[1]:02d}")
                lai_next_data = LAI_MONTH_CACHE[next_key]
            else:
                lai_next_file = get_lai_file(next_mid.year, next_mid.month)
                lai_next_data = load_lai_matrix(lai_next_file)
                if lai_next_data is None:
                    logger.error(f"后月LAI文件加载失败: {next_key[0]}-{next_key[1]:02d}")
                    return False
                LAI_MONTH_CACHE[next_key] = lai_next_data
            
            # LAI插值
            total_days = (next_mid - prev_mid).days
            current_offset = (date - prev_mid).days
            weight = current_offset / total_days
            lai_data = lai_prev_data * (1 - weight) + lai_next_data * weight
            lai_data = np.nan_to_num(lai_data, nan=0.0)
            
            # 更新有效掩膜
            lai_mask = ~np.isnan(lai_data)  # 非NaN的LAI值有效
            valid_mask = valid_mask & lai_mask
        except Exception as e:
            logger.error(f"LAI处理失败: {str(e)}", exc_info=True)
            return False
        
        # 5. 加载PFT数据（使用缓存）
        year_key = date.year
        if year_key in PFT_YEAR_CACHE:
            logger.info(f"使用缓存的{year_key}年PFT数据")
            pft_data = PFT_YEAR_CACHE[year_key]
        else:
            pft_file = get_pft_file(year_key)
            pft_data = load_pft_matrix(pft_file)
            if pft_data is None:
                logger.error(f"PFT文件加载失败: {year_key}")
                return False
            PFT_YEAR_CACHE[year_key] = pft_data
        
        # 6. 准备特征（6VOD+LAI+PFTs）
        num_valid = np.count_nonzero(valid_mask)
        logger.info(f"有效数据点数量: {num_valid} (比例: {num_valid/(1800 * 3600)*100:.2f}%)")
        
        features, feature_names, valid_indices = prepare_features(
            vod_data, lai_data, pft_data, valid_mask
        )
        
        if features is None or num_valid == 0:
            logger.warning("无有效数据点，跳过预测")
            return True
        
        # 7. 预测VWC
        vwc_array = predict_vwc(features, feature_names, 1800, 3600, valid_mask)
        
        if vwc_array is None:
            logger.error("VWC预测失败")
            return False
        
        # 8. 创建单波段TIFF
        success = create_singleband_geotiff(vwc_array, output_path)
        if success:
            logger.info(f"成功保存单波段TIFF: {output_path}")
            return True
        else:
            logger.error(f"保存失败: {output_path}")
            return False
            
    except Exception as e:
        logger.error(f"处理日期 {date} 错误: {str(e)}", exc_info=True)
        return False

# ============================== 主处理流程 ==============================
def generate_vwc(start_date, end_date, overwrite=False):
    """生成VWC影像（并行优化版本）"""
    # 创建陆地掩膜（一次性）
    logger.info("创建陆地掩膜...")
    land_mask = create_land_mask()
    logger.info(f"陆地掩膜创建完成: 有效点={np.count_nonzero(land_mask)}({np.count_nonzero(land_mask)/(1800 * 3600)*100:.2f}%)")
    
    # 日期序列
    dates = [start_date + timedelta(days=i) 
             for i in range((end_date - start_date).days + 1)]
    
    # 并行处理（线程池）
    max_workers = 4  # 根据CPU物理核心数调整
    futures = []
    completed, failed = 0, 0
    
    logger.info(f"开始并行处理{len(dates)}天数据，最大线程数: {max_workers}")
    
    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        # 提交任务
        for date in dates:
            futures.append(executor.submit(process_one_date, date, land_mask, overwrite))
        
        # 进度监控
        for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="生成VWC影像"):
            try:
                result = future.result()
                if result:
                    completed += 1
                else:
                    failed += 1
            except Exception as e:
                logger.error(f"任务失败: {str(e)}", exc_info=True)
                failed += 1
    
    # 清理缓存
    LAI_MONTH_CACHE.clear()
    PFT_YEAR_CACHE.clear()
    
    logger.info(f"处理完成: {completed}天成功, {failed}天失败")
    return completed, failed

# ============================== 执行主函数 ==============================
if __name__ == "__main__":
    logger.info(f"当前目录: {os.getcwd()}")
    logger.info(f"输出目录: {OUTPUT_PATH}")
    
    # 设置时间范围
    start_date = datetime(2015, 1, 1)
    end_date = datetime(2015, 12, 31)
    
    logger.info(f"开始处理: {start_date.strftime('%Y-%m-%d')} 到 {end_date.strftime('%Y-%m-%d')}")
    
    # 运行主处理流程
    completed, failed = generate_vwc(start_date, end_date, overwrite=True)
    logger.info(f"VWC影像生成完成: 成功 {completed} 天, 失败 {failed} 天")

2025-08-18 16:24:59,127 - INFO - 成功加载新模型: D:\Python\jupyter\VWC_RFRegression\models\RFR_6VOD_LAI_PFTs.pkl
2025-08-18 16:24:59,130 - INFO - 模型特征数: 17
2025-08-18 16:24:59,131 - INFO - 模型期望的特征顺序: VOD_Ku_Hpol_Asc, VOD_Ku_Vpol_Asc, VOD_X_Hpol_Asc, VOD_X_Vpol_Asc, VOD_C_Hpol_Asc, VOD_C_Vpol_Asc, LAI, Grass_man, Grass_nat, Shrub_bd, Shrub_be, Shrub_nd, Shrub_ne, Tree_bd, Tree_be, Tree_nd, Tree_ne
2025-08-18 16:24:59,138 - INFO - 当前目录: D:\Python\jupyter\VWC_RFRegression
2025-08-18 16:24:59,138 - INFO - 输出目录: E:\data\VWC\VWCMap\Daily
2025-08-18 16:24:59,138 - INFO - 开始处理: 2015-01-01 到 2015-12-31
2025-08-18 16:24:59,141 - INFO - 创建陆地掩膜...
2025-08-18 16:24:59,145 - INFO - 陆地掩膜创建完成: 有效点=5040000(77.78%)
2025-08-18 16:24:59,147 - INFO - 开始并行处理365天数据，最大线程数: 4
2025-08-18 16:24:59,149 - INFO - 处理日期: 2015-01-01
2025-08-18 16:24:59,152 - INFO - 处理日期: 2015-01-02
2025-08-18 16:24:59,152 - INFO - 处理日期: 2015-01-03
2025-08-18 16:24:59,158 - INFO - 处理日期: 2015-01-04
2025-08-18 16:24:59,159 - INFO - 加载VOD文件: E:\

In [2]:
# 地物类型掩膜（适配单波段VWC文件）
import os
import numpy as np
from osgeo import gdal, osr
import logging
from tqdm import tqdm
import concurrent.futures
import multiprocessing
import time

# ============================== 配置日志 ==============================
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(),
        logging.FileHandler('vwc_mask_water_bare_snowice_optimized.log')
    ]
)
logger = logging.getLogger()

# ============================== 配置参数 ==============================
VWC_DIR = r'E:\data\VWC\VWCMap\Daily'  # 每日VWC数据目录
MAIN_TYPE_FILE = r'E:\data\ESACCI PFT\Resample\Data\mainType.tif'  # 主地物类型文件
NODATA_VALUE = -9999.0  # 无效值
NUM_WORKERS = multiprocessing.cpu_count()  # 使用所有CPU核心

# 需要掩膜的地物类型 (水、裸地、冰雪)
MASK_TYPES = [0, 1, 2]  # 0=water, 1=bare, 2=snowice

# ============================== 辅助函数 ==============================
def load_main_type_mask():
    """加载主地物类型掩膜 (优化版本)"""
    try:
        # 使用内存映射提高读取速度
        ds = gdal.Open(MAIN_TYPE_FILE, gdal.GA_ReadOnly)
        if ds is None:
            logger.error(f"无法打开主地物类型文件: {MAIN_TYPE_FILE}")
            return None
        
        # 验证分辨率
        expected_res = 0.1
        actual_res_x = ds.GetGeoTransform()[1]
        actual_res_y = -ds.GetGeoTransform()[5]  # 取绝对值
        
        if abs(actual_res_x - expected_res) > 1e-6 or abs(actual_res_y - expected_res) > 1e-6:
            logger.warning(f"主地物类型文件分辨率不匹配! 期望: {expected_res}, 实际: X={actual_res_x}, Y={actual_res_y}")
        
        # 读取整个波段到内存
        band = ds.GetRasterBand(1)
        main_type = band.ReadAsArray()
        
        # 创建掩膜 (需要掩膜的类型为True)
        mask = np.isin(main_type, MASK_TYPES)
        
        # 获取NoData值并处理
        nodata = band.GetNoDataValue()
        if nodata is not None:
            # 排除NoData区域
            mask = np.logical_and(mask, main_type != nodata)
        
        ds = None
        logger.info(f"成功加载主地物类型掩膜: {mask.shape}")
        return mask
    except Exception as e:
        logger.error(f"加载主地物类型文件失败: {str(e)}")
        return None

def process_vwc_file(file_path, mask):
    """处理单个VWC文件 (适配单波段)"""
    try:
        start_time = time.time()
        logger.debug(f"开始处理: {file_path}")
        
        # 以读写模式打开文件
        ds = gdal.Open(file_path, gdal.GA_Update)
        if ds is None:
            logger.warning(f"无法打开VWC文件: {file_path}")
            return False
        
        # 获取文件信息
        rows = ds.RasterYSize
        cols = ds.RasterXSize
        num_bands = ds.RasterCount
        
        # 验证为单波段文件
        if num_bands != 1:
            logger.warning(f"文件波段数异常: {num_bands} (应为1), 文件: {file_path}")
        
        # 验证尺寸匹配
        if rows != mask.shape[0] or cols != mask.shape[1]:
            logger.error(f"尺寸不匹配: VWC文件({rows}x{cols}) vs 掩膜({mask.shape[0]}x{mask.shape[1]})")
            ds = None
            return False
        
        # 处理唯一波段
        band = ds.GetRasterBand(1)
        data = band.ReadAsArray()
        
        # 应用掩膜
        data[mask] = NODATA_VALUE
        
        # 写回数据
        band.WriteArray(data)
        
        # 设置NoData值
        band.SetNoDataValue(NODATA_VALUE)
        
        # 清理
        ds.FlushCache()
        ds = None
        
        process_time = time.time() - start_time
        logger.info(f"处理完成: {file_path} (耗时: {process_time:.2f}s)")
        return True
    except Exception as e:
        logger.error(f"处理文件 {file_path} 失败: {str(e)}", exc_info=True)
        return False

def process_vwc_files_parallel(mask):
    """并行处理所有VWC文件"""
    # 获取所有VWC文件
    vwc_files = []
    for filename in os.listdir(VWC_DIR):
        # 只处理新格式的VWC文件
        if filename.startswith('VWC-') and filename.endswith('.tif'):
            file_path = os.path.join(VWC_DIR, filename)
            vwc_files.append(file_path)
    
    if not vwc_files:
        logger.error("未找到任何VWC文件")
        return 0
    
    total_files = len(vwc_files)
    logger.info(f"找到 {total_files} 个VWC文件，使用 {NUM_WORKERS} 个线程并行处理")
    
    # 使用线程池并行处理
    success_count = 0
    with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor:
        # 提交所有任务
        future_to_file = {executor.submit(process_vwc_file, file_path, mask): file_path 
                          for file_path in vwc_files}
        
        # 使用tqdm显示进度
        for future in tqdm(concurrent.futures.as_completed(future_to_file), 
                          total=len(vwc_files), desc="处理VWC文件"):
            file_path = future_to_file[future]
            try:
                result = future.result()
                if result:
                    success_count += 1
            except Exception as e:
                logger.error(f"处理文件 {file_path} 时出错: {str(e)}")
    
    return success_count, total_files

# ============================== 主处理流程 ==============================
def main():
    logger.info(f"VWC目录: {VWC_DIR}")
    logger.info(f"主地物类型文件: {MAIN_TYPE_FILE}")
    logger.info(f"需要掩膜的地物类型: {MASK_TYPES} (水、裸地、冰雪)")
    logger.info(f"使用 {NUM_WORKERS} 个线程进行并行处理")
    
    start_time = time.time()
    
    # 加载主地物类型掩膜
    mask = load_main_type_mask()
    if mask is None:
        logger.error("无法加载掩膜，程序终止")
        return
    
    # 处理所有VWC文件
    success_count, total_files = process_vwc_files_parallel(mask)
    
    total_time = time.time() - start_time
    logger.info(f"处理完成! 成功处理 {success_count}/{total_files} 个文件")
    logger.info(f"总耗时: {total_time:.2f}秒")
    logger.info(f"平均每文件耗时: {total_time/total_files:.4f}秒" if total_files > 0 else "")

# ============================== 执行主函数 ==============================
if __name__ == "__main__":
    try:
        main()
        logger.info("VWC文件掩膜处理完成")
    except Exception as e:
        logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True)

2025-08-18 16:45:52,887 - INFO - VWC目录: E:\data\VWC\VWCMap\Daily
2025-08-18 16:45:52,887 - INFO - 主地物类型文件: E:\data\ESACCI PFT\Resample\Data\mainType.tif
2025-08-18 16:45:52,889 - INFO - 需要掩膜的地物类型: [0, 1, 2] (水、裸地、冰雪)
2025-08-18 16:45:52,889 - INFO - 使用 16 个线程进行并行处理
2025-08-18 16:46:04,081 - INFO - 成功加载主地物类型掩膜: (1800, 3600)
2025-08-18 16:46:04,086 - INFO - 找到 362 个VWC文件，使用 16 个线程并行处理
2025-08-18 16:46:04,456 - INFO - 处理完成: E:\data\VWC\VWCMap\Daily\VWC-20150101.tif (耗时: 0.37s)
2025-08-18 16:46:04,467 - INFO - 处理完成: E:\data\VWC\VWCMap\Daily\VWC-20150110.tif (耗时: 0.36s)
2025-08-18 16:46:04,471 - INFO - 处理完成: E:\data\VWC\VWCMap\Daily\VWC-20150103.tif (耗时: 0.39s)
2025-08-18 16:46:04,477 - INFO - 处理完成: E:\data\VWC\VWCMap\Daily\VWC-20150111.tif (耗时: 0.35s)
2025-08-18 16:46:04,589 - INFO - 处理完成: E:\data\VWC\VWCMap\Daily\VWC-20150107.tif (耗时: 0.49s)
2025-08-18 16:46:04,611 - INFO - 处理完成: E:\data\VWC\VWCMap\Daily\VWC-20150102.tif (耗时: 0.52s)
2025-08-18 16:46:04,619 - INFO - 处理完成: E:\data\VWC\VWCMa

## 8日合成数据

In [3]:
import os
import numpy as np
from osgeo import gdal, osr
import datetime
import re
import logging
from tqdm import tqdm

# ============================== 配置日志 ==============================
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(),
        logging.FileHandler('vwc_8day_composite_corrected.log')
    ]
)
logger = logging.getLogger()

# ============================== 配置参数 ==============================
INPUT_DIR = r'E:\data\VWC\VWCMap\Daily'  # 每日VWC数据目录
OUTPUT_DIR = r'E:\data\VWC\VWCMap\8Day'  # 8日合成输出目录
START_YEAR = 2015
END_YEAR = 2015
NODATA_VALUE = -9999.0
OVERWRITE_EXISTING = True  # 是否覆盖已存在的合成文件

# 确保输出目录存在
os.makedirs(OUTPUT_DIR, exist_ok=True)

# ============================== 辅助函数 ==============================
def get_daily_files():
    """获取所有每日文件，并按日期排序"""
    files = []
    pattern = re.compile(r'VWC-(\d{4})(\d{2})(\d{2})\.tif$')  # 匹配YYYYMMDD格式的文件名
    
    for filename in os.listdir(INPUT_DIR):
        if not filename.endswith('.tif'):
            continue
            
        match = pattern.match(filename)
        if not match:
            continue
            
        year = int(match.group(1))
        month = int(match.group(2))
        day = int(match.group(3))
        file_date = datetime.date(year, month, day)
        
        # 只处理指定年份范围内的文件
        if START_YEAR <= year <= END_YEAR:
            file_path = os.path.join(INPUT_DIR, filename)
            files.append({
                'path': file_path,
                'date': file_date,
                'year': year
            })
    
    # 按日期排序
    files.sort(key=lambda x: x['date'])
    return files

def create_composite_geotiff(data, output_path, start_date, end_date, nodata=NODATA_VALUE):
    """创建8日合成的地理参考TIFF文件（单波段）"""
    try:
        driver = gdal.GetDriverByName('GTiff')
        rows, cols = data.shape
        
        # 创建输出数据集
        out_ds = driver.Create(
            output_path,
            cols,
            rows,
            1,  # 单波段
            gdal.GDT_Float32,
            options=['COMPRESS=LZW', 'TILED=YES']
        )
        
        # 设置地理变换
        out_ds.SetGeoTransform((-180, 0.1, 0, 90, 0, -0.1))
        
        # 设置坐标系 (WGS84)
        srs = osr.SpatialReference()
        srs.ImportFromEPSG(4326)
        out_ds.SetProjection(srs.ExportToWkt())
        
        # 添加日期元数据
        out_ds.SetMetadata({
            'START_DATE': start_date.strftime('%Y-%m-%d'),
            'END_DATE': end_date.strftime('%Y-%m-%d'),
            'NUM_DAYS': str((end_date - start_date).days + 1)
        })
        
        # 写入波段数据
        band = out_ds.GetRasterBand(1)
        band.WriteArray(data)
        band.SetNoDataValue(nodata)
        band.SetDescription('VWC')
        
        # 清理
        out_ds.FlushCache()
        out_ds = None
        logger.info(f"成功创建8日合成GeoTIFF: {output_path}")
        return True
    except Exception as e:
        logger.error(f"创建GeoTIFF失败: {str(e)}")
        return False

def generate_8day_composites():
    """生成8日合成数据 (按年处理)"""
    # 获取所有每日文件
    daily_files = get_daily_files()
    if not daily_files:
        logger.error("未找到任何每日文件")
        return
    
    logger.info(f"找到 {len(daily_files)} 个每日文件")
    
    # 按年份分组
    files_by_year = {}
    for year in range(START_YEAR, END_YEAR + 1):
        files_by_year[year] = [f for f in daily_files if f['year'] == year]
    
    # 处理每一年
    for year in range(START_YEAR, END_YEAR + 1):
        year_files = files_by_year[year]
        if not year_files:
            logger.warning(f"{year}年没有每日文件")
            continue
            
        logger.info(f"处理 {year} 年 (共 {len(year_files)} 个每日文件)")
        
        # 确定该年的起始和结束日期
        start_date = datetime.date(year, 1, 1)
        end_date = datetime.date(year, 12, 31)
        
        # 创建8日周期
        current_start = start_date
        composite_groups = []
        
        # 构建8日周期 (严格限制在该年内)
        while current_start <= end_date:
            # 计算周期结束日期 (最多7天后)
            current_end = current_start + datetime.timedelta(days=7)
            
            # 如果结束日期超出该年范围，调整到12月31日
            if current_end > end_date:
                current_end = end_date
                
            # 收集当前周期内的文件
            group_files = []
            for file_info in year_files:
                if current_start <= file_info['date'] <= current_end:
                    group_files.append(file_info)
            
            if group_files:
                composite_groups.append({
                    'start_date': current_start,
                    'end_date': current_end,
                    'files': group_files
                })
            
            # 移动到下一个周期 (从结束日期的下一天开始)
            current_start = current_end + datetime.timedelta(days=1)
            
            # 如果下一个起始日期已超出该年范围，结束循环
            if current_start > end_date:
                break
        
        logger.info(f"{year}年共生成 {len(composite_groups)} 个8日周期")
        
        # 处理该年的每个8日周期
        for group in tqdm(composite_groups, desc=f"处理{year}年周期"):
            # 使用起始日期作为文件名标识
            start_str = group['start_date'].strftime('%Y%m%d')
            end_str = group['end_date'].strftime('%Y%m%d')
            
            # 文件名格式：VWC-YYYYMMDD.tif（使用起始日期）
            output_filename = f'VWC-{start_str}.tif'
            output_path = os.path.join(OUTPUT_DIR, output_filename)
            
            # 检查文件是否已存在
            if os.path.exists(output_path):
                if OVERWRITE_EXISTING:
                    try:
                        os.remove(output_path)
                        logger.info(f"已删除现有文件: {output_path}")
                    except Exception as e:
                        logger.error(f"删除文件失败: {output_path} - {str(e)}")
                        continue
                else:
                    logger.info(f"文件已存在: {output_path} - 跳过")
                    continue
            
            num_days = len(group['files'])
            logger.info(f"处理周期: {start_str} 到 {end_str} ({num_days}天)")
            
            # 初始化数据数组
            composite_sum = None
            valid_count = None
            rows, cols = 0, 0
            
            # 处理周期内的每个文件
            for file_info in group['files']:
                try:
                    ds = gdal.Open(file_info['path'])
                    if ds is None:
                        logger.warning(f"无法打开文件: {file_info['path']}")
                        continue
                    
                    # 获取图像尺寸（只在第一次确定）
                    if rows == 0:
                        rows = ds.RasterYSize
                        cols = ds.RasterXSize
                        composite_sum = np.zeros((rows, cols), dtype=np.float32)
                        valid_count = np.zeros((rows, cols), dtype=np.uint16)
                    
                    # 读取唯一波段
                    band = ds.GetRasterBand(1)
                    data = band.ReadAsArray()
                    nodata = band.GetNoDataValue()
                    
                    if nodata is None:
                        nodata = NODATA_VALUE
                    
                    # 创建有效值掩膜
                    valid_mask = (data != nodata) & (~np.isnan(data))
                    
                    # 累加有效值
                    composite_sum += np.where(valid_mask, data, 0)
                    valid_count += valid_mask.astype(np.uint16)
                    
                    ds = None
                except Exception as e:
                    logger.error(f"处理文件 {file_info['path']} 错误: {str(e)}")
            
            # 计算平均值
            if composite_sum is not None and valid_count is not None:
                # 避免除以零
                with np.errstate(divide='ignore', invalid='ignore'):
                    composite_avg = np.where(
                        valid_count > 0,
                        composite_sum / valid_count,
                        NODATA_VALUE  # 无效值
                    )
                
                # 保存合成结果
                success = create_composite_geotiff(
                    composite_avg, 
                    output_path,
                    group['start_date'],
                    group['end_date']
                )
                if success:
                    # 计算有效数据百分比
                    valid_percent = (valid_count > 0).sum() / (rows * cols) * 100
                    logger.info(f"8日合成完成: {start_str} 到 {end_str} (有效数据: {valid_percent:.2f}%)")
                else:
                    logger.error(f"保存失败: {output_path}")
            else:
                logger.warning(f"无有效数据可合成: {start_str} 到 {end_str}")
    
    logger.info("8日合成处理完成")

# ============================== 执行主函数 ==============================
if __name__ == "__main__":
    logger.info(f"输入目录: {INPUT_DIR}")
    logger.info(f"输出目录: {OUTPUT_DIR}")
    logger.info(f"生成8日合成数据: {START_YEAR} 到 {END_YEAR}")
    logger.info(f"覆盖模式: {'开启' if OVERWRITE_EXISTING else '关闭'}")
    
    try:
        generate_8day_composites()
        logger.info("8日合成数据生成完成")
    except Exception as e:
        logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True)

2025-08-18 16:46:14,441 - INFO - 输入目录: E:\data\VWC\VWCMap\Daily
2025-08-18 16:46:14,442 - INFO - 输出目录: E:\data\VWC\VWCMap\8Day
2025-08-18 16:46:14,443 - INFO - 生成8日合成数据: 2015 到 2016
2025-08-18 16:46:14,444 - INFO - 覆盖模式: 开启
2025-08-18 16:46:14,448 - INFO - 找到 362 个每日文件
2025-08-18 16:46:14,449 - INFO - 处理 2015 年 (共 362 个每日文件)
2025-08-18 16:46:14,451 - INFO - 2015年共生成 46 个8日周期
2025-08-18 16:46:14,454 - INFO - 处理周期: 20150101 到 20150108 (8天)/46 [00:00<?, ?it/s]
2025-08-18 16:46:15,470 - INFO - 成功创建8日合成GeoTIFF: E:\data\VWC\VWCMap\8Day\VWC-20150101.tif
2025-08-18 16:46:15,489 - INFO - 8日合成完成: 20150101 到 20150108 (有效数据: 9.68%)
2025-08-18 16:46:15,493 - INFO - 处理周期: 20150109 到 20150116 (8天)01<00:46,  1.04s/it]
2025-08-18 16:46:16,501 - INFO - 成功创建8日合成GeoTIFF: E:\data\VWC\VWCMap\8Day\VWC-20150109.tif
2025-08-18 16:46:16,512 - INFO - 8日合成完成: 20150109 到 20150116 (有效数据: 9.57%)
2025-08-18 16:46:16,516 - INFO - 处理周期: 20150117 到 20150124 (8天)02<00:45,  1.03s/it]
2025-08-18 16:46:17,522 - INFO - 成功创建8

## 月度数据合成

In [4]:
# 月度数据合成（适配单波段VWC图像）
import os
import numpy as np
from osgeo import gdal, osr
import datetime
import re
import logging
from tqdm import tqdm

# ============================== 配置日志 ==============================
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(),
        logging.FileHandler('vwc_monthly_composite.log')
    ]
)
logger = logging.getLogger()

# ============================== 配置参数 ==============================
INPUT_DIR = r'E:\data\VWC\VWCMap\Daily'  # 每日VWC数据目录
OUTPUT_DIR = r'E:\data\VWC\VWCMap\Monthly'  # 月度合成输出目录
START_YEAR = 2015
END_YEAR = 2015
NODATA_VALUE = -9999.0
OVERWRITE_EXISTING = True  # 是否覆盖已存在的月度合成文件

# 确保输出目录存在
os.makedirs(OUTPUT_DIR, exist_ok=True)

# ============================== 辅助函数 ==============================
def get_daily_files_for_month(year, month):
    """获取指定年月的所有每日文件（仅TIFF格式）"""
    files = []
    pattern = re.compile(r'VWC-(\d{4})(\d{2})(\d{2})\.tif$')  # 匹配YYYYMMDD格式的文件名
    
    for filename in os.listdir(INPUT_DIR):
        if not filename.endswith('.tif'):
            continue
            
        match = pattern.match(filename)
        if not match:
            continue
            
        file_year = int(match.group(1))
        file_month = int(match.group(2))
        file_day = int(match.group(3))
        
        if file_year == year and file_month == month:
            file_path = os.path.join(INPUT_DIR, filename)
            files.append({
                'path': file_path,
                'date': datetime.date(file_year, file_month, file_day)
            })
    
    # 按日期排序
    files.sort(key=lambda x: x['date'])
    return files

def create_monthly_geotiff(data, output_path, year, month, nodata=NODATA_VALUE):
    """创建地理参考的月度TIFF文件（单波段）"""
    try:
        driver = gdal.GetDriverByName('GTiff')
        rows, cols = data.shape
        
        # 创建数据集
        out_ds = driver.Create(
            output_path, 
            cols, 
            rows, 
            1,  # 单波段
            gdal.GDT_Float32,
            options=['COMPRESS=LZW', 'TILED=YES']
        )
        
        # 设置地理变换
        out_ds.SetGeoTransform((-180, 0.1, 0, 90, 0, -0.1))
        
        # 设置坐标系 (WGS84)
        srs = osr.SpatialReference()
        srs.ImportFromEPSG(4326)
        out_ds.SetProjection(srs.ExportToWkt())
        
        # 添加日期元数据
        out_ds.SetMetadata({
            'YEAR': str(year),
            'MONTH': str(month),
            'PRODUCTION_DATE': datetime.date.today().isoformat()
        })
        
        # 写入波段数据
        band = out_ds.GetRasterBand(1)
        band.WriteArray(data)
        band.SetNoDataValue(nodata)
        band.SetDescription('VWC')
        
        # 清理
        out_ds.FlushCache()
        out_ds = None
        logger.info(f"成功创建月度合成GeoTIFF: {output_path}")
        return True
    except Exception as e:
        logger.error(f"创建GeoTIFF失败: {str(e)}")
        return False

def generate_monthly_composites():
    """生成月度合成数据"""
    total_months = (END_YEAR - START_YEAR + 1) * 12
    processed = 0
    
    # 进度条初始化
    pbar = tqdm(total=total_months, desc="月度合成进度")
    
    for year in range(START_YEAR, END_YEAR + 1):
        for month in range(1, 13):
            # 输出文件名 - 格式：VWC-YYYYMM.tif
            output_filename = f'VWC-{year}{str(month).zfill(2)}.tif'
            output_path = os.path.join(OUTPUT_DIR, output_filename)
            
            # 检查文件是否已存在
            if os.path.exists(output_path):
                if OVERWRITE_EXISTING:
                    try:
                        os.remove(output_path)
                        logger.info(f"已删除现有文件: {output_path}")
                    except Exception as e:
                        logger.error(f"删除文件失败: {output_path} - {str(e)}")
                        pbar.update(1)
                        processed += 1
                        continue
                else:
                    logger.info(f"月度合成已存在: {output_path} - 跳过")
                    pbar.update(1)
                    processed += 1
                    continue
            
            # 获取该月的所有每日文件
            files = get_daily_files_for_month(year, month)
            
            if not files:
                logger.warning(f"在{year}年{month}月未找到任何每日文件")
                pbar.update(1)
                processed += 1
                continue
            
            logger.info(f"生成{year}年{month}月合成 ({len(files)}个每日文件)")
            
            # 初始化数据数组
            monthly_sum = None
            valid_count = None
            rows, cols = 0, 0
            
            # 用于存储每日数据的列表（用于中值计算）
            daily_arrays = []
            
            # 处理每个文件
            for file_info in files:
                try:
                    # 打开文件
                    ds = gdal.Open(file_info['path'])
                    if ds is None:
                        logger.warning(f"无法打开文件: {file_info['path']}")
                        continue
                    
                    # 获取图像尺寸（只在第一次确定）
                    if rows == 0:
                        rows = ds.RasterYSize
                        cols = ds.RasterXSize
                        monthly_sum = np.zeros((rows, cols), dtype=np.float32)
                        valid_count = np.zeros((rows, cols), dtype=np.uint16)
                    
                    # 读取唯一波段
                    band = ds.GetRasterBand(1)
                    data = band.ReadAsArray()
                    nodata = band.GetNoDataValue()
                    
                    if nodata is None:
                        nodata = NODATA_VALUE
                    
                    # 创建有效值掩膜
                    valid_mask = (data != nodata) & (~np.isnan(data))
                    
                    # 累加有效值
                    monthly_sum += np.where(valid_mask, data, 0)
                    valid_count += valid_mask.astype(np.uint16)
                    
                    # 收集每日数据用于中值计算
                    daily_arrays.append(data)
                    
                    # 关闭数据集
                    ds = None
                    
                except Exception as e:
                    logger.error(f"处理文件 {file_info['path']} 错误: {str(e)}")
            
            # 计算月度平均值 (避免除以零)
            if monthly_sum is not None and valid_count is not None:
                # 计算平均值
                with np.errstate(divide='ignore', invalid='ignore'):
                    monthly_avg = np.where(
                        valid_count > 0,
                        monthly_sum / valid_count,
                        NODATA_VALUE  # 无效值
                    )
                
                # 计算中值（更健壮的指标）
                if daily_arrays:
                    # 创建数据堆栈
                    stack = np.stack(daily_arrays, axis=0)
                    
                    # 创建有效值掩膜
                    valid_mask = (stack != NODATA_VALUE) & (~np.isnan(stack))
                    
                    # 计算中值
                    monthly_median = np.full((rows, cols), NODATA_VALUE, dtype=np.float32)
                    for i in range(rows):
                        for j in range(cols):
                            # 获取该像素的所有有效值
                            pixel_values = stack[:, i, j][valid_mask[:, i, j]]
                            if len(pixel_values) > 0:
                                monthly_median[i, j] = np.median(pixel_values)
                
                # 保存合成结果 - 使用中值作为更健壮的指标
                success = create_monthly_geotiff(monthly_median, output_path, year, month)
                
                if success:
                    # 输出质量报告
                    valid_percent = (valid_count > 0).sum() / (rows * cols) * 100
                    logger.info(f"月度合成完成: {year}年{month}月, 有效数据: {valid_percent:.2f}%")
                else:
                    logger.error(f"保存月度合成失败: {output_path}")
            
            # 更新进度条
            pbar.update(1)
            processed += 1
    
    # 关闭进度条
    pbar.close()
    logger.info(f"月度合成处理完成! 共处理 {processed} 个月份")

# ============================== 执行主函数 ==============================
if __name__ == "__main__":
    logger.info(f"输入目录: {INPUT_DIR}")
    logger.info(f"输出目录: {OUTPUT_DIR}")
    logger.info(f"生成月度合成数据: {START_YEAR} 到 {END_YEAR}")
    logger.info(f"覆盖模式: {'开启' if OVERWRITE_EXISTING else '关闭'}")
    
    try:
        generate_monthly_composites()
        logger.info("月度合成数据生成完成")
    except Exception as e:
        logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True)

2025-08-18 16:49:09,349 - INFO - 输入目录: E:\data\VWC\VWCMap\Daily
2025-08-18 16:49:09,351 - INFO - 输出目录: E:\data\VWC\VWCMap\Monthly
2025-08-18 16:49:09,353 - INFO - 生成月度合成数据: 2015 到 2015
2025-08-18 16:49:09,354 - INFO - 覆盖模式: 开启
2025-08-18 16:49:09,359 - INFO - 生成2015年1月合成 (31个每日文件)     | 0/12 [00:00<?, ?it/s]
2025-08-18 16:49:44,507 - INFO - 成功创建月度合成GeoTIFF: E:\data\VWC\VWCMap\Monthly\VWC-201501.tif
2025-08-18 16:49:44,525 - INFO - 月度合成完成: 2015年1月, 有效数据: 10.55%
2025-08-18 16:49:44,531 - INFO - 生成2015年2月合成 (28个每日文件)/12 [00:35<06:26, 35.17s/it]
2025-08-18 16:50:08,669 - INFO - 成功创建月度合成GeoTIFF: E:\data\VWC\VWCMap\Monthly\VWC-201502.tif
2025-08-18 16:50:08,679 - INFO - 月度合成完成: 2015年2月, 有效数据: 10.93%
2025-08-18 16:50:08,685 - INFO - 生成2015年3月合成 (31个每日文件)/12 [00:59<04:46, 28.69s/it]
2025-08-18 16:50:37,381 - INFO - 成功创建月度合成GeoTIFF: E:\data\VWC\VWCMap\Monthly\VWC-201503.tif
2025-08-18 16:50:37,396 - INFO - 月度合成完成: 2015年3月, 有效数据: 13.51%
2025-08-18 16:50:37,404 - INFO - 生成2015年4月合成 (29个每日文件)/12 [

## VWC季节性制图 

In [5]:
import os
import numpy as np
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
from osgeo import gdal, osr
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.gridspec as gridspec
from cartopy.feature import NaturalEarthFeature
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
import matplotlib as mpl
import logging

# ==========================================================
# 配置参数
# ==========================================================
VWC_DIR = r"E:\data\VWC\VWCMap\Monthly"
SPECIFIC_MONTHS = {
    "2015-01": "201501",
    "2015-04": "201504",
    "2015-07": "201507",
    "2015-10": "201510"
}
BAND = 1  # 读取第一个波段(KuH)
OUTPUT_DIR = r"E:\文章\HUITU\Fig"
OUTPUT_NAME = "Global_VWC_KuH_Seasonal_2015_6VOD_LAI_PFT"

# 掩膜相关参数
MAIN_TYPE_FILE = r'E:\data\ESACCI PFT\Resample\Data\mainType.tif'
MASK_TYPES = [0, 1, 2]  # 0=water, 1=bare, 2=snowice
NODATA_VALUE = -9999.0

# ==========================================================
# 掩膜处理函数（使用修复后的方法）
# ==========================================================
def load_main_type_mask():
    """加载主地物类型掩膜（含空间参考校验）"""
    try:
        ds = gdal.Open(MAIN_TYPE_FILE, gdal.GA_ReadOnly)
        if ds is None:
            raise ValueError(f"无法打开主地物类型文件: {MAIN_TYPE_FILE}")
        
        band = ds.GetRasterBand(1)
        main_type = band.ReadAsArray()
        
        # 创建掩膜（需要掩膜的类型为True）
        mask = np.isin(main_type, MASK_TYPES)
        
        # 获取NoData值并处理
        nodata = band.GetNoDataValue()
        if nodata is not None:
            mask = np.logical_and(mask, main_type != nodata)
        
        # 获取空间参考信息
        geotransform = ds.GetGeoTransform()
        projection = ds.GetProjection()
        ds = None
        
        print("成功加载主地物类型掩膜")
        return mask, geotransform, projection
    except Exception as e:
        print(f"加载主地物类型文件失败: {str(e)}")
        return None, None, None

def apply_mask_to_vwc_file(file_path, mask, ref_geotransform, ref_projection):
    """对单个VWC文件应用掩膜（仅处理第一个波段）"""
    try:
        print(f"正在掩膜处理: {os.path.basename(file_path)}")
        
        # 以读写模式打开文件
        ds = gdal.Open(file_path, gdal.GA_Update)
        if ds is None:
            raise ValueError(f"无法打开VWC文件: {file_path}")
        
        # 验证空间参考是否匹配
        if (ds.GetGeoTransform() != ref_geotransform or 
            ds.GetProjection() != ref_projection):
            raise ValueError("空间参考不匹配，请确保文件投影一致")
        
        # 仅处理第一个波段
        band = ds.GetRasterBand(BAND)
        data = band.ReadAsArray()
        
        # 应用掩膜
        data[mask] = NODATA_VALUE
        
        # 写回数据并设置NoData属性
        band.WriteArray(data)
        band.SetNoDataValue(NODATA_VALUE)
        
        # 强制写入磁盘
        ds.FlushCache()
        ds = None
        
        print(f"成功掩膜处理: {os.path.basename(file_path)}")
        return True
    except Exception as e:
        print(f"掩膜处理失败: {os.path.basename(file_path)}，错误: {str(e)}")
        return False

# ==========================================================
# 数据读取函数（保持不变）
# ==========================================================
def read_vwc_tif(file_path, band=1, no_data=-9999):
    """读取VWC TIFF文件并处理无效值"""
    ds = gdal.Open(file_path)
    if ds is None:
        raise FileNotFoundError(f"Cannot open file: {file_path}")
    
    # 读取指定波段
    band = ds.GetRasterBand(band)
    data = band.ReadAsArray()
    
    # 获取地理信息
    geotransform = ds.GetGeoTransform()
    projection = ds.GetProjection()
    
    # 计算地理范围
    x_size = ds.RasterXSize
    y_size = ds.RasterYSize
    lon_min = geotransform[0]
    lat_max = geotransform[3]
    lon_max = lon_min + geotransform[1] * x_size
    lat_min = lat_max + geotransform[5] * y_size
    
    # 处理无效值
    data = data.astype(np.float32)
    data[data == no_data] = np.nan
    
    ds = None
    return data, (lon_min, lon_max, lat_min, lat_max), projection

# ==========================================================
# 创建自定义颜色映射（保持不变）
# ==========================================================
def create_custom_cmap():
    """创建自定义颜色映射"""
    # 使用指定的五种颜色
    colors = [
        '#fe3c19',  # 0 kg/m²
        '#ffac18',  # 5 kg/m²
        '#f2fe2a',  # 10 kg/m²
        '#7cb815',  # 15 kg/m²
        '#147218'   # 20 kg/m²
    ]
    
    # 创建颜色映射
    cmap = LinearSegmentedColormap.from_list('custom_vwc', colors, N=256)
    return cmap

# ==========================================================
# 地图绘制函数（保持不变）
# ==========================================================
def plot_vwc_map(ax, data, extent, month_label, vmin=0, vmax=20):
    """在指定轴对象上绘制VWC地图"""
    # 创建自定义颜色映射
    cmap = create_custom_cmap()
    
    # 添加地图特征
    ax.coastlines(linewidth=0.5, color='gray')
    ax.add_feature(NaturalEarthFeature(category='physical', name='ocean', scale='50m', 
                                      facecolor='lightblue', alpha=0.3))
    ax.add_feature(NaturalEarthFeature(category='cultural', name='admin_0_countries', 
                                      scale='50m', edgecolor='gray', facecolor='none', linewidth=0.3))
    
    # 添加经纬度网格
    gl = ax.gridlines(draw_labels=True, linestyle='--', alpha=0.7)
    gl.top_labels = False
    gl.right_labels = False
    gl.xformatter = LONGITUDE_FORMATTER
    gl.yformatter = LATITUDE_FORMATTER
    
    # 绘制VWC数据
    im = ax.imshow(data, origin='upper', 
                  extent=extent,
                  transform=ccrs.PlateCarree(),
                  cmap=cmap, vmin=vmin, vmax=vmax,
                  interpolation='nearest')
    
    # 添加标题 (格式为YYYY-MM)
    ax.set_title(f"{month_label}", fontsize=12, pad=10)
    
    return im

# ==========================================================
# 主执行函数（添加掩膜处理步骤）
# ==========================================================
def main():
    print("======== 开始掩膜处理 ========")
    
    # 加载主地物类型掩膜
    mask, ref_geotransform, ref_projection = load_main_type_mask()
    if mask is None:
        print("错误：无法加载主掩膜，程序终止")
        return
    
    # 检查掩膜尺寸
    print(f"掩膜尺寸: {mask.shape[0]}x{mask.shape[1]}")
    
    # 对SPECIFIC_MONTHS中的每个文件应用掩膜
    for label, month_code in SPECIFIC_MONTHS.items():
        file_path = os.path.join(VWC_DIR, f"VWC-{month_code}.tif")
        if os.path.exists(file_path):
            apply_mask_to_vwc_file(file_path, mask, ref_geotransform, ref_projection)
        else:
            print(f"警告: 文件不存在 - {file_path}")
    
    print("======== 掩膜处理完成 ========")
    print("开始读取特定月度VWC数据...")
    
    # 准备存储数据
    vwc_data = []
    
    # 读取已掩膜处理的数据
    for label, month_code in SPECIFIC_MONTHS.items():
        file_path = os.path.join(VWC_DIR, f"VWC-{month_code}.tif")
        print(f"读取掩膜后的文件: {file_path}")
        
        try:
            data, extent, projection = read_vwc_tif(file_path, band=BAND, no_data=NODATA_VALUE)
            vwc_data.append({
                "data": data,
                "extent": extent,
                "label": label  # 使用YYYY-MM格式的标签
            })
        except Exception as e:
            print(f"错误读取文件: {file_path}\n{str(e)}")
            continue
    
    # 检查数据完整性
    if not vwc_data:
        print("错误: 无有效数据，无法绘制")
        return
    
    # 创建图形
    fig = plt.figure(figsize=(14, 12))
    
    # 创建2x2网格布局
    gs = gridspec.GridSpec(2, 2, figure=fig, 
                          wspace=0.1, hspace=0.15,
                          top=0.95, bottom=0.1,
                          left=0.05, right=0.95)
    
    axs = [
        fig.add_subplot(gs[0], projection=ccrs.PlateCarree()),
        fig.add_subplot(gs[1], projection=ccrs.PlateCarree()),
        fig.add_subplot(gs[2], projection=ccrs.PlateCarree()),
        fig.add_subplot(gs[3], projection=ccrs.PlateCarree())
    ]
    
    # 绘制四个季度的地图
    images = []
    for i in range(4):
        if i < len(vwc_data):
            ax = axs[i]
            data = vwc_data[i]['data']
            extent = vwc_data[i]['extent']
            label = vwc_data[i]['label']
            
            # 设置全球视图
            ax.set_global()
            
            # 绘制地图
            img = plot_vwc_map(ax, data, extent, label)
            images.append(img)
    
    # 添加共享颜色条
    cax = fig.add_axes([0.25, 0.05, 0.5, 0.02])
    
    # 创建颜色条，使用自定义颜色映射
    norm = mpl.colors.Normalize(vmin=0, vmax=20)
    cmap = create_custom_cmap()
    
    # 创建颜色条标签
    ticks = [0, 5, 10, 15, 20]  # 对应五种颜色的位置
    cbar = fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap),
                     cax=cax, orientation='horizontal',
                     ticks=ticks)
    cbar.set_label('VWC (kg/m²)', fontsize=11)
    
    # 添加主标题
    plt.suptitle("Global VWC Map - 2015", 
                 fontsize=14, y=0.98)
    
    # 保存图像（路径不变）
    output_path = os.path.join(OUTPUT_DIR, f"{OUTPUT_NAME}.png")
    plt.savefig(output_path, bbox_inches='tight', dpi=300)
    
    print(f"结果已保存至: {output_path}")
    plt.close()

if __name__ == "__main__":
    main()

成功加载主地物类型掩膜
掩膜尺寸: 1800x3600
正在掩膜处理: VWC-201501.tif
成功掩膜处理: VWC-201501.tif
正在掩膜处理: VWC-201504.tif
成功掩膜处理: VWC-201504.tif
正在掩膜处理: VWC-201507.tif
成功掩膜处理: VWC-201507.tif
正在掩膜处理: VWC-201510.tif
成功掩膜处理: VWC-201510.tif
开始读取特定月度VWC数据...
读取掩膜后的文件: E:\data\VWC\VWCMap\Monthly\VWC-201501.tif
读取掩膜后的文件: E:\data\VWC\VWCMap\Monthly\VWC-201504.tif
读取掩膜后的文件: E:\data\VWC\VWCMap\Monthly\VWC-201507.tif
读取掩膜后的文件: E:\data\VWC\VWCMap\Monthly\VWC-201510.tif
结果已保存至: E:\文章\HUITU\Fig\Global_VWC_KuH_Seasonal_2015_Custom.png


# 总结：那一块还是偏高

# 补一下第一块数据的验证

## 缺失数据的补充，使用前后5天内存在数据的最近的两个节点进行线性插值

In [1]:
# 数据填充（修复Sheet名称）
import pandas as pd
import numpy as np
from pathlib import Path
import os
import h5py
from datetime import datetime, timedelta
import calendar
import warnings
warnings.filterwarnings('ignore', category=FutureWarning)

# 常量定义
TOTAL_ROWS = 1800  # 纬度方向像元数
TOTAL_COLS = 3600  # 经度方向像元数
VOD_VARIABLES = ['SM', 'ku_vod_H', 'ku_vod_V', 'x_vod_H', 'x_vod_V', 'c_vod_H', 'c_vod_V']
VOD_COLUMNS = ['SM_Satellite', 'ku_vod_H', 'ku_vod_V', 'x_vod_H', 'x_vod_V', 'c_vod_H', 'c_vod_V']
PFT_VARIABLES = ['water', 'bare', 'snowice', 'built', 'grassnat', 'grassman', 
                 'shrubbd', 'shrubbe', 'shrubnd', 'shrubne', 'treebd', 'treebe', 'treend', 'treene']

# 定义Sheet名称（修复：将SMEX08改为SMAPVEX08）
SHEET_NAMES = ['SMEX02', 'CLASIC07', 'SMAPVEX08', 'SMAPVEX16']

def read_mat_file(file_path, variable_names, silent=False):
    """
    读取MAT文件并返回所需变量的数据矩阵
    
    参数:
    file_path (str): MAT文件路径
    variable_names (list): 要读取的变量名列表
    silent (bool): 是否静默处理错误
    
    返回:
    dict: 包含变量名及其对应的矩阵数据
    """
    try:
        # 尝试使用h5py读取v7.3格式
        with h5py.File(file_path, 'r') as f:
            data = {}
            for var in variable_names:
                if var in f:
                    dataset = f[var]
                    # 如果数据是引用类型，获取实际数据
                    if isinstance(dataset, h5py.Reference):
                        dataset = f[dataset]
                    # 确保是二维数组
                    if len(dataset.shape) == 2:
                        matrix = dataset[()]
                        # 检查矩阵方向是否需要转置
                        if matrix.shape == (TOTAL_ROWS, TOTAL_COLS):
                            data[var] = matrix
                        elif matrix.shape == (TOTAL_COLS, TOTAL_ROWS):
                            data[var] = matrix.T
                        else:
                            # 尝试重塑为正确形状
                            try:
                                data[var] = matrix.reshape(TOTAL_ROWS, TOTAL_COLS)
                            except:
                                data[var] = np.full((TOTAL_ROWS, TOTAL_COLS), np.nan)
                    else:
                        data[var] = np.full((TOTAL_ROWS, TOTAL_COLS), np.nan)
            return data
    except Exception as e:
        if not silent:
            print(f"警告: 读取文件 {file_path} 时出错: {str(e)}")
        return None

def safe_date_to_str(date_val):
    """安全地将日期值转换为YYYYMMDD格式的字符串"""
    if pd.isna(date_val):
        return ""
    
    # 处理不同的日期格式
    if isinstance(date_val, datetime):
        return date_val.strftime('%Y%m%d')
    elif isinstance(date_val, np.datetime64):
        return pd.to_datetime(date_val).strftime('%Y%m%d')
    elif isinstance(date_val, (int, float)):
        # 数字日期 (如20220715.0)
        date_str = str(int(date_val))
        return date_str[:8] if len(date_str) > 8 else date_str.zfill(8)
    else:
        # 字符串日期
        date_str = str(date_val).replace('-', '').replace('/', '').replace(' ', '')
        return date_str[:8] if len(date_str) > 8 else date_str.zfill(8)

def calculate_lai_weight(date_str):
    """计算LAI插值权重（修正版）"""
    if len(date_str) != 8 or not date_str.isdigit():
        return None, None, 0.0
    
    try:
        year = int(date_str[:4])
        month = int(date_str[4:6])
        day = int(date_str[6:8])
    except:
        return None, None, 0.0
    
    # 处理无效日期
    try:
        current_date = datetime(year, month, day)
    except ValueError:
        # 处理无效日期（如2月31日）
        if month == 2 and day > 28:
            day = 28
        elif day > 30 and month in [4, 6, 9, 11]:
            day = 30
        elif day > 31 and month in [1, 3, 5, 7, 8, 10, 12]:
            day = 31
            
        try:
            current_date = datetime(year, month, day)
        except:
            return None, None, 0.0
    
    # 确定正确的月份对
    if day < 15:
        # 如果日期在15日之前，使用前一个月和当前月
        prev_month = month - 1
        prev_year = year
        if prev_month == 0:
            prev_month = 12
            prev_year = year - 1
        
        prev_month_mid = datetime(prev_year, prev_month, 15)
        current_month_mid = datetime(year, month, 15)
        
        total_days = (current_month_mid - prev_month_mid).days
        days_passed = (current_date - prev_month_mid).days
    else:
        # 如果日期在15日或之后，使用当前月和下一月
        current_month_mid = datetime(year, month, 15)
        
        next_month = month + 1
        next_year = year
        if next_month > 12:
            next_month = 1
            next_year += 1
        next_month_mid = datetime(next_year, next_month, 15)
        
        total_days = (next_month_mid - current_month_mid).days
        days_passed = (current_date - current_month_mid).days
    
    if total_days <= 0:
        weight = 0.0
    else:
        weight = max(0.0, min(1.0, days_passed / total_days))
    
    # 返回月份对和权重
    if day < 15:
        return (prev_year, prev_month), (year, month), weight
    else:
        return (year, month), (next_year, next_month), weight

def get_vod_file_path(date_str):
    """根据日期生成VOD文件路径"""
    if not date_str or len(date_str) != 8:
        return None
    try:
        year = int(date_str[:4])
        if year <= 2012:
            return f"E:\\data\\VOD\\mat\\kuxcVOD\\ASC\\MCCA_AMSRE_010D_CCXH_VSM_VOD_Asc_{date_str}_V0.nc4.mat"
        else:
            return f"E:\\data\\VOD\\mat\\kuxcVOD\\ASC\\MCCA_AMSR2_010D_CCXH_VSM_VOD_Asc_{date_str}_V0.nc4.mat"
    except:
        return None

def linear_interpolate(current_dt, pre_date, pre_val, post_date, post_val):
    """线性插值计算缺失值"""
    delta_pre = (current_dt - pre_date).days
    delta_post = (post_date - current_dt).days
    total_delta = delta_pre + delta_post
    return (pre_val * delta_post + post_val * delta_pre) / total_delta

def process_dataset(df, sheet_name):
    """
    处理单个数据集，添加额外的卫星数据列
    
    参数:
    df (pd.DataFrame): 输入数据集
    sheet_name (str): sheet名称
    
    返回:
    pd.DataFrame: 处理后的数据集
    """
    print(f"\n开始处理 {sheet_name} 数据集...")
    print(f"原始数据行数: {len(df)}")
    
    # ===================================================================
    # 步骤1: 添加VOD及SM数据（改进版）
    # ===================================================================
    print("添加VOD及SM数据（改进版）...")
    
    # 准备结果列
    for col in VOD_COLUMNS:
        df[col] = np.nan
    
    # 获取所有日期和像元
    df['Date_Str'] = df['Date'].apply(safe_date_to_str)
    unique_dates = df['Date_Str'].unique()
    unique_pixels = set(zip(df['row'].astype(int), df['col'].astype(int)))
    
    # 构建日期缓存
    date_cache = {}
    for date_str in unique_dates:
        if not date_str or len(date_str) != 8:
            continue
            
        # 生成前后5天的日期范围
        current_dt = datetime.strptime(date_str, '%Y%m%d')
        date_range = [current_dt + timedelta(days=i) for i in range(-5, 6)]
        date_range_str = [d.strftime('%Y%m%d') for d in date_range]
        
        # 读取相关日期的数据
        for d_str in date_range_str:
            if d_str in date_cache:  # 已缓存
                continue
                
            file_path = get_vod_file_path(d_str)
            if not file_path or not os.path.exists(file_path):
                date_cache[d_str] = None
                continue
                
            vod_data = read_mat_file(file_path, VOD_VARIABLES, silent=True)
            if not vod_data:
                date_cache[d_str] = None
                continue
                
            # 提取所有像元的数据
            pixel_data = {}
            for (row, col) in unique_pixels:
                values = {}
                for var in VOD_VARIABLES:
                    matrix = vod_data.get(var)
                    if matrix is not None:
                        values[var] = matrix[row, col]
                    else:
                        values[var] = np.nan
                pixel_data[(row, col)] = values
            date_cache[d_str] = pixel_data
    
    # 填充VOD数据（使用线性插值）
    for idx, row in df.iterrows():
        date_str = row['Date_Str']
        if not date_str or len(date_str) != 8:
            continue
            
        current_dt = datetime.strptime(date_str, '%Y%m%d')
        r, c = int(row['row']), int(row['col'])
        
        # 检查当前日期是否有数据
        current_data = date_cache.get(date_str, {})
        if current_data and (r, c) in current_data:
            for var, col_name in zip(VOD_VARIABLES, VOD_COLUMNS):
                val = current_data[(r, c)][var]
                if not np.isnan(val):
                    df.at[idx, col_name] = val
        
        # 对每个变量进行插值
        for var, col_name in zip(VOD_VARIABLES, VOD_COLUMNS):
            if not np.isnan(df.at[idx, col_name]):
                continue  # 已有数据，跳过
                
            # 寻找前后5天内的有效数据
            pre_val, post_val = None, None
            pre_date, post_date = None, None
            
            for offset in range(-5, 6):
                if offset == 0:  # 跳过当前日期
                    continue
                    
                check_date = current_dt + timedelta(days=offset)
                check_date_str = check_date.strftime('%Y%m%d')
                check_data = date_cache.get(check_date_str)
                
                if not check_data or (r, c) not in check_data:
                    continue
                    
                val = check_data[(r, c)][var]
                if np.isnan(val):
                    continue
                    
                if offset < 0:  # 前向数据
                    if pre_val is None or abs(offset) < abs((check_date - current_dt).days):
                        pre_val = val
                        pre_date = check_date
                else:  # 后向数据
                    if post_val is None or abs(offset) < abs((check_date - current_dt).days):
                        post_val = val
                        post_date = check_date
            
            # 应用插值或直接填充
            if pre_val is not None and post_val is not None:
                # 线性插值
                interpolated = linear_interpolate(current_dt, pre_date, pre_val, post_date, post_val)
                df.at[idx, col_name] = interpolated
            elif pre_val is not None:
                df.at[idx, col_name] = pre_val
            elif post_val is not None:
                df.at[idx, col_name] = post_val
    
    # 清理临时列
    df.drop(columns=['Date_Str'], inplace=True, errors='ignore')
    
    # ===================================================================
    # 步骤2: 添加PFT数据
    # ===================================================================
    print("添加PFT数据...")
    
    # 准备结果列
    for var in PFT_VARIABLES:
        df[var] = np.nan
    
    # 获取所有年份
    years = set()
    for date_str in df['Date'].apply(safe_date_to_str):
        if len(date_str) >= 4 and date_str[:4].isdigit():
            years.add(int(date_str[:4]))
    
    # 创建缓存以提高性能
    year_pft_map = {}
    
    # 处理每个年份
    for year in years:
        file_path = f"E:\\data\\ESACCI PFT\\Resample\\Data\\{year}.mat"
        if os.path.exists(file_path):
            pft_data = read_mat_file(file_path, PFT_VARIABLES, silent=True)
            if pft_data:
                year_pft_map[year] = pft_data
    
    # 填充数据
    for i in df.index:
        date_str = safe_date_to_str(df.at[i, 'Date'])
        if not date_str or len(date_str) < 4:
            continue
            
        try:
            year = int(date_str[:4])
            pft_data = year_pft_map.get(year)
            
            if pft_data is None:
                continue
                
            row_index = int(df.at[i, 'row'])
            col_index = int(df.at[i, 'col'])
            
            for var in PFT_VARIABLES:
                matrix = pft_data.get(var)
                if matrix is not None and not np.isnan(matrix[row_index, col_index]):
                    df.at[i, var] = matrix[row_index, col_index]
        except Exception as e:
            print(f"处理行 {i} 的PFT数据时出错: {str(e)}")
    
    # ===================================================================
    # 步骤3: 添加LAI数据
    # ===================================================================
    print("添加LAI数据...")
    df['LAI_Satellite'] = np.nan
    
    # 创建缓存以提高性能
    lai_cache = {}
    
    # 处理每个日期
    for i in df.index:
        date_str = safe_date_to_str(df.at[i, 'Date'])
        if not date_str or len(date_str) != 8:
            continue
            
        # 计算权重和月份
        prev_month, next_month, weight = calculate_lai_weight(date_str)
        if prev_month is None:
            continue
            
        # 检查并读取当前月份文件
        lai1 = np.nan
        file1_path = f"E:\\data\\GLASS LAI\\mat\\0.1Deg\\Dataset\\{prev_month[0]:04d}-{prev_month[1]:02d}-01.tif.mat"
    
        if file1_path in lai_cache:
            lai1 = lai_cache[file1_path]
        elif os.path.exists(file1_path):
            lai_data1 = read_mat_file(file1_path, ['lai'], silent=True)
            if lai_data1 and 'lai' in lai_data1:
                matrix = lai_data1['lai']
                row_index = int(df.at[i, 'row'])
                col_index = int(df.at[i, 'col'])
                
                try:
                    lai1 = matrix[row_index, col_index]
                    lai_cache[file1_path] = lai1
                except:
                    lai_cache[file1_path] = np.nan
            else:
                lai_cache[file1_path] = np.nan
                lai1 = np.nan
        else:
            lai1 = np.nan
            
        # 检查并读取下个月份文件
        lai2 = np.nan
        file2_path = f"E:\\data\\GLASS LAI\\mat\\0.1Deg\\Dataset\\{next_month[0]:04d}-{next_month[1]:02d}-01.tif.mat"
        
        if file2_path in lai_cache:
            lai2 = lai_cache[file2_path]
        elif os.path.exists(file2_path):
            lai_data2 = read_mat_file(file2_path, ['lai'], silent=True)
            if lai_data2 and 'lai' in lai_data2:
                matrix = lai_data2['lai']
                row_index = int(df.at[i, 'row'])
                col_index = int(df.at[i, 'col'])
                
                try:
                    lai2 = matrix[row_index, col_index]
                    lai_cache[file2_path] = lai2
                except:
                    lai_cache[file2_path] = np.nan
            else:
                lai_cache[file2_path] = np.nan
                lai2 = np.nan
        else:
            lai2 = np.nan
            
        # 线性插值计算最终LAI值
        if not np.isnan(lai1) and not np.isnan(lai2):
            # 使用权重进行线性插值
            lai_final = (1 - weight) * lai1 + weight * lai2
        elif not np.isnan(lai1):
            lai_final = lai1
        elif not np.isnan(lai2):
            lai_final = lai2
        else:
            lai_final = np.nan
            
        df.at[i, 'LAI_Satellite'] = lai_final
    
    # ===================================================================
    # 步骤4: 添加Hveg数据
    # ===================================================================
    print("添加Hveg数据...")
    df['Hveg_Satellite'] = np.nan
    
    hveg_file = "E:\\data\\CanopyHeight\\CH.mat"
    if os.path.exists(hveg_file):
        hveg_data = read_mat_file(hveg_file, ['Hveg'])
        if hveg_data and 'Hveg' in hveg_data:
            matrix = hveg_data['Hveg']
            
            # 填充数据
            for i in df.index:
                row_index = int(df.at[i, 'row'])
                col_index = int(df.at[i, 'col'])
                
                try:
                    df.at[i, 'Hveg_Satellite'] = matrix[row_index, col_index]
                except (IndexError, ValueError):
                    # 保留NaN值
                    pass
    
    print(f"处理完成, 最终数据行数: {len(df)}")
    return df

# 主处理过程
if __name__ == "__main__":
    input_dir = r'E:\data\VWC\test-VWC\SMEX02_CLASIC07_SMEX08_SMAPVEX16'
    input_file = os.path.join(input_dir, 'InsituData_Pixel.xlsx')
    output_file = os.path.join(input_dir, 'InsituData_Pixel_ML_vodFilled.xlsx')
    
    # 确保输出目录存在
    os.makedirs(input_dir, exist_ok=True)
    
    # 处理每个sheet（修复Sheet名称）
    sheet_names = ['SMEX02', 'CLASIC07', 'SMAPVEX08', 'SMAPVEX16']  # 将SMEX08改为SMAPVEX08
    output_dfs = {}
    
    # 读取输入Excel文件
    for sheet in sheet_names:
        print(f"\n{'='*50}")
        print(f"处理数据集: {sheet}")
        try:
            df = pd.read_excel(input_file, sheet_name=sheet, engine='openpyxl')
            
            # 确保有足够的行
            if len(df) == 0:
                print(f"警告: {sheet} 中没有数据")
                output_dfs[sheet] = pd.DataFrame()
                continue
            
            processed_df = process_dataset(df, sheet)
            output_dfs[sheet] = processed_df
            
            # 添加虚拟行避免保存错误
            if processed_df.empty:
                # 创建至少一行数据防止ExcelWriter错误
                processed_df = pd.DataFrame(columns=df.columns)
                processed_df.loc[0] = [None] * len(processed_df.columns)
                
        except Exception as e:
            print(f"处理 {sheet} 时出错: {str(e)}")
            # 创建空DataFrame但有列名防止保存错误
            try:
                df = pd.read_excel(input_file, sheet_name=sheet, nrows=0, engine='openpyxl')
                output_dfs[sheet] = df
            except:
                output_dfs[sheet] = pd.DataFrame(columns=['Date', 'row', 'col'])
    
    # 保存结果
    with pd.ExcelWriter(output_file, engine='openpyxl') as writer:
        for sheet, df in output_dfs.items():
            if not df.empty:
                print(f"保存 '{sheet}' 到Excel文件 ({len(df)} 行)")
                df.to_excel(writer, sheet_name=sheet, index=False)
            else:
                # 创建空但有列名的sheet
                print(f"{sheet} 无有效数据，创建空工作表")
                empty_df = pd.DataFrame(columns=df.columns)
                empty_df.to_excel(writer, sheet_name=sheet, index=False)
    
    print(f"\n{'='*50}")
    print(f"处理完成! 结果已保存至: {output_file}")
    if os.path.exists(output_file):
        print(f"文件大小: {os.path.getsize(output_file)/1024/1024:.2f} MB")
    print("="*50)


处理数据集: SMEX02

开始处理 SMEX02 数据集...
原始数据行数: 16
添加VOD及SM数据（改进版）...
添加PFT数据...
添加LAI数据...
添加Hveg数据...
处理完成, 最终数据行数: 16

处理数据集: CLASIC07

开始处理 CLASIC07 数据集...
原始数据行数: 18
添加VOD及SM数据（改进版）...
添加PFT数据...
添加LAI数据...
添加Hveg数据...
处理完成, 最终数据行数: 18

处理数据集: SMAPVEX08

开始处理 SMAPVEX08 数据集...
原始数据行数: 6
添加VOD及SM数据（改进版）...
添加PFT数据...
添加LAI数据...
添加Hveg数据...
处理完成, 最终数据行数: 6

处理数据集: SMAPVEX16

开始处理 SMAPVEX16 数据集...
原始数据行数: 115
添加VOD及SM数据（改进版）...
添加PFT数据...
添加LAI数据...
添加Hveg数据...
处理完成, 最终数据行数: 115
保存 'SMEX02' 到Excel文件 (16 行)
保存 'CLASIC07' 到Excel文件 (18 行)
保存 'SMAPVEX08' 到Excel文件 (6 行)
保存 'SMAPVEX16' 到Excel文件 (115 行)

处理完成! 结果已保存至: E:\data\VWC\test-VWC\SMEX02_CLASIC07_SMEX08_SMAPVEX16\InsituData_Pixel_ML_vodFilled.xlsx
文件大小: 0.04 MB


## 预测结果，绘制图像

In [15]:
# 散点图绘制（完整修复版）
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import joblib
import os
from pathlib import Path
import logging
from tqdm import tqdm
from sklearn.metrics import mean_squared_error, r2_score

# ============================== 配置日志 ==============================
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(),
        logging.FileHandler('vwc_scatter_plot_final.log')
    ]
)
logger = logging.getLogger()

# ============================== 配置参数 ==============================
SHEET_NAMES = ['SMEX02', 'CLASIC07', 'SMAPVEX08', 'SMAPVEX16']
VWC_COLUMNS = {
    'SMEX02': 'VWC-Field',
    'CLASIC07': 'VWC (kg/m²)',
    'SMAPVEX08': 'VWC',
    'SMAPVEX16': 'PLANT_WATER_CONTENT_AREA'
}

# 标记和颜色设置
MARKER_STYLES = {
    'SMEX02': {'marker': '*', 'color': '#F8766D'},
    'CLASIC07': {'marker': '^', 'facecolor': 'none', 'edgecolor': '#00BFC4'},
    'SMAPVEX08': {'marker': '+', 'color': '#C77CFF'},
    'SMAPVEX16': {'marker': 'o', 'facecolor': 'none', 'edgecolor': '#7CAE00'}
}

# 设置全局字体
plt.rcParams['font.family'] = 'Arial'
plt.rcParams['font.weight'] = 'bold'

def load_and_preprocess_data(file_path):
    """
    加载并预处理Excel文件中的所有sheet
    
    参数:
    file_path (str): Excel文件路径
    
    返回:
    dict: 包含预处理后数据的字典，键为sheet名称
    """
    print(f"加载文件: {file_path}")
    data_dict = {}
    
    # 定义所有必需的PFT特征
    pft_features = [
        'Grass_man', 'Grass_nat',
        'Shrub_bd', 'Shrub_be', 'Shrub_nd', 'Shrub_ne',
        'Tree_bd', 'Tree_be', 'Tree_nd', 'Tree_ne'
    ]
    
    # 定义VOD特征
    vod_features = [
        'ku_vod_H', 'ku_vod_V',
        'x_vod_H', 'x_vod_V',
        'c_vod_H', 'c_vod_V'
    ]
    
    # 定义其他必需特征
    required_features = ['LAI'] + vod_features
    
    for sheet in SHEET_NAMES:
        try:
            df = pd.read_excel(file_path, sheet_name=sheet)
            print(f"  - {sheet}: {len(df)}行")
            
            # ========== 确保所有必需特征都存在 ==========
            # 1. 创建缺失的PFT特征并设置值
            for feature in pft_features:
                if feature not in df.columns:
                    df[feature] = 0.0  # 初始化为0
                    print(f"    创建列 {feature} 并初始化为0")
            
            # 设置'Grass_man'为1，其他PFT特征为0
            df['Grass_man'] = 1.0
            for feature in pft_features[1:]:  # 跳过Grass_man
                df[feature] = 0.0
            print(f"    设置Grass_man=1, 其他PFT特征=0")
            
            # 2. 确保VOD特征存在
            for feature in vod_features:
                if feature not in df.columns:
                    df[feature] = 0.0  # 初始化为0
                    print(f"    创建列 {feature} 并初始化为0")
            
            # 3. 确保LAI存在
            if 'LAI' not in df.columns:
                df['LAI'] = 0.0  # 初始化为0
                print(f"    创建列 LAI 并初始化为0")
            
            # ========== 替换卫星数据（如果存在实测数据） ==========
            if 'LAI' in df.columns and 'LAI_Satellite' in df.columns:
                mask = df['LAI'].notna()
                df.loc[mask, 'LAI_Satellite'] = df.loc[mask, 'LAI']
                print(f"    替换了 {mask.sum()} 行LAI_Satellite数据")
            
            # ========== 归一化处理 ==========
            # 1. VOD特征归一化（除以1.5）
            for vod_col in vod_features:
                df[vod_col] = df[vod_col] / 1.5
            
            # 2. LAI特征归一化（除以6）
            df['LAI'] = df['LAI'] / 6.0
            
            print(f"    完成特征归一化")
            
            data_dict[sheet] = df
        except Exception as e:
            print(f"  加载 {sheet} 时出错: {str(e)}")
            data_dict[sheet] = pd.DataFrame()
    
    return data_dict

def predict_vwc(data_dict):
    """
    使用新模型预测VWC
    
    参数:
    data_dict (dict): 包含所有sheet数据的字典
    
    返回:
    dict: 包含每个sheet预测结果的字典
    """
    # 加载模型
    model_path = "models/RFR_6VOD_LAI_PFTs.pkl"
    print(f"加载模型: {model_path}")
    
    if not os.path.exists(model_path):
        print(f"  模型文件不存在: {model_path}")
        return {}
    
    try:
        model = joblib.load(model_path)
        # 打印模型训练时的特征名称（如果可用）
        if hasattr(model, 'feature_names_in_'):
            model_features = list(model.feature_names_in_)
            print(f"  模型训练特征: {model_features}")
    except Exception as e:
        print(f"  加载模型失败: {str(e)}")
        return {}
    
    # 存储预测结果
    predictions = {}
    
    # 定义特征名称映射（数据框列名 -> 模型期望名称）
    feature_mapping = {
        'ku_vod_H': 'VOD_Ku_Hpol_Asc',
        'ku_vod_V': 'VOD_Ku_Vpol_Asc',
        'x_vod_H': 'VOD_X_Hpol_Asc',
        'x_vod_V': 'VOD_X_Vpol_Asc',
        'c_vod_H': 'VOD_C_Hpol_Asc',
        'c_vod_V': 'VOD_C_Vpol_Asc',
        'LAI': 'LAI',
        'Grass_man': 'Grass_man',
        'Grass_nat': 'Grass_nat',
        'Shrub_bd': 'Shrub_bd',
        'Shrub_be': 'Shrub_be',
        'Shrub_nd': 'Shrub_nd',
        'Shrub_ne': 'Shrub_ne',
        'Tree_bd': 'Tree_bd',
        'Tree_be': 'Tree_be',
        'Tree_nd': 'Tree_nd',
        'Tree_ne': 'Tree_ne'
    }
    
    for sheet, df in data_dict.items():
        if df.empty:
            continue
        
        # 准备特征
        features = [
            'ku_vod_H', 'ku_vod_V',
            'x_vod_H', 'x_vod_V',
            'c_vod_H', 'c_vod_V',
            'LAI',
            'Grass_man', 'Grass_nat',
            'Shrub_bd', 'Shrub_be', 'Shrub_nd', 'Shrub_ne',
            'Tree_bd', 'Tree_be', 'Tree_nd', 'Tree_ne'
        ]
        
        # 检查是否包含所有必要特征
        missing_features = []
        for feature in features:
            if feature not in df.columns:
                missing_features.append(feature)
        
        if missing_features:
            print(f"  {sheet} 缺少特征: {', '.join(missing_features)}")
            continue
        
        # 准备数据
        X = df[features].copy()
        
        # 重命名列以匹配模型期望的特征名称
        X = X.rename(columns=feature_mapping)
        
        # 确保特征顺序与模型期望一致
        if hasattr(model, 'feature_names_in_'):
            # 检查是否有模型期望但数据中不存在的特征
            missing_model_features = set(model_features) - set(X.columns)
            if missing_model_features:
                print(f"  {sheet} 缺少模型期望的特征: {', '.join(missing_model_features)}")
                continue
                
            X = X[model_features]
        
        # 填充缺失值为0
        X = X.fillna(0)
        print(f"  {sheet} 填充缺失值为0")
        
        # 预测VWC
        y_pred = model.predict(X)
        
        # 修复：使用正确的经纬度列名
        predictions[sheet] = {
            'actual': df.loc[X.index, VWC_COLUMNS[sheet]],
            'predicted': y_pred,
            'source': sheet,
            'lat': df.loc[X.index, 'Center_Latitude'],  # 修复列名
            'lon': df.loc[X.index, 'Center_Longitude'],  # 修复列名
            'date': df.loc[X.index, 'Date']
        }
        print(f"  {sheet} 预测完成: {len(y_pred)} 个样本")
    
    return predictions

def create_scatter_plot(predictions):
    """
    创建单一散点图（所有数据集），使用纯文本显示精度指标，移除大标题
    
    参数:
    predictions (dict): 包含所有数据集预测结果的字典
    """
    print("创建散点图...")
    
    # 创建图形
    fig, ax = plt.subplots(figsize=(10, 10))
    
    # 收集所有数据点
    all_actual = []
    all_predicted = []
    
    # 存储每个数据集的统计信息
    dataset_stats = []
    
    # 绘制每个sheet的数据点
    for sheet in SHEET_NAMES:
        if sheet in predictions:
            actual = predictions[sheet]['actual']
            predicted = predictions[sheet]['predicted']
            
            # 添加到总集合
            all_actual.extend(actual)
            all_predicted.extend(predicted)
            
            # 计算当前数据集的统计信息
            rmse = np.sqrt(mean_squared_error(actual, predicted))
            r2 = r2_score(actual, predicted)
            n = len(actual)
            
            # 存储统计信息
            dataset_stats.append({
                'dataset': sheet,
                'rmse': rmse,
                'r2': r2,
                'n': n,
                'color': MARKER_STYLES[sheet].get('color', MARKER_STYLES[sheet].get('edgecolor', '#000000'))
            })
            
            # 绘制当前sheet的点
            if sheet in ['CLASIC07', 'SMAPVEX16']:
                # 对CLASIC07、SMAPVEX16特殊处理：空心
                ax.scatter(
                    actual, predicted,
                    marker=MARKER_STYLES[sheet]['marker'],
                    facecolor='none',  # 内部无填充
                    edgecolor=MARKER_STYLES[sheet]['edgecolor'],  # 使用边缘颜色
                    s=50,
                    alpha=0.7,
                    linewidths=1.0,  # 确保边框可见
                    label=sheet
                )
            else:
                # 其他数据集保持原样
                ax.scatter(
                    actual, predicted,
                    marker=MARKER_STYLES[sheet]['marker'],
                    color=MARKER_STYLES[sheet].get('color', MARKER_STYLES[sheet].get('edgecolor', None)),
                    s=50,
                    alpha=0.7,
                    label=sheet
                )
    
    # 如果没有数据，跳过
    if not all_actual:
        ax.text(0.5, 0.5, '无数据', 
                horizontalalignment='center', 
                verticalalignment='center', 
                transform=ax.transAxes,
                fontsize=16)
        plt.savefig("figures/VWC_Scatter_NoData.png", dpi=300)
        print("无有效数据，无法创建散点图")
        return
    
    # 计算整体RMSE和R²
    rmse = np.sqrt(mean_squared_error(all_actual, all_predicted))
    r2 = r2_score(all_actual, all_predicted)
    
    # 添加1:1参考线
    max_val = max(max(all_actual), max(all_predicted)) * 1.05
    ax.plot([0, max_val], [0, max_val], 'k--', lw=1.5, alpha=0.7)
    
    # 设置坐标轴范围
    ax.set_xlim(0, max_val)
    ax.set_ylim(0, max_val)
    
    # 设置坐标轴标签
    ax.set_xlabel('Insitu VWC (kg/m²)', fontsize=14, fontweight='bold')
    ax.set_ylabel('Predicted VWC (kg/m²)', fontsize=14, fontweight='bold')
    
    # ========== 添加数据集精度文本 ==========
    if dataset_stats:
        # 创建文本内容
        text_lines = []
        
        # 添加每个数据集的统计信息
        for stats in dataset_stats:
            text_lines.append(
                f"{stats['dataset']}: RMSE={stats['rmse']:.3f}  R²={stats['r2']:.4f}  n={stats['n']}"
            )
        
        # 添加整体统计
        text_lines.append(
            f"Overall: RMSE={rmse:.3f}  R²={r2:.4f}  n={len(all_actual)}"
        )
        
        # 创建文本字符串
        text_str = "\n".join(text_lines)
        
        # 添加文本到图表
        text_box = ax.text(
            0.05, 0.95,  # x, y位置（左上角）
            text_str,
            transform=ax.transAxes,
            fontsize=12,
            verticalalignment='top',
            bbox=dict(facecolor='white', alpha=0.8, edgecolor='none')
        )
        
        # 设置数据集文本颜色
        # 由于matplotlib不支持文本中不同颜色，我们使用多个文本对象
        # 删除之前的文本对象
        text_box.remove()
        
        # 创建多个文本对象，每个数据集一行
        y_pos = 0.95  # 起始y位置
        for i, stats in enumerate(dataset_stats):
            ax.text(
                0.05, y_pos,
                f"{stats['dataset']}: RMSE={stats['rmse']:.3f}  R²={stats['r2']:.4f}  n={stats['n']}",
                transform=ax.transAxes,
                fontsize=12,
                color=stats['color'],
                verticalalignment='top',
                bbox=dict(facecolor='white', alpha=0.8, edgecolor='none')
            )
            y_pos -= 0.05  # 下移一行
        
        # 添加整体统计文本（黑色加粗）
        ax.text(
            0.05, y_pos,
            f"Overall: RMSE={rmse:.3f}  R²={r2:.4f}  n={len(all_actual)}",
            transform=ax.transAxes,
            fontsize=12,
            color='black',
            fontweight='bold',
            verticalalignment='top',
            bbox=dict(facecolor='white', alpha=0.8, edgecolor='none')
        )
    
    # 添加图例
    ax.legend(loc='lower right', fontsize=12, frameon=True)
    
    # 调整布局
    plt.tight_layout()
    
    # 保存图像 - 使用与之前相同的文件名
    fig_path = "figures/VWC_Scatter_6VOD_LAI_PFTs.png"
    os.makedirs(os.path.dirname(fig_path), exist_ok=True)
    plt.savefig(fig_path, dpi=300, bbox_inches='tight')
    print(f"散点图已保存至: {fig_path}")
    plt.close()

def save_prediction_details(predictions):
    """
    将预测结果保存到Excel文件中
    
    参数:
    predictions (dict): 包含所有数据集预测结果的字典
    """
    output_dir = Path(r"E:\data\VWC\test-VWC\SMEX02_CLASIC07_SMEX08_SMAPVEX16")
    output_file = output_dir / "details_6VOD_LAI_PFTs.xlsx"
    
    # 创建Excel写入器
    with pd.ExcelWriter(output_file, engine='openpyxl') as writer:
        # 创建主数据框
        all_data = []
        
        # 收集所有sheet的数据
        for sheet, data in predictions.items():
            # 创建当前sheet的数据框
            sheet_df = pd.DataFrame({
                'Date': data['date'],
                'Center_Latitude': data['lat'],  # 使用正确的列名
                'Center_Longitude': data['lon'],  # 使用正确的列名
                'Actual_VWC': data['actual'],
                'Predicted_VWC': data['predicted'],
                'Source': data['source']
            })
            
            all_data.append(sheet_df)
        
        # 合并所有数据
        if all_data:
            combined_df = pd.concat(all_data, ignore_index=True)
            
            # 保存到Excel
            combined_df.to_excel(writer, sheet_name='All_Predictions', index=False)
            print(f"保存预测结果到: All_Predictions ({len(combined_df)}行)")
    
    print(f"所有预测结果已保存至: {output_file}")

def main():
    # 输入文件路径
    input_file = r"E:\data\VWC\test-VWC\SMEX02_CLASIC07_SMEX08_SMAPVEX16\InsituData_Pixel_ML_vodFilled.xlsx"
    
    # 加载并预处理数据
    data_dict = load_and_preprocess_data(input_file)
    
    # 使用新模型预测VWC
    predictions = predict_vwc(data_dict)
    
    # 创建散点图
    create_scatter_plot(predictions)
    
    # 保存预测结果到Excel
    save_prediction_details(predictions)
    
    print("\n处理完成!")

if __name__ == "__main__":
    main()

加载文件: E:\data\VWC\test-VWC\SMEX02_CLASIC07_SMEX08_SMAPVEX16\InsituData_Pixel_ML.xlsx
  - SMEX02: 16行
    创建列 Grass_man 并初始化为0
    创建列 Grass_nat 并初始化为0
    创建列 Shrub_bd 并初始化为0
    创建列 Shrub_be 并初始化为0
    创建列 Shrub_nd 并初始化为0
    创建列 Shrub_ne 并初始化为0
    创建列 Tree_bd 并初始化为0
    创建列 Tree_be 并初始化为0
    创建列 Tree_nd 并初始化为0
    创建列 Tree_ne 并初始化为0
    设置Grass_man=1, 其他PFT特征=0
    创建列 LAI 并初始化为0
    替换了 16 行LAI_Satellite数据
    完成特征归一化
  - CLASIC07: 18行
    创建列 Grass_man 并初始化为0
    创建列 Grass_nat 并初始化为0
    创建列 Shrub_bd 并初始化为0
    创建列 Shrub_be 并初始化为0
    创建列 Shrub_nd 并初始化为0
    创建列 Shrub_ne 并初始化为0
    创建列 Tree_bd 并初始化为0
    创建列 Tree_be 并初始化为0
    创建列 Tree_nd 并初始化为0
    创建列 Tree_ne 并初始化为0
    设置Grass_man=1, 其他PFT特征=0
    创建列 LAI 并初始化为0
    替换了 18 行LAI_Satellite数据
    完成特征归一化
  - SMAPVEX08: 6行
    创建列 Grass_man 并初始化为0
    创建列 Grass_nat 并初始化为0
    创建列 Shrub_bd 并初始化为0
    创建列 Shrub_be 并初始化为0
    创建列 Shrub_nd 并初始化为0
    创建列 Shrub_ne 并初始化为0
    创建列 Tree_bd 并初始化为0
    创建列 Tree_be 并初始化为0
    创建列 Tree_nd 并初始

In [None]:
# 还行，比较散，R2比较低

# 中国那两个数据的验证

## 重新填充

In [1]:
# 多频多角度数据填充（支持VOD插值）
import os
import numpy as np
import pandas as pd
import h5py
from datetime import datetime, timedelta
from pathlib import Path
import warnings
import openpyxl
warnings.filterwarnings('ignore')

def latlon_to_rowcol(lat, lon):
    """将经纬度转换为0.1°栅格的行列号"""
    row = int((89.95 - lat) / 0.1)
    col = int((lon + 179.95) / 0.1)
    return row, col

def get_nearest_lai_files(date):
    """获取指定日期前后两个月的LAI文件路径（精确到每月15日）"""
    # 获取当前日期所在月份的前一个月15日
    prev_month_15 = (date.replace(day=1) - timedelta(days=1)).replace(day=15)
    
    # 获取当前日期所在月份的下一个月15日
    next_month_15 = (date.replace(day=28) + timedelta(days=4)).replace(day=15)
    
    # 构建文件路径
    prev_file = Path(f"E:/data/GLASS LAI/mat/0.1Deg/Dataset/{prev_month_15.strftime('%Y-%m')}-01.tif.mat")
    next_file = Path(f"E:/data/GLASS LAI/mat/0.1Deg/Dataset/{next_month_15.strftime('%Y-%m')}-01.tif.mat")
    
    return prev_file, next_file, prev_month_15, next_month_15

def read_mat_v73(file_path, variable_names):
    """
    读取 v7.3 格式的 .mat 文件
    返回字典：{变量名: 矩阵数据}
    """
    data = {}
    try:
        with h5py.File(file_path, 'r') as f:
            for var in variable_names:
                if var in f:
                    dataset = f[var]
                    # 读取数据（不自动转置）
                    matrix = dataset[()]
                    
                    # 确保数据是二维数组
                    if len(matrix.shape) == 2:
                        # 检查形状是否匹配全局常量（1800×3600）
                        if matrix.shape == (1800, 3600):
                            data[var] = matrix
                        elif matrix.shape == (3600, 1800):
                            # 如果是转置的形状，则手动转置
                            data[var] = matrix.T
                        else:
                            # 尝试重塑为正确形状
                            try:
                                data[var] = matrix.reshape(1800, 3600)
                            except:
                                data[var] = np.full((1800, 3600), np.nan)
                    else:
                        data[var] = np.full((1800, 3600), np.nan)
    except Exception as e:
        print(f"警告: 读取文件 {file_path} 时出错: {str(e)}")
        return None
    
    return data

def get_vod_file_path(date):
    """根据日期生成VOD文件路径"""
    date_str = date.strftime('%Y%m%d')
    year = date.year
    if year <= 2012:
        return Path(f"E:/data/VOD/mat/kuxcVOD/ASC/MCCA_AMSRE_010D_CCXH_VSM_VOD_Asc_{date_str}_V0.nc4.mat")
    else:
        return Path(f"E:/data/VOD/mat/kuxcVOD/ASC/MCCA_AMSR2_010D_CCXH_VSM_VOD_Asc_{date_str}_V0.nc4.mat")

def linear_interpolate(target_dt, pre_dt, pre_val, post_dt, post_val):
    """线性插值计算缺失值"""
    if pre_val is None or post_val is None or pre_dt is None or post_dt is None:
        return None
    delta_pre = (target_dt - pre_dt).days
    delta_post = (post_dt - target_dt).days
    total_delta = delta_pre + delta_post
    if total_delta == 0:  # 避免除零错误
        return (pre_val + post_val) / 2
    return (pre_val * delta_post + post_val * delta_pre) / total_delta

def process_sheet(df, sheet_name, lat, lon, year):
    """处理单个sheet的数据"""
    # 计算固定位置的栅格行列号
    row, col = latlon_to_rowcol(lat, lon)
    
    # 添加位置信息
    df['Latitude'] = lat
    df['Longitude'] = lon
    df['row'] = row
    df['col'] = col
    
    # 创建日期列 - 使用英文列名
    # 检查并删除有空值的行
    initial_count = len(df)
    df = df.dropna(subset=['Year', 'Month', 'Day'])
    removed_count = initial_count - len(df)
    if removed_count > 0:
        print(f"警告: {sheet_name} 中删除了 {removed_count} 行包含空值的行")
    
    # 转换为整数
    df[['Year', 'Month', 'Day']] = df[['Year', 'Month', 'Day']].astype(int)
    
    # 创建日期列
    df['Date'] = pd.to_datetime(df[['Year', 'Month', 'Day']].astype(str).agg('-'.join, axis=1))
    
    # 准备新列
    vod_columns = ['SM', 'ku_vod_H', 'ku_vod_V', 'x_vod_H', 'x_vod_V', 'c_vod_H', 'c_vod_V']
    pft_columns = ['water', 'bare', 'snowice', 'built', 'grassnat', 'grassman', 
                   'shrubbd', 'shrubbe', 'shrubnd', 'shrubne', 'treebd', 'treebe', 'treend', 'treene']
    lai_column = 'lai'  # LAI变量名
    
    for col_name in vod_columns + pft_columns + ['LAI_Satellite', 'Hveg_Satellite']:
        df[col_name] = np.nan
    
    # 加载PFT数据 (一次性加载全年的)
    pft_file = Path(f"E:/data/ESACCI PFT/Resample/Data/{year}.mat")
    if pft_file.exists():
        pft_data = read_mat_v73(pft_file, pft_columns)
        if pft_data:
            for pft_col in pft_columns:
                if pft_col in pft_data:
                    try:
                        # 关键修复：正确索引二维数组
                        # 检查是否为NumPy数组（支持[row, col]索引）
                        if isinstance(pft_data[pft_col], np.ndarray):
                            df[pft_col] = pft_data[pft_col][row, col]
                        else:
                            # 普通列表使用[row][col]
                            df[pft_col] = pft_data[pft_col][row][col]
                    except Exception as e:
                        print(f"处理PFT数据时出错: {str(e)}")
    
    # 加载Hveg数据 (不随时间变化)
    hveg_file = Path("E:/data/CanopyHeight/CH.mat")
    if hveg_file.exists():
        ch_data = read_mat_v73(hveg_file, ['Hveg'])
        if ch_data and 'Hveg' in ch_data:
            try:
                # 正确索引二维数组
                if isinstance(ch_data['Hveg'], np.ndarray):
                    df['Hveg_Satellite'] = ch_data['Hveg'][row, col]
                else:
                    df['Hveg_Satellite'] = ch_data['Hveg'][row][col]
            except Exception as e:
                print(f"处理Hveg数据时出错: {str(e)}")
    
    # 创建VOD文件缓存字典
    vod_cache = {}
    
    # 预加载所有可能需要的VOD文件（前后5天）
    unique_dates = df['Date'].unique()
    for date_val in unique_dates:
        # 生成日期范围（前后5天）
        date_range = [date_val + timedelta(days=i) for i in range(-5, 6)]
        
        for dt in date_range:
            vod_file = get_vod_file_path(dt)
            if vod_file in vod_cache:  # 已缓存
                continue
                
            if vod_file.exists():
                try:
                    vod_data = read_mat_v73(vod_file, vod_columns)
                    if vod_data:
                        # 提取所有像元的数据
                        pixel_data = {}
                        for var in vod_columns:
                            matrix = vod_data.get(var)
                            if matrix is not None:
                                pixel_data[var] = matrix[row, col]
                            else:
                                pixel_data[var] = np.nan
                        vod_cache[vod_file] = pixel_data
                    else:
                        vod_cache[vod_file] = None
                except Exception as e:
                    print(f"读取VOD文件 {vod_file} 时出错: {str(e)}")
                    vod_cache[vod_file] = None
            else:
                vod_cache[vod_file] = None
    
    # 逐行处理VOD数据
    for idx, row_data in df.iterrows():
        current_date = row_data['Date']
        
        # 处理VOD数据（改进版：支持前后5天搜索）
        for var in vod_columns:
            # 1. 首先尝试获取当天数据
            vod_file = get_vod_file_path(current_date)
            current_val = None
            
            # 尝试从缓存获取数据
            if vod_file in vod_cache and vod_cache[vod_file] is not None:
                pixel_data = vod_cache[vod_file]
                if var in pixel_data:
                    current_val = pixel_data[var]
            
            # 如果当前值有效，则填充
            if current_val is not None and not np.isnan(current_val):
                if var == 'SM':
                    df.at[idx, 'SM_Satellite'] = current_val
                else:
                    df.at[idx, var] = current_val
                continue
            
            # 2. 当天数据缺失，搜索前后5天数据
            pre_val, post_val = None, None
            pre_date, post_date = None, None
            
            # 向前搜索（最多5天）
            for i in range(1, 6):
                check_date = current_date - timedelta(days=i)
                check_file = get_vod_file_path(check_date)
                
                # 从缓存获取数据
                if check_file in vod_cache and vod_cache[check_file] is not None:
                    pixel_data = vod_cache[check_file]
                    if var in pixel_data:
                        val = pixel_data[var]
                        if not np.isnan(val):
                            pre_val = val
                            pre_date = check_date
                            break
            
            # 向后搜索（最多5天）
            for i in range(1, 6):
                check_date = current_date + timedelta(days=i)
                check_file = get_vod_file_path(check_date)
                
                # 从缓存获取数据
                if check_file in vod_cache and vod_cache[check_file] is not None:
                    pixel_data = vod_cache[check_file]
                    if var in pixel_data:
                        val = pixel_data[var]
                        if not np.isnan(val):
                            post_val = val
                            post_date = check_date
                            break
            
            # 3. 应用插值或填充逻辑
            final_val = None
            if pre_val is not None and post_val is not None and pre_date and post_date:
                # 线性插值
                final_val = linear_interpolate(current_date, pre_date, pre_val, post_date, post_val)
            elif pre_val is not None:
                # 使用最近的前向值
                final_val = pre_val
            elif post_val is not None:
                # 使用最近的后向值
                final_val = post_val
            
            # 4. 赋值
            if final_val is not None and not np.isnan(final_val):
                if var == 'SM':
                    df.at[idx, 'SM_Satellite'] = final_val
                else:
                    df.at[idx, var] = final_val
        
        # 处理LAI数据（插值）
        prev_file, next_file, prev_date, next_date = get_nearest_lai_files(row_data['Date'])
        if prev_file.exists() and next_file.exists():
            try:
                # 读取前一个月数据
                prev_data = read_mat_v73(prev_file, [lai_column])
                if prev_data and lai_column in prev_data:
                    # 正确索引二维数组
                    if isinstance(prev_data[lai_column], np.ndarray):
                        prev_lai = prev_data[lai_column][row, col]
                    else:
                        prev_lai = prev_data[lai_column][row][col]
                else:
                    prev_lai = np.nan
                
                # 读取后一个月数据
                next_data = read_mat_v73(next_file, [lai_column])
                if next_data and lai_column in next_data:
                    # 正确索引二维数组
                    if isinstance(next_data[lai_column], np.ndarray):
                        next_lai = next_data[lai_column][row, col]
                    else:
                        next_lai = next_data[lai_column][row][col]
                else:
                    next_lai = np.nan
                
                # 计算日期差（精确到天）
                total_days = (next_date - prev_date).days
                current_days = (row_data['Date'] - prev_date).days
                
                # 线性插值
                if total_days > 0 and 0 <= current_days <= total_days:
                    weight = current_days / total_days
                    df.at[idx, 'LAI_Satellite'] = (1 - weight) * prev_lai + weight * next_lai
                else:
                    # 如果日期超出范围，使用最近的一个值
                    if current_days < 0:
                        df.at[idx, 'LAI_Satellite'] = prev_lai
                    else:
                        df.at[idx, 'LAI_Satellite'] = next_lai
            except Exception as e:
                print(f"处理LAI插值失败，日期 {current_date}: {str(e)}")
        else:
            # 如果缺少LAI文件，记录警告
            missing_files = []
            if not prev_file.exists():
                missing_files.append(str(prev_file))
            if not next_file.exists():
                missing_files.append(str(next_file))
            print(f"警告: 缺少LAI文件: {', '.join(missing_files)}")
    
    return df

def process_2017_data():
    """处理2017年的数据"""
    file_path = r"E:\data\VWC\test-VWC\多频多角度地基微波辐射计及地表参量观测数据集\DuolunExp_Veg.xlsx"
    save_path = r"E:\data\VWC\test-VWC\多频多角度地基微波辐射计及地表参量观测数据集\DuolunExp_Veg_ML_VODFilled.xlsx"
    
    # 创建保存目录
    save_dir = Path(save_path).parent
    save_dir.mkdir(parents=True, exist_ok=True)
    
    # 多伦位置 - 东经116.47，北纬42.18
    lat = 42.18
    lon = 116.47
    
    # 获取所有sheet名称
    xl = pd.ExcelFile(file_path)
    all_sheets = xl.sheet_names
    
    # 排除BuckwheatMeasured
    sheets_to_process = [sheet for sheet in all_sheets if "BuckwheatMeasured" not in sheet]
    
    print(f"将处理以下工作表: {', '.join(sheets_to_process)}")
    
    # 创建一个新的Excel文件
    with pd.ExcelWriter(save_path, engine='openpyxl') as writer:
        # 添加一个空的工作表作为占位符（避免"没有可见工作表"错误）
        pd.DataFrame().to_excel(writer, sheet_name='Placeholder', index=False)
        
        for sheet_name in sheets_to_process:
            try:
                # 跳过首行标题（中文列名）
                df = pd.read_excel(file_path, sheet_name=sheet_name, skiprows=1)
                
                # 处理数据
                print(f"处理 2017: {sheet_name}")
                df_processed = process_sheet(df, sheet_name, lat, lon, year=2017)
                
                # 保存到Excel
                df_processed.to_excel(writer, sheet_name=sheet_name, index=False)
            except Exception as e:
                print(f"处理工作表 {sheet_name} 时出错: {str(e)}")
                # 创建一个空DataFrame但有列名防止保存错误
                try:
                    df_empty = pd.read_excel(file_path, sheet_name=sheet_name, nrows=0, skiprows=1)
                    df_empty.to_excel(writer, sheet_name=sheet_name, index=False)
                except:
                    # 如果连列名都读不到，则创建一个默认列
                    df_empty = pd.DataFrame(columns=['Year', 'Month', 'Day'])
                    df_empty.to_excel(writer, sheet_name=sheet_name, index=False)
    
    # 删除占位符工作表，但确保至少有一个工作表
    wb = openpyxl.load_workbook(save_path)
    if 'Placeholder' in wb.sheetnames:
        if len(wb.sheetnames) > 1:
            del wb['Placeholder']
        else:
            # 重命名占位符工作表，避免空工作簿
            ws = wb['Placeholder']
            ws.title = 'EmptyData'
            ws['A1'] = "无有效数据，请检查原始文件或错误日志"
    wb.save(save_path)
    
    print(f"2017年数据处理完成，保存至: {save_path}")

def process_2018_data():
    """处理2018年的数据"""
    file_path = r"E:\data\VWC\test-VWC\多频多角度地基微波辐射计及地表参量观测数据集\ZhenglanqiExp_VWC.xlsx"
    save_path = r"E:\data\VWC\test-VWC\多频多角度地基微波辐射计及地表参量观测数据集\ZhenglanqiExp_VWC_ML_VODFilled.xlsx"
    
    # 正蓝旗位置 - 东经115.93，北纬42.04
    lat = 42.04
    lon = 115.93
    
    # 获取所有sheet名称
    xl = pd.ExcelFile(file_path)
    all_sheets = xl.sheet_names
    
    # 2018年只有一个名为GrassVWC的工作表
    sheets_to_process = [sheet for sheet in all_sheets if "GrassVWC" in sheet]
    
    if not sheets_to_process:
        print(f"警告: 在 {file_path} 中未找到名为 'GrassVWC' 的工作表")
        sheets_to_process = all_sheets  # 尝试处理所有工作表
    
    print(f"将处理以下工作表: {', '.join(sheets_to_process)}")
    
    # 创建一个新的Excel文件
    with pd.ExcelWriter(save_path, engine='openpyxl') as writer:
        # 添加一个空的工作表作为占位符（避免"没有可见工作表"错误）
        pd.DataFrame().to_excel(writer, sheet_name='Placeholder', index=False)
        
        for sheet_name in sheets_to_process:
            try:
                # 跳过首行标题（中文列名）
                df = pd.read_excel(file_path, sheet_name=sheet_name, skiprows=1)
                
                # 处理数据
                print(f"处理 2018: {sheet_name}")
                df_processed = process_sheet(df, sheet_name, lat, lon, year=2018)
                
                # 保存到Excel
                df_processed.to_excel(writer, sheet_name=sheet_name, index=False)
            except Exception as e:
                print(f"处理工作表 {sheet_name} 时出错: {str(e)}")
                try:
                    df_empty = pd.read_excel(file_path, sheet_name=sheet_name, nrows=0, skiprows=1)
                    df_empty.to_excel(writer, sheet_name=sheet_name, index=False)
                except:
                    df_empty = pd.DataFrame(columns=['Year', 'Month', 'Day'])
                    df_empty.to_excel(writer, sheet_name=sheet_name, index=False)
    
    # 删除占位符工作表，但确保至少有一个工作表
    wb = openpyxl.load_workbook(save_path)
    if 'Placeholder' in wb.sheetnames:
        if len(wb.sheetnames) > 1:
            del wb['Placeholder']
        else:
            ws = wb['Placeholder']
            ws.title = 'EmptyData'
            ws['A1'] = "无有效数据，请检查原始文件或错误日志"
    wb.save(save_path)
    
    print(f"2018年数据处理完成，保存至: {save_path}")

def main():
    # 处理2017年数据
    process_2017_data()
    
    # 处理2018年数据
    process_2018_data()

if __name__ == "__main__":
    main()

将处理以下工作表: CornVegMeasured, CornVegFitting, OatVegMeasured, OatVegFitting
处理 2017: CornVegMeasured
警告: CornVegMeasured 中删除了 5 行包含空值的行
处理 2017: CornVegFitting
处理 2017: OatVegMeasured
处理 2017: OatVegFitting
2017年数据处理完成，保存至: E:\data\VWC\test-VWC\多频多角度地基微波辐射计及地表参量观测数据集\DuolunExp_Veg_ML_VODFilled.xlsx
将处理以下工作表: GrassVWC
处理 2018: GrassVWC
2018年数据处理完成，保存至: E:\data\VWC\test-VWC\多频多角度地基微波辐射计及地表参量观测数据集\ZhenglanqiExp_VWC_ML_VODFilled.xlsx


## 整理出预测值并绘制时序图

In [10]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import matplotlib.gridspec as gridspec
import joblib
import os
from pathlib import Path
import warnings
from datetime import datetime
from sklearn.metrics import mean_squared_error, r2_score
from scipy.interpolate import make_interp_spline
warnings.filterwarnings('ignore')

# 设置全局字体
plt.rcParams['font.family'] = 'Arial'
plt.rcParams['font.weight'] = 'bold'
plt.rcParams['axes.labelweight'] = 'bold'
plt.rcParams['axes.titleweight'] = 'bold'
plt.rcParams['figure.titlesize'] = 16
plt.rcParams['figure.titleweight'] = 'bold'

# 常量定义
MODEL_NAME = "RFR_6VOD_LAI_PFTs.pkl"

# 列名到模型特征名的映射
COLUMN_TO_FEATURE_MAPPING = {
    # VOD变量
    'ku_vod_H': 'VOD_Ku_Hpol_Asc',
    'ku_vod_V': 'VOD_Ku_Vpol_Asc',
    'x_vod_H': 'VOD_X_Hpol_Asc',
    'x_vod_V': 'VOD_X_Vpol_Asc',
    'c_vod_H': 'VOD_C_Hpol_Asc',
    'c_vod_V': 'VOD_C_Vpol_Asc',
    
    # LAI变量
    'LAI_Satellite': 'LAI',
    
    # PFT变量
    'grassman': 'Grass_man',
    'grassnat': 'Grass_nat',
    'shrubbd': 'Shrub_bd',
    'shrubbe': 'Shrub_be',
    'shrubnd': 'Shrub_nd',
    'shrubne': 'Shrub_ne',
    'treebd': 'Tree_bd',
    'treebe': 'Tree_be',
    'treend': 'Tree_nd',
    'treene': 'Tree_ne'
}

# 植被类型映射
VEGETATION_TYPES = {
    'CornVegMeasured': 'Corn (2017)',
    'OatVegMeasured': 'Oat (2017)',
    'GrassVWC': 'Grass (2018)'
}

# 实测与拟合数据映射
FITTING_MAPPING = {
    'CornVegMeasured': 'CornVegFitting',
    'OatVegMeasured': 'OatVegFitting',
    'GrassVWC': 'GrassVWC'  # 2018年没有拟合数据
}

# 实测VWC列名映射
ACTUAL_COL_MAPPING = {
    'CornVegMeasured': 'total_VWC(kg/m2)',
    'OatVegMeasured': 'total_VWC(kg/m2)',
    'GrassVWC': 'vegetation water content(kg/m2)'
}

# 实测数据样式
ACTUAL_STYLE = {
    'color': 'black',
    'marker': 'o',
    'markersize': 8,
    'markerfacecolor': 'none',
    'markeredgewidth': 1.5,
    'label': 'Measured'
}

# 预测数据样式
PREDICTED_STYLE = {
    'color': 'red',
    'linestyle': '-',
    'linewidth': 2,
    'label': 'Predicted'
}

def load_data(file_path):
    """加载Excel文件中的所有工作表"""
    print(f"加载文件: {file_path}")
    data_dict = {}
    
    # 获取所有工作表名称
    xl = pd.ExcelFile(file_path)
    sheet_names = xl.sheet_names
    
    for sheet in sheet_names:
        try:
            df = pd.read_excel(file_path, sheet_name=sheet)
            print(f"  - {sheet}: {len(df)}行")
            
            # 确保日期是datetime类型
            if 'Date' in df.columns:
                df['Date'] = pd.to_datetime(df['Date'])
            
            data_dict[sheet] = df
        except Exception as e:
            print(f"  加载 {sheet} 时出错: {str(e)}")
            data_dict[sheet] = pd.DataFrame()
    
    return data_dict

def predict_vwc_unified_model(df):
    """
    使用统一的机器学习模型预测VWC（6VOD+LAI+10PFT）
    """
    # 加载模型
    model_path = f"models/{MODEL_NAME}"
    if not os.path.exists(model_path):
        print(f"警告: 模型文件不存在: {model_path}")
        return pd.Series(np.nan, index=df.index)
    
    try:
        model = joblib.load(model_path)
        print(f"加载统一模型: {model_path}")
        
        # 获取模型期望的特征名称
        if hasattr(model, 'feature_names_in_'):
            expected_features = list(model.feature_names_in_)
            print(f"  模型期望特征: {expected_features}")
        else:
            print("  警告: 模型没有feature_names_in_属性")
            expected_features = []
    except Exception as e:
        print(f"加载模型失败: {str(e)}")
        return pd.Series(np.nan, index=df.index)
    
    # 1. 优先使用地面实测数据替换卫星数据
    if 'LAI' in df.columns:
        lai_mask = df['LAI'].notna() & (df['LAI'] > 0)
        if lai_mask.any():
            df.loc[lai_mask, 'LAI_Satellite'] = df.loc[lai_mask, 'LAI']
            print(f"  使用实测LAI替换了 {lai_mask.sum()} 行数据")
    
    # 2. 检查所有必需特征是否存在
    required_features = list(COLUMN_TO_FEATURE_MAPPING.keys())
    missing_features = [feat for feat in required_features if feat not in df.columns]
    
    if missing_features:
        print(f"  缺少特征: {', '.join(missing_features)}")
        return pd.Series(np.nan, index=df.index)
    
    # 3. 时间序列插值
    if 'Date' in df.columns and not df.empty:
        # 确保按日期排序
        df = df.sort_values('Date')
        
        # 设置时间索引
        date_index = pd.DatetimeIndex(df['Date'])
        df_temp = df.set_index('Date')
        
        # 生成完整的时间序列范围
        full_range = pd.date_range(start=date_index.min(), end=date_index.max(), freq='D')
        df_full = df_temp.reindex(full_range)
        
        # 对特征列进行线性插值
        for col in required_features:
            if col in df_full.columns:
                df_full[col] = df_full[col].interpolate(method='time', limit_direction='both')
                print(f"  已完成{col}的时间序列插值")
        
        # 重置索引
        df = df_full.reset_index().rename(columns={'index': 'Date'})
    else:
        print("  无日期列或数据为空，跳过插值")
    
    # 4. 准备特征数据 - 应用列名映射
    X = pd.DataFrame()
    for data_col, model_feature in COLUMN_TO_FEATURE_MAPPING.items():
        if data_col in df.columns:
            X[model_feature] = df[data_col]
    
    # 5. 应用特征归一化
    # VOD特征归一化（除以2）
    vod_features = [model_feature for data_col, model_feature in COLUMN_TO_FEATURE_MAPPING.items() 
                   if data_col.startswith(('ku_vod', 'x_vod', 'c_vod'))]
    for vod_feature in vod_features:
        if vod_feature in X.columns:
            X[vod_feature] = X[vod_feature].clip(0, 2) / 2.0
    
    # LAI特征归一化（除以6）
    if 'LAI' in X.columns:
        X['LAI'] = X['LAI'].clip(0, 6) / 6.0
    
    # PFT特征归一化（除以100）
    pft_features = [model_feature for data_col, model_feature in COLUMN_TO_FEATURE_MAPPING.items() 
                   if data_col in ('grassman', 'grassnat', 'shrubbd', 'shrubbe', 'shrubnd', 'shrubne', 
                                  'treebd', 'treebe', 'treend', 'treene')]
    for pft_feature in pft_features:
        if pft_feature in X.columns:
            X[pft_feature] = X[pft_feature] / 100.0
    
    # 6. 移除缺失值
    initial_count = len(X)
    X = X.dropna()
    removed_count = initial_count - len(X)
    if removed_count > 0:
        print(f"  移除了 {removed_count} 行包含缺失值的数据")
    
    if X.empty:
        print("  无有效数据可用于预测")
        return pd.Series(np.nan, index=df.index)
    
    # 7. 确保特征顺序与模型期望一致
    if hasattr(model, 'feature_names_in_'):
        X = X[expected_features]
    
    # 8. 预测VWC
    try:
        y_pred = model.predict(X)
        
        # 创建完整长度的预测序列
        full_pred = pd.Series(np.nan, index=df.index)
        full_pred.loc[X.index] = y_pred
        
        return full_pred
    except Exception as e:
        print(f"  预测失败: {str(e)}")
        return pd.Series(np.nan, index=df.index)

def create_combined_plots(data_dict_2017, data_dict_2018):
    """创建组合时间序列图并保存预测结果"""
    print("创建组合时间序列图...")
    
    # 创建输出目录
    output_dir = Path("prediction_results")
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # 创建图形
    fig = plt.figure(figsize=(15, 12))
    gs = gridspec.GridSpec(3, 1, figure=fig, hspace=0.3)
    
    # 设置全局标题
    fig.suptitle('Vegetation Water Content Time Series (6VOD+LAI+PFT Model)', 
                 fontsize=20, fontweight='bold', y=0.95)
    
    # 植被类型列表
    vegetation_types = [
        ('CornVegMeasured', data_dict_2017),  # 玉米
        ('OatVegMeasured', data_dict_2017),   # 燕麦
        ('GrassVWC', data_dict_2018)           # 草
    ]
    
    # 存储所有评估指标
    all_metrics = {}
    
    # 存储所有预测结果
    all_predictions = {}
    
    # 遍历所有植被类型
    for idx, (veg_type, data_dict) in enumerate(vegetation_types):
        ax = fig.add_subplot(gs[idx])
        
        # 获取当前植被类型的实测列名
        actual_col = ACTUAL_COL_MAPPING[veg_type]
        
        # 初始化Y轴范围
        y_min = float('inf')
        y_max = float('-inf')
        
        # 获取实测数据
        if veg_type in data_dict:
            df_measured = data_dict[veg_type].copy()
            
            # 确保日期列存在
            if 'Date' not in df_measured.columns:
                print(f"警告: {veg_type} 中没有 'Date' 列")
                continue
            
            # 按日期排序
            df_measured = df_measured.sort_values('Date')
            
            # 更新Y轴范围（实测值）
            if actual_col in df_measured.columns:
                measured_values = df_measured[actual_col].dropna()
                if not measured_values.empty:
                    y_min = min(y_min, measured_values.min())
                    y_max = max(y_max, measured_values.max())
            
            # 获取拟合数据用于预测
            fitting_sheet = FITTING_MAPPING.get(veg_type, veg_type)
            if fitting_sheet in data_dict:
                df_fitting = data_dict[fitting_sheet].copy()
                
                # 确保日期列存在
                if 'Date' not in df_fitting.columns:
                    print(f"警告: {fitting_sheet} 中没有 'Date' 列")
                    continue
                
                # 按日期排序
                df_fitting = df_fitting.sort_values('Date')
            else:
                # 2018年没有单独的拟合数据
                df_fitting = df_measured.copy()
            
            # 存储评估指标
            metrics = []
            
            # 使用统一模型预测VWC
            col_name = "Predicted_VWC"
            
            # 如果列不存在，使用模型预测
            if col_name not in df_fitting.columns:
                print(f"为 {fitting_sheet} 预测 VWC...")
                df_fitting[col_name] = predict_vwc_unified_model(df_fitting)
            
            # 只在有有效预测值的点进行绘制和评估
            if col_name in df_fitting.columns:
                # 更新Y轴范围（预测值）
                pred_values = df_fitting[col_name].dropna()
                if not pred_values.empty:
                    y_min = min(y_min, pred_values.min())
                    y_max = max(y_max, pred_values.max())
                
                # 获取有效预测数据点
                valid_mask = df_fitting[col_name].notna()
                valid_dates = df_fitting['Date'][valid_mask]
                valid_values = df_fitting[col_name][valid_mask]
                
                # 如果数据点足够多，使用样条插值生成平滑曲线
                if len(valid_dates) > 3:
                    try:
                        # 将日期转换为数值（从最小日期开始的天数）
                        date_numeric = (valid_dates - valid_dates.min()).dt.days
                        
                        # 创建样条插值对象
                        spline = make_interp_spline(date_numeric, valid_values, k=3)
                        
                        # 生成更密集的时间点
                        dense_dates = np.linspace(date_numeric.min(), date_numeric.max(), 300)
                        dense_values = spline(dense_dates)
                        
                        # 将数值日期转换回实际日期
                        dense_dates = valid_dates.min() + pd.to_timedelta(dense_dates, unit='D')
                        
                        # 绘制平滑曲线
                        ax.plot(dense_dates, dense_values,
                                color=PREDICTED_STYLE['color'],
                                linestyle=PREDICTED_STYLE['linestyle'],
                                linewidth=PREDICTED_STYLE['linewidth'],
                                label=PREDICTED_STYLE['label'])
                    except Exception as e:
                        print(f"样条插值失败: {str(e)}")
                        # 如果插值失败，使用原始数据点绘制折线
                        ax.plot(valid_dates, valid_values,
                                color=PREDICTED_STYLE['color'],
                                linestyle=PREDICTED_STYLE['linestyle'],
                                linewidth=PREDICTED_STYLE['linewidth'],
                                label=PREDICTED_STYLE['label'])
                else:
                    # 数据点太少，直接绘制折线
                    ax.plot(valid_dates, valid_values,
                            color=PREDICTED_STYLE['color'],
                            linestyle=PREDICTED_STYLE['linestyle'],
                            linewidth=PREDICTED_STYLE['linewidth'],
                            label=PREDICTED_STYLE['label'])
                
                # 找出同时有实测值和预测值的点
                common_data = pd.merge(
                    df_measured[['Date', actual_col]], 
                    df_fitting[['Date', col_name]], 
                    on='Date', 
                    how='inner'
                ).dropna(subset=[actual_col, col_name])
                
                if not common_data.empty:
                    # 更新Y轴范围（共同数据）
                    common_min = min(common_data[actual_col].min(), common_data[col_name].min())
                    common_max = max(common_data[actual_col].max(), common_data[col_name].max())
                    y_min = min(y_min, common_min)
                    y_max = max(y_max, common_max)
                    
                    # 在实测日期位置绘制实测值点（空心圆）
                    ax.plot(common_data['Date'], common_data[actual_col],
                            linestyle='',  # 无线条
                            color=ACTUAL_STYLE['color'],
                            marker=ACTUAL_STYLE['marker'],
                            markersize=ACTUAL_STYLE['markersize'],
                            markerfacecolor=ACTUAL_STYLE['markerfacecolor'],
                            markeredgewidth=ACTUAL_STYLE['markeredgewidth'],
                            label=ACTUAL_STYLE['label'])
                    
                    # 计算评估指标
                    rmse = np.sqrt(mean_squared_error(common_data[actual_col], common_data[col_name]))
                    r2 = r2_score(common_data[actual_col], common_data[col_name])
                    
                    # 添加到指标列表
                    metrics.append({
                        'rmse': rmse,
                        'r2': r2
                    })
                    
                    # 保存预测结果
                    model_key = f"{veg_type}"
                    all_predictions[model_key] = {
                        'dates': common_data['Date'].tolist(),
                        'measured': common_data[actual_col].tolist(),
                        'predicted': common_data[col_name].tolist(),
                        'rmse': rmse,
                        'r2': r2
                    }
        
        # 设置子图标题
        ax.set_title(VEGETATION_TYPES.get(veg_type, veg_type), 
                     fontsize=16, fontweight='bold')
        
        # 设置坐标轴标签
        if idx == 2:  # 最后一行
            ax.set_xlabel('Date', fontsize=12, fontweight='bold')
        ax.set_ylabel('VWC (kg/m²)', fontsize=12, fontweight='bold')
        
        # 设置X轴格式
        ax.xaxis.set_major_locator(mdates.DayLocator(interval=10))
        ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d'))
        
        # 添加网格
        ax.grid(True, linestyle='--', alpha=0.3)
        
        # 动态设置Y轴范围
        if y_min != float('inf') and y_max != float('-inf'):
            # 添加10%的边距
            y_range = y_max - y_min
            padding = y_range * 0.1
            
            # 确保最小值不小于0
            y_min = max(0, y_min - padding)
            y_max = y_max + padding
            
            ax.set_ylim(y_min, y_max)
        else:
            # 默认范围
            ax.set_ylim(0, 10)
        
        # 添加图例
        handles, labels = ax.get_legend_handles_labels()
        by_label = dict(zip(labels, handles))
        ax.legend(by_label.values(), by_label.keys(), loc='best')
        
        # 添加评估指标文本框
        if metrics:
            metric_text = f"RMSE = {metrics[0]['rmse']:.3f}\nR² = {metrics[0]['r2']:.3f}"
            ax.text(0.02, 0.95, metric_text, 
                    transform=ax.transAxes,
                    fontsize=12,
                    verticalalignment='top',
                    bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
        
        # 存储指标
        all_metrics[veg_type] = metrics
    
    # 调整布局
    plt.tight_layout(rect=[0, 0, 1, 0.93])
    
    # 保存图像 - 添加后缀
    figures_dir = Path("figures")
    figures_dir.mkdir(parents=True, exist_ok=True)
    fig_path = figures_dir / "Combined_VWC_Time_Series_6VOD_LAI_PFT.png"
    plt.savefig(fig_path, dpi=300, bbox_inches='tight')
    print(f"组合时间序列图已保存至: {fig_path}")
    plt.close()
    
    # 保存所有预测结果到CSV文件 - 添加后缀
    for model_key, data in all_predictions.items():
        df = pd.DataFrame({
            'Date': data['dates'],
            'Measured': data['measured'],
            'Predicted': data['predicted']
        })
        csv_path = output_dir / f"{model_key}_predictions_6VOD_LAI_PFT.csv"
        df.to_csv(csv_path, index=False)
        print(f"保存预测结果至: {csv_path}")
    
    # 保存评估指标 - 添加后缀
    metrics_path = output_dir / "model_metrics_6VOD_LAI_PFT.csv"
    metrics_data = []
    for veg_type, metrics in all_metrics.items():
        if metrics:
            metrics_data.append({
                'Vegetation': veg_type,
                'RMSE': metrics[0]['rmse'],
                'R2': metrics[0]['r2']
            })
    
    metrics_df = pd.DataFrame(metrics_data)
    metrics_df.to_csv(metrics_path, index=False)
    print(f"保存模型评估指标至: {metrics_path}")

def main():
    # 2017年数据文件 - 使用ML结尾的文件
    file_2017 = r"E:\data\VWC\test-VWC\多频多角度地基微波辐射计及地表参量观测数据集\DuolunExp_Veg_ML_VODFilled.xlsx"
    data_2017 = load_data(file_2017)
    
    # 2018年数据文件 - 使用ML结尾的文件
    file_2018 = r"E:\data\VWC\test-VWC\多频多角度地基微波辐射计及地表参量观测数据集\ZhenglanqiExp_VWC_ML_VODFilled.xlsx"
    data_2018 = load_data(file_2018)
    
    # 创建组合时间序列图
    create_combined_plots(data_2017, data_2018)
    
    print("\n处理完成!")

if __name__ == "__main__":
    main()

加载文件: E:\data\VWC\test-VWC\多频多角度地基微波辐射计及地表参量观测数据集\DuolunExp_Veg_ML_VODFilled.xlsx
  - CornVegMeasured: 8行
  - CornVegFitting: 64行
  - OatVegMeasured: 7行
  - OatVegFitting: 64行
加载文件: E:\data\VWC\test-VWC\多频多角度地基微波辐射计及地表参量观测数据集\ZhenglanqiExp_VWC_ML_VODFilled.xlsx
  - GrassVWC: 13行
创建组合时间序列图...
为 CornVegFitting 预测 VWC...
加载统一模型: models/RFR_6VOD_LAI_PFTs.pkl
  模型期望特征: ['VOD_Ku_Hpol_Asc', 'VOD_Ku_Vpol_Asc', 'VOD_X_Hpol_Asc', 'VOD_X_Vpol_Asc', 'VOD_C_Hpol_Asc', 'VOD_C_Vpol_Asc', 'LAI', 'Grass_man', 'Grass_nat', 'Shrub_bd', 'Shrub_be', 'Shrub_nd', 'Shrub_ne', 'Tree_bd', 'Tree_be', 'Tree_nd', 'Tree_ne']
  使用实测LAI替换了 63 行数据
  已完成ku_vod_H的时间序列插值
  已完成ku_vod_V的时间序列插值
  已完成x_vod_H的时间序列插值
  已完成x_vod_V的时间序列插值
  已完成c_vod_H的时间序列插值
  已完成c_vod_V的时间序列插值
  已完成LAI_Satellite的时间序列插值
  已完成grassman的时间序列插值
  已完成grassnat的时间序列插值
  已完成shrubbd的时间序列插值
  已完成shrubbe的时间序列插值
  已完成shrubnd的时间序列插值
  已完成shrubne的时间序列插值
  已完成treebd的时间序列插值
  已完成treebe的时间序列插值
  已完成treend的时间序列插值
  已完成treene的时间序列插值
为 OatVegFitting 预测 VWC...
加载统一模

## 散点图

In [8]:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
from pathlib import Path
import warnings
from sklearn.metrics import mean_squared_error, r2_score
warnings.filterwarnings('ignore')

# 设置全局字体
plt.rcParams['font.family'] = 'Arial'
plt.rcParams['font.weight'] = 'bold'
plt.rcParams['axes.labelweight'] = 'bold'
plt.rcParams['axes.titleweight'] = 'bold'

# 植被类型标记样式
VEG_MARKERS = {
    'Corn': {'marker': 's', 'size': 80, 'color': '#1f77b4', 'label': 'Corn (2017)'},
    'Oat': {'marker': '^', 'size': 80, 'color': '#ff7f0e', 'label': 'Oat (2017)'},
    'Grass': {'marker': 'o', 'size': 80, 'color': '#2ca02c', 'label': 'Grass (2018)'}
}

def load_prediction_data(prediction_dir):
    """从CSV文件加载预测结果"""
    print(f"加载预测结果: {prediction_dir}")
    all_data = {}
    
    # 遍历所有CSV文件
    for csv_file in prediction_dir.glob("*_predictions_6VOD_LAI_PFT.csv"):
        # 解析文件名获取植被类型信息
        filename = csv_file.stem
        parts = filename.split('_')
        
        if len(parts) >= 1:  # 格式: {vegetation_type}_predictions_6VOD_LAI_PFT
            veg_type = parts[0]
            
            # 加载数据
            df = pd.read_csv(csv_file)
            
            # 确保日期是datetime类型
            if 'Date' in df.columns:
                df['Date'] = pd.to_datetime(df['Date'])
            
            # 存储数据
            all_data[veg_type] = df
    
    return all_data

def create_single_scatter_plot(prediction_dir):
    """创建单张散点图（所有植被类型）"""
    # 加载预测结果
    all_data = load_prediction_data(prediction_dir)
    
    if not all_data:
        print("警告: 没有找到预测结果文件")
        return
    
    # 创建图形（不包含标题）
    fig, ax = plt.subplots(figsize=(10, 8))
    
    # 收集所有数据点
    all_actual = []
    all_predicted = []
    
    # 存储评估指标
    metrics = {}
    
    # 处理玉米数据
    if 'CornVegMeasured' in all_data:
        df = all_data['CornVegMeasured']
        if 'Measured' in df.columns and 'Predicted' in df.columns:
            actual = df['Measured'].values
            predicted = df['Predicted'].values
            
            # 添加数据点
            ax.scatter(actual, predicted,
                      marker=VEG_MARKERS['Corn']['marker'],
                      s=VEG_MARKERS['Corn']['size'],
                      facecolor='none',
                      edgecolor=VEG_MARKERS['Corn']['color'],
                      linewidths=1.5,
                      label=VEG_MARKERS['Corn']['label'])
            
            # 计算评估指标
            rmse = np.sqrt(mean_squared_error(actual, predicted))
            r2 = r2_score(actual, predicted)
            metrics['Corn'] = {'RMSE': rmse, 'R2': r2}
            
            # 添加到总数据
            all_actual.extend(actual)
            all_predicted.extend(predicted)
    
    # 处理燕麦数据
    if 'OatVegMeasured' in all_data:
        df = all_data['OatVegMeasured']
        if 'Measured' in df.columns and 'Predicted' in df.columns:
            actual = df['Measured'].values
            predicted = df['Predicted'].values
            
            # 添加数据点
            ax.scatter(actual, predicted,
                      marker=VEG_MARKERS['Oat']['marker'],
                      s=VEG_MARKERS['Oat']['size'],
                      facecolor='none',
                      edgecolor=VEG_MARKERS['Oat']['color'],
                      linewidths=1.5,
                      label=VEG_MARKERS['Oat']['label'])
            
            # 计算评估指标
            rmse = np.sqrt(mean_squared_error(actual, predicted))
            r2 = r2_score(actual, predicted)
            metrics['Oat'] = {'RMSE': rmse, 'R2': r2}
            
            # 添加到总数据
            all_actual.extend(actual)
            all_predicted.extend(predicted)
    
    # 处理草数据
    if 'GrassVWC' in all_data:
        df = all_data['GrassVWC']
        if 'Measured' in df.columns and 'Predicted' in df.columns:
            actual = df['Measured'].values
            predicted = df['Predicted'].values
            
            # 添加数据点
            ax.scatter(actual, predicted,
                      marker=VEG_MARKERS['Grass']['marker'],
                      s=VEG_MARKERS['Grass']['size'],
                      facecolor='none',
                      edgecolor=VEG_MARKERS['Grass']['color'],
                      linewidths=1.5,
                      label=VEG_MARKERS['Grass']['label'])
            
            # 计算评估指标
            rmse = np.sqrt(mean_squared_error(actual, predicted))
            r2 = r2_score(actual, predicted)
            metrics['Grass'] = {'RMSE': rmse, 'R2': r2}
            
            # 添加到总数据
            all_actual.extend(actual)
            all_predicted.extend(predicted)
    
    # 如果没有数据点，显示错误信息
    if not all_actual:
        ax.text(0.5, 0.5, 'No Data', horizontalalignment='center', 
                verticalalignment='center', transform=ax.transAxes,
                fontsize=14, color='red')
        return
    
    # 转换为numpy数组
    all_actual = np.array(all_actual)
    all_predicted = np.array(all_predicted)
    
    # 计算整体评估指标
    rmse_total = np.sqrt(mean_squared_error(all_actual, all_predicted))
    r2_total = r2_score(all_actual, all_predicted)
    metrics['Total'] = {'RMSE': rmse_total, 'R2': r2_total}
    
    # 添加1:1参考线
    min_val = min(np.min(all_actual), np.min(all_predicted))
    max_val = max(np.max(all_actual), np.max(all_predicted))
    ax.plot([min_val, max_val], [min_val, max_val], 'k--', linewidth=1.5, label='1:1 Line')
    
    # 设置坐标轴范围
    padding = (max_val - min_val) * 0.05
    ax.set_xlim(min_val - padding, max_val + padding)
    ax.set_ylim(min_val - padding, max_val + padding)
    
    # 设置坐标轴标签
    ax.set_xlabel('Measured VWC (kg/m²)', fontsize=12)
    ax.set_ylabel('Predicted VWC (kg/m²)', fontsize=12)
    
    # 添加网格
    ax.grid(True, linestyle='--', alpha=0.3)
    
    # 创建评估指标文本
    metric_text = ""
    for veg_type in ['Corn', 'Oat', 'Grass', 'Total']:
        if veg_type in metrics:
            metric_text += f"{veg_type}: RMSE={metrics[veg_type]['RMSE']:.3f}, R²={metrics[veg_type]['R2']:.3f}\n"
    
    # 添加评估指标文本框
    ax.text(0.05, 0.95, metric_text, transform=ax.transAxes, 
           fontsize=10, verticalalignment='top',
           bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    # 添加图例（放在图外底部）
    handles, labels = ax.get_legend_handles_labels()
    fig.legend(handles, labels, loc='lower center', 
              bbox_to_anchor=(0.5, 0.01), ncol=3, fontsize=10)
    
    # 调整布局（为图例留出空间）
    plt.tight_layout(rect=[0, 0.05, 1, 0.95])  # 底部留出5%空间给图例
    
    # 保存图像
    output_dir = Path("figures")
    output_dir.mkdir(parents=True, exist_ok=True)
    fig_path = output_dir / "Single_Scatter_Predictions_6VOD_LAI_PFT.png"
    plt.savefig(fig_path, dpi=300, bbox_inches='tight')
    print(f"散点图已保存至: {fig_path}")
    plt.close()
    
    # 打印评估指标
    print("\n模型评估指标:")
    for veg_type, metric in metrics.items():
        print(f"{veg_type}: RMSE={metric['RMSE']:.4f}, R²={metric['R2']:.4f}")

def main():
    # 设置预测结果目录
    prediction_dir = Path("prediction_results")
    
    # 创建散点图
    create_single_scatter_plot(prediction_dir)
    
    print("\n处理完成!")

if __name__ == "__main__":
    main()

加载预测结果: prediction_results
散点图已保存至: figures\Single_Scatter_Predictions_6VOD_LAI_PFT.png

模型评估指标:
Corn: RMSE=0.6176, R²=0.7379
Oat: RMSE=0.5212, R²=-0.7242
Grass: RMSE=0.6204, R²=-52.6990
Total: RMSE=0.5963, R²=0.6670

处理完成!
