# 使用产品以及模型（最近训练的一套）

In [None]:
# 单日数据的生成
import os
import sys
import h5py
import joblib
from osgeo import gdal
gdal.UseExceptions()
from osgeo import osr
import numpy as np
from tqdm import tqdm
from datetime import datetime, timedelta
import logging
import concurrent.futures

# ============================== 配置日志 ==============================
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(),
        logging.FileHandler('vwc_generation.log')
    ]
)
logger = logging.getLogger()

# ============================== 更新配置参数 ==============================
# 定义所有要处理的波段和极化组合
MODEL_CONFIGS = [
    # 单极化模型
    {'band': 'Ku', 'pol': 'H', 'name': 'KuH'},
    {'band': 'Ku', 'pol': 'V', 'name': 'KuV'},
    {'band': 'X', 'pol': 'H', 'name': 'XH'},
    {'band': 'X', 'pol': 'V', 'name': 'XV'},
    {'band': 'C', 'pol': 'H', 'name': 'CH'},
    {'band': 'C', 'pol': 'V', 'name': 'CV'},
    
    # 双极化模型
    {'band': 'Ku', 'pol': 'HV', 'name': 'KuHV'},
    {'band': 'X', 'pol': 'HV', 'name': 'XHV'},
    {'band': 'C', 'pol': 'HV', 'name': 'CHV'},
]

# 时间范围限定
START_DATE = datetime(2015, 8, 1)
END_DATE = datetime(2020, 12, 31)

# 路径配置
VOD_BASE_PATH = r'E:\data\VOD\mat\kuxcVOD\ASC'
LAI_BASE_PATH = r'E:\data\GLASS LAI\mat\0.1Deg\Dataset'
PFT_BASE_PATH = r'E:\data\ESACCI PFT\Resample\Data'
OUTPUT_PATH = r'E:\data\VWC\VWCMap\Daily'

# 确保输出目录存在
os.makedirs(OUTPUT_PATH, exist_ok=True)

# 加载所有模型
MODELS = {}
for config in MODEL_CONFIGS:
    band = config['band']
    pol = config['pol']
    model_path = os.path.join(os.getcwd(), f'models\\RFR_{band}_{pol}pol_Type1.pkl')
    try:
        if os.path.exists(model_path):
            MODELS[(band, pol)] = joblib.load(model_path)
            logger.info(f"成功加载 {band}波段{pol}极化模型")
        else:
            logger.warning(f"模型文件不存在 ({band},{pol}): {model_path}")
    except Exception as e:
        logger.error(f"加载模型失败 ({band},{pol}): {str(e)}")

# 定义PFT特征列表
PFT_FEATURES = [
    'Grass_man', 'Grass_nat', 
    'Shrub_bd', 'Shrub_be', 'Shrub_nd', 'Shrub_ne',
    'Tree_bd', 'Tree_be', 'Tree_nd', 'Tree_ne'
]

# ============================== 辅助函数 ==============================
def create_multiband_geotiff(bands_data, band_names, output_path, nodata=-9999.0):
    """创建多波段地理参考的TIFF文件"""
    try:
        driver = gdal.GetDriverByName('GTiff')
        rows, cols = bands_data[0].shape
        
        # 验证分辨率
        expected_res = 0.1
        if cols != 3600 or rows != 1800:
            logger.warning(f"图像尺寸异常: {rows}x{cols} (期望1800x3600)")
        
        # 创建数据集
        out_ds = driver.Create(
            output_path, 
            cols, 
            rows, 
            len(bands_data), 
            gdal.GDT_Float32,
            options=['COMPRESS=LZW', 'BIGTIFF=YES']
        )
        
        # 设置地理变换 (确保精确的0.1度分辨率)
        geotransform = (-180.0, expected_res, 0.0, 90.0, 0.0, -expected_res)
        out_ds.SetGeoTransform(geotransform)
        
        # 设置坐标系 (WGS84)
        srs = osr.SpatialReference()
        srs.ImportFromEPSG(4326)
        out_ds.SetProjection(srs.ExportToWkt())
        
        # 添加分辨率元数据
        out_ds.SetMetadataItem('PIXEL_SIZE_X', str(expected_res))
        out_ds.SetMetadataItem('PIXEL_SIZE_Y', str(expected_res))
        out_ds.SetMetadataItem('UNITS', 'degrees')
        
        # 写入每个波段数据
        for i, band_data in enumerate(bands_data):
            band = out_ds.GetRasterBand(i+1)
            band.WriteArray(band_data)
            band.SetNoDataValue(nodata)
            band.SetDescription(band_names[i])
            band.SetScale(0.1)  # 设置缩放因子
            band.SetOffset(0.0)  # 设置偏移值
        
        # 清理
        out_ds.FlushCache()
        out_ds = None
        
        # 验证输出文件
        ds = gdal.Open(output_path)
        if ds:
            actual_geotransform = ds.GetGeoTransform()
            ds = None
            
            # 验证分辨率
            res_x = actual_geotransform[1]
            res_y = -actual_geotransform[5]  # 取绝对值
            
            if abs(res_x - expected_res) > 1e-6 or abs(res_y - expected_res) > 1e-6:
                logger.error(f"分辨率不匹配! 期望: {expected_res}, 实际: X={res_x}, Y={res_y}")
                return False
        
        logger.info(f"成功创建多波段GeoTIFF: {output_path} (波段: {', '.join(band_names)})")
        return True
    except Exception as e:
        logger.error(f"创建GeoTIFF失败: {str(e)}")
        return False

def get_vod_file(date):
    """获取VOD文件路径"""
    sensor = 'AMSRE' if date.year <= 2011 else 'AMSR2'
    date_str = date.strftime('%Y%m%d')
    
    # 查找VOD文件
    possible_files = [
        f'MCCA_{sensor}_010D_CCXH_VSM_VOD_Asc_{date_str}_V0.nc4.mat',
        f'MCCA_{sensor}_010D_CCXH_VSM_VOD_Asc_{date_str}_V0.mat',
        f'MCCA_{sensor}_010D_CCXH_VSM_VOD_Asc_{date_str}_V1.nc4.mat',
        f'MCCA_{sensor}_010D_CCXH_VSM_VOD_Asc_{date_str}.mat'
    ]
    
    for filename in possible_files:
        file_path = os.path.join(VOD_BASE_PATH, filename)
        if os.path.exists(file_path):
            return file_path
    
    logger.warning(f"未找到VOD文件: {date_str}")
    return None

def get_lai_file(year, month):
    """获取LAI文件路径"""
    month_str = str(month).zfill(2)
    
    # 尝试文件名格式
    possible_files = [
        f'{year}-{month_str}-01.tif.mat',
        f'{year}-{month_str}-01.mat',
        f'LAI_{year}{month_str}.mat',
        f'{year}_{month_str}_LAI.mat'
    ]
    
    for filename in possible_files:
        file_path = os.path.join(LAI_BASE_PATH, filename)
        if os.path.exists(file_path):
            return file_path
    
    logger.warning(f"未找到LAI文件: {year}-{month_str}")
    return None

def get_pft_file(year):
    """获取PFT文件路径"""
    # 尝试不同文件名
    possible_files = [
        f'{year}.mat',
        f'PFT_{year}.mat',
        f'ESACCI_PFT_{year}.mat',
        f'pft_{year}.mat'
    ]
    
    for filename in possible_files:
        file_path = os.path.join(PFT_BASE_PATH, filename)
        if os.path.exists(file_path):
            return file_path
    
    logger.warning(f"未找到PFT文件: {year}")
    return None

def get_month_centers(date):
    """获取用于插值的前后月中日期"""
    current_mid = date.replace(day=15)
    
    if date.day <= 15:
        prev_mid = (current_mid - timedelta(days=30)).replace(day=15)
        next_mid = current_mid
    else:
        prev_mid = current_mid
        next_mid = (current_mid + timedelta(days=30)).replace(day=15)
    return prev_mid, next_mid

def load_lai_matrix(file_path):
    """加载LAI矩阵数据"""
    if not file_path or not os.path.exists(file_path): 
        return None
    
    try:
        with h5py.File(file_path, 'r') as f:
            # 尝试不同数据集名称
            for key in ['lai', 'Layer', 'data']:
                if key in f:
                    lai_data = np.array(f[key][:])
                    break
            else:
                # 使用第一个数据集
                keys = list(f.keys())
                if keys:
                    lai_data = np.array(f[keys[0]][:])
                    logger.warning(f"使用默认数据集 '{keys[0]}' 作为LAI数据: {file_path}")
                else:
                    logger.warning(f"未找到数据集: {file_path}")
                    return None
            
            # 转置为(1800,3600)
            if lai_data.shape == (3600, 1800):
                lai_data = lai_data.T
            elif lai_data.shape != (1800, 3600):
                logger.error(f"不支持的LAI数据形状: {lai_data.shape}")
                return None
            
            return lai_data
    except Exception as e:
        logger.error(f"加载LAI文件失败: {file_path} - {str(e)}")
        return None

def load_pft_matrix(file_path):
    """加载并旋转PFT矩阵数据"""
    if not file_path or not os.path.exists(file_path): 
        return None
    
    try:
        pft_data = {}
        with h5py.File(file_path, 'r') as f:
            # 定义植被类型映射
            pft_mapping = {
                'water': ['water', 'WATER'],
                'bare': ['bare', 'bareland', 'BARE'],
                'snowice': ['snowice', 'snow', 'ice', 'SNOWICE'],
                'built': ['built', 'urban', 'BUILT'],
                'grassnat': ['grassnat', 'grass_natural', 'GRASSNAT'],
                'grassman': ['grassman', 'grass_managed', 'GRASSMAN'],
                'shrubbd': ['shrubbd', 'shrub_bd', 'SHRUBBD'],
                'shrubbe': ['shrubbe', 'shrub_be', 'SHRUBBE'],
                'shrubnd': ['shrubnd', 'shrub_nd', 'SHRUBND'],
                'shrubne': ['shrubne', 'shrub_ne', 'SHRUBNE'],
                'treebd': ['treebd', 'tree_bd', 'TREEBD'],
                'treebe': ['treebe', 'tree_be', 'TREEBE'],
                'treend': ['treend', 'tree_nd', 'TREEND'],
                'treene': ['treene', 'tree_ne', 'TREENE']
            }
            
            # 查找匹配的数据集
            matched_datasets = {}
            for target, aliases in pft_mapping.items():
                for alias in aliases:
                    if alias in f:
                        data = np.array(f[alias][:])
                        # 旋转和转置
                        if data.shape == (3600, 7200):
                            # 从0.05度到0.1度的降采样
                            data = np.rot90(data, k=-1)
                            # 平均聚合到0.1度
                            data = (data[:, ::2] + data[:, 1::2]) / 2.0
                            pft_data[target] = data.T
                        elif data.shape == (3600, 1800):
                            pft_data[target] = np.rot90(data, k=-1)
                        elif data.shape == (1800, 3600):
                            pft_data[target] = data
                        else:
                            logger.warning(f"未知PFT形状: {data.shape} for {target}")
                            pft_data[target] = np.zeros((1800, 3600))
                        matched_datasets[alias] = True
                        break
                else:
                    logger.warning(f"未找到 {target} 的PFT数据")
                    pft_data[target] = np.zeros((1800, 3600))
            
            # 输出找到的数据集
            if matched_datasets:
                logger.info(f"从PFT文件中加载的数据集: {', '.join(matched_datasets.keys())}")
            else:
                logger.warning("未找到匹配的PFT数据集")
                
        return pft_data
    except Exception as e:
        logger.error(f"加载PFT文件失败: {file_path} - {str(e)}")
        return None

def replace_nan(data, default=0.0):
    """替换NaN值为默认值"""
    if data is None:
        return None
    return np.where(np.isnan(data), default, data)

def create_land_mask(rows=1800, cols=3600):
    """创建简单的陆地掩膜"""
    # 创建纬度数组
    lats = np.linspace(90, -90, rows)
    
    # 创建陆地掩膜 (排除极地和海洋)
    land_mask = np.zeros((rows, cols), dtype=bool)
    for i in range(rows):
        if -60 <= lats[i] <= 80:  # 排除南极和北极
            land_mask[i, :] = True
    
    return land_mask

def prepare_feature_matrix(model_config, vod_data, sm_data, lai_data, pft_data, valid_mask):
    band = model_config['band']
    pol = model_config['pol']
    name = model_config['name']
    
    # 确保模型存在
    if (band, pol) not in MODELS:
        logger.warning(f"模型不可用: {name}")
        return None, None, None
    
    logger.info(f"为模型 {name} 准备特征矩阵")
    
    # 初始化特征数组
    feature_list = []
    feature_names = []
    
    try:
        # 1. VOD特征处理 - 保持与训练时一致的归一化
        if pol in ['H', 'V']:  # 单极化模型
            vod_key = f"{band.lower()}_vod_{pol.lower()}"
            if vod_key in vod_data:
                vod_val = vod_data[vod_key]
                vod_val = replace_nan(vod_val, 0.0)
                vod_val = np.clip(vod_val, 0, 2) / 2.0  # 保持归一化到0-1
                feature_list.append(vod_val)
                feature_names.append('VOD')
            else:
                logger.warning(f"VOD特征缺失: {vod_key}")
                return None, None, None
        
        elif pol == 'HV':  # 双极化模型
            vod_key_h = f"{band.lower()}_vod_h"
            vod_key_v = f"{band.lower()}_vod_v"
            
            if vod_key_h in vod_data and vod_key_v in vod_data:
                vod_val_h = vod_data[vod_key_h]
                vod_val_v = vod_data[vod_key_v]
                
                vod_val_h = replace_nan(vod_val_h, 0.0)
                vod_val_v = replace_nan(vod_val_v, 0.0)
                
                vod_val_h = np.clip(vod_val_h, 0, 2) / 2.0  # 保持归一化
                vod_val_v = np.clip(vod_val_v, 0, 2) / 2.0  # 保持归一化
                
                feature_list.append(vod_val_h)
                feature_list.append(vod_val_v)
                feature_names.extend(['VOD-Hpol', 'VOD-Vpol'])
            else:
                missing = []
                if vod_key_h not in vod_data: missing.append(vod_key_h)
                if vod_key_v not in vod_data: missing.append(vod_key_v)
                logger.warning(f"双极化VOD特征缺失: {', '.join(missing)}")
                return None, None, None
        
        # 2. LAI特征 - 保持与训练时一致的归一化
        if lai_data is not None:
            lai_val = np.clip(lai_data, 0, 6) / 6.0  # 保持归一化到0-1
            feature_list.append(lai_val)
            feature_names.append('LAI')
        else:
            logger.warning("LAI特征缺失")
            return None, None, None
        
        # 3. 土壤湿度特征 - 保持与训练时一致的处理
        sm_val = replace_nan(sm_data, 0.0)
        feature_list.append(sm_val)
        feature_names.append('SM')
        
        # 4. PFT特征 - 使用小写名称匹配并保持归一化
        if pft_data is not None:
            # 定义PFT特征映射（训练时使用的名称）
            pft_mapping = {
                'Grass_man': ['grassman', 'grass_man'],
                'Grass_nat': ['grassnat', 'grass_nat'],
                'Shrub_bd': ['shrubbd', 'shrub_bd'],
                'Shrub_be': ['shrubbe', 'shrub_be'],
                'Shrub_nd': ['shrubnd', 'shrub_nd'],
                'Shrub_ne': ['shrubne', 'shrub_ne'],
                'Tree_bd': ['treebd', 'tree_bd'],
                'Tree_be': ['treebe', 'tree_be'],
                'Tree_nd': ['treend', 'tree_nd'],
                'Tree_ne': ['treene', 'tree_ne']
            }
            
            for model_feature_name, aliases in pft_mapping.items():
                found = False
                for alias in aliases:
                    if alias in pft_data:
                        # 保持除以100的归一化
                        feature_list.append(pft_data[alias] / 100.0)
                        feature_names.append(model_feature_name)
                        found = True
                        break
                
                if not found:
                    logger.warning(f"PFT特征缺失: {model_feature_name}")
                    feature_list.append(np.zeros_like(sm_data))
                    feature_names.append(model_feature_name)
        else:
            logger.warning("PFT数据缺失")
            return None, None, None
        
        # 5. 创建特征矩阵
        valid_indices = np.where(valid_mask)
        num_valid = len(valid_indices[0])
        
        if num_valid == 0:
            logger.warning("无有效数据点")
            return None, None, None
        
        # 特征矩阵 (num_valid x num_features)
        num_features = len(feature_list)
        X = np.zeros((num_valid, num_features), dtype=np.float32)
        
        for i, feat_arr in enumerate(feature_list):
            X[:, i] = feat_arr[valid_indices]
        
        logger.info(f"特征矩阵完成: {X.shape} 形状, 特征: {', '.join(feature_names)}")
        return X, num_valid, feature_names
        
    except Exception as e:
        logger.error(f"准备特征矩阵失败: {str(e)}")
        return None, None, None

def predict_vwc_single_model(model, features, feature_names, rows, cols, valid_mask):
    """为单个模型预测VWC"""
    if features is None:
        return None
    
    # 初始化结果数组
    vwc_array = np.full((rows, cols), -9999.0, dtype=np.float32)
    valid_indices = np.where(valid_mask)
    num_valid = len(valid_indices[0])
    
    if num_valid == 0:
        logger.warning("无有效数据点")
        return vwc_array
    
    try:
        # 检查特征顺序是否匹配
        if hasattr(model, 'feature_names_in_'):
            model_features = model.feature_names_in_
            logger.info(f"模型期望特征顺序: {model_features}")
            
            if feature_names:
                logger.info(f"实际提供特征顺序: {feature_names}")
                
                # 检查特征顺序是否一致
                if list(feature_names) != list(model_features):
                    logger.warning("特征顺序不匹配! 尝试重新排序...")
                    try:
                        # 重新排序特征以匹配模型期望
                        sorted_indices = [feature_names.index(f) for f in model_features]
                        features = features[:, sorted_indices]
                        logger.info("特征重新排序完成")
                    except Exception as e:
                        logger.error(f"特征重新排序失败: {str(e)}")
        
        # 预测 (分批处理避免内存溢出)
        predictions = np.zeros(num_valid, dtype=np.float32)
        chunk_size = 100000
        chunks = (num_valid + chunk_size - 1) // chunk_size
        
        for chunk_idx in range(chunks):
            start = chunk_idx * chunk_size
            end = min((chunk_idx + 1) * chunk_size, num_valid)
            X_chunk = features[start:end]
            predictions[start:end] = model.predict(X_chunk)
        
        # 填充结果数组
        vwc_array[valid_indices] = predictions
        
        # 添加预测统计
        valid_predictions = predictions[~np.isnan(predictions)]
        if len(valid_predictions) > 0:
            logger.info(f"预测值统计: min={np.min(valid_predictions):.4f}, "
                        f"max={np.max(valid_predictions):.4f}, "
                        f"mean={np.mean(valid_predictions):.4f}")
        else:
            logger.warning("无有效预测值")
        
        return vwc_array
        
    except Exception as e:
        logger.error(f"预测失败: {str(e)}")
        return None
        
# ============================== 主处理流程 ==============================
def generate_vwc():
    if not MODELS:
        logger.error("无可用模型，程序终止")
        return
    
    # 全局参数
    rows, cols = 1800, 3600
    total_pixels = rows * cols
    
    # 创建陆地掩膜
    land_mask = create_land_mask(rows, cols)
    logger.info(f"陆地掩膜: 有效点={np.count_nonzero(land_mask)}({np.count_nonzero(land_mask)/total_pixels*100:.2f}%)")
    
    # 日期序列 (2015-2020)
    dates = [START_DATE + timedelta(days=i) 
             for i in range((END_DATE - START_DATE).days + 1)]
    
    # 处理每一天
    for date in tqdm(dates, desc="生成VWC影像"):
        output_filename = f'VWC-{date.strftime("%Y%m%d")}.tif'
        output_path = os.path.join(OUTPUT_PATH, output_filename)
        
        if os.path.exists(output_path):
            logger.info(f"文件已存在: {output_path}")
            continue
        
        logger.info(f"处理日期: {date.strftime('%Y-%m-%d')}")
        band_predictions = {}  # 存储每个模型的预测结果
        
        try:
            # 1. 加载公共数据
            vod_file = get_vod_file(date)
            if not vod_file or not os.path.exists(vod_file):
                logger.error(f"VOD文件未找到或不可访问: {date}")
                continue
                
            logger.info(f"加载VOD文件: {vod_file}")
            vod_data = {}
            with h5py.File(vod_file, 'r') as f:
                # 加载所有VOD数据集
                for key in f.keys():
                    if key.lower().startswith(('ku_vod', 'x_vod', 'c_vod')):
                        vod_data[key.lower()] = np.array(f[key][:]).T
                
                # 加载土壤湿度和QC
                sm_data = np.array(f['SM'][:,:]).T
                qc_data = np.array(f['QC'][:,:]).T
            
            # 处理NaN值
            sm_data = replace_nan(sm_data, 0.0)
            qc_data = replace_nan(qc_data, 1)  # QC无效值设为1
            
            # 2. 加载LAI数据
            prev_mid, next_mid = get_month_centers(date)
            lai_prev_file = get_lai_file(prev_mid.year, prev_mid.month)
            lai_next_file = get_lai_file(next_mid.year, next_mid.month)
            
            if not lai_prev_file or not os.path.exists(lai_prev_file):
                logger.error(f"前月LAI文件未找到: {prev_mid.strftime('%Y-%m')}")
                continue
            if not lai_next_file or not os.path.exists(lai_next_file):
                logger.error(f"后月LAI文件未找到: {next_mid.strftime('%Y-%m')}")
                continue
                
            lai_prev = load_lai_matrix(lai_prev_file)
            lai_next = load_lai_matrix(lai_next_file)
            
            if lai_prev is None or lai_next is None:
                logger.error("无法加载LAI数据")
                continue
                
            # LAI插值
            total_days = (next_mid - prev_mid).days
            current_offset = (date - prev_mid).days
            weight = current_offset / total_days
            lai_data = lai_prev * (1 - weight) + lai_next * weight
            lai_data = replace_nan(lai_data, 0.0)
            
            # 3. 加载PFT数据
            pft_file = get_pft_file(date.year)
            if not pft_file or not os.path.exists(pft_file):
                logger.error(f"PFT文件未找到: {date.year}")
                pft_data = None
            else:
                logger.info(f"加载PFT文件: {pft_file}")
                pft_data = load_pft_matrix(pft_file)
                
            # 4. 创建有效掩膜
            qc_mask = qc_data == 0  # QC=0 表示有效
            lai_mask = ~np.isnan(lai_data)  # 非NaN的LAI值有效
            valid_mask = np.logical_and.reduce([land_mask, qc_mask, lai_mask])
            num_valid = np.count_nonzero(valid_mask)
            logger.info(f"有效数据点数量: {num_valid} (比例: {num_valid/total_pixels*100:.2f}%)")
            
            # 5. 为每个模型准备并预测
            futures = {}
            with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
                for config in MODEL_CONFIGS:
                    # 只处理存在的模型
                    if (config['band'], config['pol']) not in MODELS:
                        continue
                    
                    future = executor.submit(
                        prepare_feature_matrix,
                        config,
                        vod_data,
                        sm_data,
                        lai_data,
                        pft_data,
                        valid_mask
                    )
                    futures[future] = config
            
            # 收集结果并预测
            for future in concurrent.futures.as_completed(futures):
                config = futures[future]
                name = config['name']
                model = MODELS[(config['band'], config['pol'])]
                
                # 正确接收三个返回值
                result = future.result()
                if result is None:
                    logger.warning(f"模型 {name} 特征矩阵不可用，跳过")
                    continue
                    
                features, num_valid, feature_names = result
                
                if features is None or num_valid == 0:
                    logger.warning(f"模型 {name} 预测跳过")
                    continue
                    
                prediction = predict_vwc_single_model(
                    model, features, feature_names, rows, cols, valid_mask
                )
                if prediction is not None:
                    band_predictions[name] = prediction
                    logger.info(f"模型 {name} 预测完成")
                else:
                    logger.warning(f"模型 {name} 预测失败")
            
            # 6. 确保所有模型都处理完毕
            predicted_models = set(band_predictions.keys())
            expected_models = {config['name'] for config in MODEL_CONFIGS if (config['band'], config['pol']) in MODELS}
            missing_models = expected_models - predicted_models
            if missing_models:
                logger.warning(f"缺失模型预测: {', '.join(missing_models)}")
            
            # 7. 创建多波段TIFF
            if band_predictions:
                bands = []
                band_names = []
                for config in MODEL_CONFIGS:
                    name = config['name']
                    if name in band_predictions:
                        bands.append(band_predictions[name])
                        band_names.append(name)
                
                if bands:
                    success = create_multiband_geotiff(bands, band_names, output_path)
                    if success:
                        logger.info(f"成功保存多波段TIFF: {output_path}")
                    else:
                        logger.error(f"保存失败: {output_path}")
                else:
                    logger.warning("无有效波段数据可保存")
            else:
                logger.warning("无预测数据可保存")
                
        except Exception as e:
            logger.error(f"处理日期 {date} 错误: {str(e)}", exc_info=True)

# ============================== 执行主函数 ==============================
if __name__ == "__main__":
    logger.info(f"当前目录: {os.getcwd()}")
    logger.info(f"输出目录: {OUTPUT_PATH}")
    logger.info(f"开始处理: {START_DATE.strftime('%Y-%m-%d')} 到 {END_DATE.strftime('%Y-%m-%d')}")
    generate_vwc()
    logger.info("VWC影像生成完成")

2025-08-05 20:00:37,960 - INFO - 成功加载 Ku波段H极化模型
2025-08-05 20:00:38,564 - INFO - 成功加载 Ku波段V极化模型
2025-08-05 20:00:40,210 - INFO - 成功加载 X波段H极化模型
2025-08-05 20:00:42,367 - INFO - 成功加载 X波段V极化模型


In [1]:
# 地物类型掩膜
import os
import numpy as np
from osgeo import gdal, osr
import logging
from tqdm import tqdm
import concurrent.futures
import multiprocessing
import time

# ============================== 配置日志 ==============================
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(),
        logging.FileHandler('vwc_mask_water_bare_snowice_optimized.log')
    ]
)
logger = logging.getLogger()

# ============================== 配置参数 ==============================
VWC_DIR = r'E:\data\VWC\VWCMap\8Day'  # 每日VWC数据目录
MAIN_TYPE_FILE = r'E:\data\ESACCI PFT\Resample\Data\mainType.tif'  # 主地物类型文件
NODATA_VALUE = -9999.0  # 无效值
NUM_WORKERS = multiprocessing.cpu_count()  # 使用所有CPU核心

# 需要掩膜的地物类型 (水、裸地、冰雪)
MASK_TYPES = [0, 1, 2]  # 0=water, 1=bare, 2=snowice

# ============================== 辅助函数 ==============================
def load_main_type_mask():
    """加载主地物类型掩膜 (优化版本)"""
    try:
        # 使用内存映射提高读取速度
        ds = gdal.Open(MAIN_TYPE_FILE, gdal.GA_ReadOnly)
        if ds is None:
            logger.error(f"无法打开主地物类型文件: {MAIN_TYPE_FILE}")
            return None
        
        # 读取整个波段到内存
        band = ds.GetRasterBand(1)
        main_type = band.ReadAsArray()
        
        # 创建掩膜 (需要掩膜的类型为True)
        mask = np.isin(main_type, MASK_TYPES)
        
        # 获取NoData值并处理
        nodata = band.GetNoDataValue()
        if nodata is not None:
            # 排除NoData区域
            mask = np.logical_and(mask, main_type != nodata)
        
        ds = None
        logger.info(f"成功加载主地物类型掩膜: {mask.shape}")
        return mask
    except Exception as e:
        logger.error(f"加载主地物类型文件失败: {str(e)}")
        return None

def process_vwc_file(file_path, mask):
    """处理单个VWC文件 (优化版本)"""
    try:
        start_time = time.time()
        
        # 以读写模式打开文件
        ds = gdal.Open(file_path, gdal.GA_Update)
        if ds is None:
            logger.warning(f"无法打开VWC文件: {file_path}")
            return False
        
        # 获取波段数量
        num_bands = ds.RasterCount
        rows = ds.RasterYSize
        cols = ds.RasterXSize
        
        # 验证尺寸匹配
        if rows != mask.shape[0] or cols != mask.shape[1]:
            logger.error(f"尺寸不匹配: VWC文件({rows}x{cols}) vs 掩膜({mask.shape[0]}x{mask.shape[1]})")
            ds = None
            return False
        
        # 一次性读取所有波段数据
        band_data = []
        for band_idx in range(1, num_bands + 1):
            band = ds.GetRasterBand(band_idx)
            data = band.ReadAsArray()
            band_data.append(data)
        
        # 应用掩膜到所有波段
        for i in range(num_bands):
            band_data[i][mask] = NODATA_VALUE
        
        # 一次性写回所有波段
        for band_idx in range(1, num_bands + 1):
            band = ds.GetRasterBand(band_idx)
            band.WriteArray(band_data[band_idx-1])
        
        # 清理
        ds.FlushCache()
        ds = None
        
        process_time = time.time() - start_time
        logger.debug(f"处理完成: {file_path} (耗时: {process_time:.2f}s)")
        return True
    except Exception as e:
        logger.error(f"处理文件 {file_path} 失败: {str(e)}")
        return False

def process_vwc_files_parallel(mask):
    """并行处理所有VWC文件"""
    # 获取所有VWC文件
    vwc_files = []
    for filename in os.listdir(VWC_DIR):
        if filename.startswith('VWC-') and filename.endswith('.tif'):
            file_path = os.path.join(VWC_DIR, filename)
            vwc_files.append(file_path)
    
    if not vwc_files:
        logger.error("未找到任何VWC文件")
        return 0
    
    logger.info(f"找到 {len(vwc_files)} 个VWC文件，使用 {NUM_WORKERS} 个线程并行处理")
    
    # 使用线程池并行处理
    success_count = 0
    with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor:
        # 提交所有任务
        future_to_file = {executor.submit(process_vwc_file, file_path, mask): file_path 
                          for file_path in vwc_files}
        
        # 使用tqdm显示进度
        for future in tqdm(concurrent.futures.as_completed(future_to_file), 
                          total=len(vwc_files), desc="处理VWC文件"):
            file_path = future_to_file[future]
            try:
                result = future.result()
                if result:
                    success_count += 1
            except Exception as e:
                logger.error(f"处理文件 {file_path} 时出错: {str(e)}")
    
    return success_count

# ============================== 主处理流程 ==============================
def main():
    logger.info(f"VWC目录: {VWC_DIR}")
    logger.info(f"主地物类型文件: {MAIN_TYPE_FILE}")
    logger.info(f"需要掩膜的地物类型: {MASK_TYPES} (水、裸地、冰雪)")
    logger.info(f"使用 {NUM_WORKERS} 个线程进行并行处理")
    
    start_time = time.time()
    
    # 加载主地物类型掩膜
    mask = load_main_type_mask()
    if mask is None:
        logger.error("无法加载掩膜，程序终止")
        return
    
    # 处理所有VWC文件
    success_count = process_vwc_files_parallel(mask)
    
    total_time = time.time() - start_time
    logger.info(f"处理完成! 成功处理 {success_count}/{len(os.listdir(VWC_DIR))} 个文件")
    logger.info(f"总耗时: {total_time:.2f}秒")

# ============================== 执行主函数 ==============================
if __name__ == "__main__":
    try:
        main()
        logger.info("VWC文件掩膜处理完成")
    except Exception as e:
        logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True)

2025-08-06 18:49:10,194 - INFO - VWC目录: E:\data\VWC\VWCMap
2025-08-06 18:49:10,196 - INFO - 主地物类型文件: E:\data\ESACCI PFT\Resample\Data\mainType.tif
2025-08-06 18:49:10,198 - INFO - 需要掩膜的地物类型: [0, 1, 2] (水、裸地、冰雪)
2025-08-06 18:49:10,199 - INFO - 使用 16 个线程进行并行处理
2025-08-06 18:49:21,434 - INFO - 成功加载主地物类型掩膜: (1800, 3600)
2025-08-06 18:49:21,650 - INFO - 找到 851 个VWC文件，使用 16 个线程并行处理
2025-08-06 20:56:36,451 - ERROR - 处理文件 E:\data\VWC\VWCMap\VWC-20161029.tif 失败: 'NoneType' object does not support item assignment
2025-08-06 20:56:36,530 - ERROR - 处理文件 E:\data\VWC\VWCMap\VWC-20161028.tif 失败: 'NoneType' object does not support item assignment
2025-08-06 20:56:36,699 - ERROR - 处理文件 E:\data\VWC\VWCMap\VWC-20161102.tif 失败: 'NoneType' object does not support item assignment
2025-08-06 20:56:36,715 - ERROR - 处理文件 E:\data\VWC\VWCMap\VWC-20161030.tif 失败: 'NoneType' object does not support item assignment
2025-08-06 20:56:36,767 - ERROR - 处理文件 E:\data\VWC\VWCMap\VWC-20161104.tif 失败: 'NoneType' object doe

8日合成数据

In [7]:
import os
import numpy as np
from osgeo import gdal, osr
import datetime
import re
import logging
from tqdm import tqdm

# ============================== 配置日志 ==============================
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(),
        logging.FileHandler('vwc_8day_composite_corrected.log')
    ]
)
logger = logging.getLogger()

# ============================== 配置参数 ==============================
INPUT_DIR = r'E:\data\VWC\VWCMap\Daily'  # 每日VWC数据目录
OUTPUT_DIR = r'E:\data\VWC\VWCMap\8Day'  # 8日合成输出目录
START_YEAR = 2015
END_YEAR = 2016
NODATA_VALUE = -9999.0

# 确保输出目录存在
os.makedirs(OUTPUT_DIR, exist_ok=True)

# ============================== 辅助函数 ==============================
def get_daily_files():
    """获取所有每日文件，并按日期排序"""
    files = []
    pattern = re.compile(r'VWC-(\d{4})(\d{2})(\d{2})\.tif$')  # 匹配YYYYMMDD格式的文件名
    
    for filename in os.listdir(INPUT_DIR):
        if not filename.endswith('.tif'):
            continue
            
        match = pattern.match(filename)
        if not match:
            continue
            
        year = int(match.group(1))
        month = int(match.group(2))
        day = int(match.group(3))
        file_date = datetime.date(year, month, day)
        
        # 只处理指定年份范围内的文件
        if START_YEAR <= year <= END_YEAR:
            file_path = os.path.join(INPUT_DIR, filename)
            files.append({
                'path': file_path,
                'date': file_date,
                'year': year
            })
    
    # 按日期排序
    files.sort(key=lambda x: x['date'])
    return files

def create_composite_geotiff(data, output_path, start_date, end_date, nodata=NODATA_VALUE):
    """创建8日合成的地理参考TIFF文件"""
    try:
        driver = gdal.GetDriverByName('GTiff')
        rows, cols, bands = data.shape
        
        # 创建输出数据集
        out_ds = driver.Create(
            output_path,
            cols,
            rows,
            bands,
            gdal.GDT_Float32,
            options=['COMPRESS=LZW', 'TILED=YES']
        )
        
        # 设置地理变换
        out_ds.SetGeoTransform((-180, 0.1, 0, 90, 0, -0.1))
        
        # 设置坐标系 (WGS84)
        srs = osr.SpatialReference()
        srs.ImportFromEPSG(4326)
        out_ds.SetProjection(srs.ExportToWkt())
        
        # 添加日期元数据
        out_ds.SetMetadata({
            'START_DATE': start_date.strftime('%Y-%m-%d'),
            'END_DATE': end_date.strftime('%Y-%m-%d'),
            'NUM_DAYS': str((end_date - start_date).days + 1)
        })
        
        # 写入每个波段
        for band_idx in range(bands):
            band = out_ds.GetRasterBand(band_idx+1)
            band.WriteArray(data[:, :, band_idx])
            band.SetNoDataValue(nodata)
            band.SetDescription(f"Model_{band_idx+1}")
        
        # 清理
        out_ds.FlushCache()
        out_ds = None
        logger.info(f"成功创建8日合成GeoTIFF: {output_path}")
        return True
    except Exception as e:
        logger.error(f"创建GeoTIFF失败: {str(e)}")
        return False

def generate_8day_composites():
    """生成8日合成数据 (按年处理)"""
    # 获取所有每日文件
    daily_files = get_daily_files()
    if not daily_files:
        logger.error("未找到任何每日文件")
        return
    
    logger.info(f"找到 {len(daily_files)} 个每日文件")
    
    # 按年份分组
    files_by_year = {}
    for year in range(START_YEAR, END_YEAR + 1):
        files_by_year[year] = [f for f in daily_files if f['year'] == year]
    
    # 处理每一年
    for year in range(START_YEAR, END_YEAR + 1):
        year_files = files_by_year[year]
        if not year_files:
            logger.warning(f"{year}年没有每日文件")
            continue
            
        logger.info(f"处理 {year} 年 (共 {len(year_files)} 个每日文件)")
        
        # 确定该年的起始和结束日期
        start_date = datetime.date(year, 1, 1)
        end_date = datetime.date(year, 12, 31)
        
        # 创建8日周期
        current_start = start_date
        composite_groups = []
        
        # 构建8日周期 (严格限制在该年内)
        while current_start <= end_date:
            # 计算周期结束日期 (最多7天后)
            current_end = current_start + datetime.timedelta(days=7)
            
            # 如果结束日期超出该年范围，调整到12月31日
            if current_end > end_date:
                current_end = end_date
                
            # 收集当前周期内的文件
            group_files = []
            for file_info in year_files:
                if current_start <= file_info['date'] <= current_end:
                    group_files.append(file_info)
            
            if group_files:
                composite_groups.append({
                    'start_date': current_start,
                    'end_date': current_end,
                    'files': group_files
                })
            
            # 移动到下一个周期 (从结束日期的下一天开始)
            current_start = current_end + datetime.timedelta(days=1)
            
            # 如果下一个起始日期已超出该年范围，结束循环
            if current_start > end_date:
                break
        
        logger.info(f"{year}年共生成 {len(composite_groups)} 个8日周期")
        
        # 处理该年的每个8日周期
        for group in tqdm(composite_groups, desc=f"处理{year}年周期"):
            # 使用起始日期作为文件名标识
            start_str = group['start_date'].strftime('%Y%m%d')
            end_str = group['end_date'].strftime('%Y%m%d')
            
            # 文件名格式：VWC-YYYYMMDD.tif（使用起始日期）
            output_filename = f'VWC-{start_str}.tif'
            output_path = os.path.join(OUTPUT_DIR, output_filename)
            
            if os.path.exists(output_path):
                logger.info(f"文件已存在: {output_path}")
                continue
            
            num_days = len(group['files'])
            logger.info(f"处理周期: {start_str} 到 {end_str} ({num_days}天)")
            
            # 初始化数据立方体
            num_bands = None
            composite_sum = None
            valid_count = None
            rows, cols = 0, 0
            
            # 处理周期内的每个文件
            for file_info in group['files']:
                try:
                    ds = gdal.Open(file_info['path'])
                    if ds is None:
                        logger.warning(f"无法打开文件: {file_info['path']}")
                        continue
                    
                    # 获取图像尺寸和波段数（只在第一次确定）
                    if rows == 0:
                        rows = ds.RasterYSize
                        cols = ds.RasterXSize
                        num_bands = ds.RasterCount
                        composite_sum = np.zeros((rows, cols, num_bands), dtype=np.float32)
                        valid_count = np.zeros((rows, cols, num_bands), dtype=np.uint16)
                    
                    # 读取每个波段
                    for band_idx in range(num_bands):
                        band = ds.GetRasterBand(band_idx+1)
                        data = band.ReadAsArray()
                        nodata = band.GetNoDataValue()
                        
                        if nodata is None:
                            nodata = NODATA_VALUE
                        
                        # 创建有效值掩膜
                        valid_mask = (data != nodata) & (~np.isnan(data))
                        
                        # 累加有效值
                        composite_sum[:, :, band_idx] += np.where(valid_mask, data, 0)
                        valid_count[:, :, band_idx] += valid_mask.astype(np.uint16)
                    
                    ds = None
                except Exception as e:
                    logger.error(f"处理文件 {file_info['path']} 错误: {str(e)}")
            
            # 计算平均值
            if composite_sum is not None and valid_count is not None:
                # 避免除以零
                with np.errstate(divide='ignore', invalid='ignore'):
                    composite_avg = np.where(
                        valid_count > 0,
                        composite_sum / valid_count,
                        NODATA_VALUE  # 无效值
                    )
                
                # 保存合成结果
                success = create_composite_geotiff(
                    composite_avg, 
                    output_path,
                    group['start_date'],
                    group['end_date']
                )
                if success:
                    # 计算有效数据百分比
                    valid_percent = (valid_count > 0).sum() / (rows * cols * num_bands) * 100
                    logger.info(f"8日合成完成: {start_str} 到 {end_str} (有效数据: {valid_percent:.2f}%)")
                else:
                    logger.error(f"保存失败: {output_path}")
            else:
                logger.warning(f"无有效数据可合成: {start_str} 到 {end_str}")
    
    logger.info("8日合成处理完成")

# ============================== 执行主函数 ==============================
if __name__ == "__main__":
    logger.info(f"输入目录: {INPUT_DIR}")
    logger.info(f"输出目录: {OUTPUT_DIR}")
    logger.info(f"生成8日合成数据: {START_YEAR} 到 {END_YEAR}")
    
    try:
        generate_8day_composites()
        logger.info("8日合成数据生成完成")
    except Exception as e:
        logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True)

2025-08-07 16:32:44,831 - INFO - 输入目录: E:\data\VWC\VWCMap\Daily
2025-08-07 16:32:44,833 - INFO - 输出目录: E:\data\VWC\VWCMap\8Day
2025-08-07 16:32:44,834 - INFO - 生成8日合成数据: 2015 到 2016
2025-08-07 16:32:44,845 - INFO - 找到 731 个每日文件
2025-08-07 16:32:44,848 - INFO - 处理 2015 年 (共 365 个每日文件)
2025-08-07 16:32:44,852 - INFO - 2015年共生成 46 个8日周期
2025-08-07 16:32:44,855 - INFO - 处理周期: 20150101 到 20150108 (8天)                            | 0/46 [00:00<?, ?it/s]
2025-08-07 16:32:58,036 - INFO - 成功创建8日合成GeoTIFF: E:\data\VWC\VWCMap\8Day\VWC-20150101.tif
2025-08-07 16:32:58,120 - INFO - 8日合成完成: 20150101 到 20150108 (有效数据: 9.68%)
2025-08-07 16:32:58,124 - INFO - 处理周期: 20150109 到 20150116 (8天)                    | 1/46 [00:13<09:57, 13.27s/it]
2025-08-07 16:33:10,478 - INFO - 成功创建8日合成GeoTIFF: E:\data\VWC\VWCMap\8Day\VWC-20150109.tif
2025-08-07 16:33:10,561 - INFO - 8日合成完成: 20150109 到 20150116 (有效数据: 9.57%)
2025-08-07 16:33:10,564 - INFO - 处理周期: 20150117 到 20150124 (8天)                    | 2/46 [00:25<09:22

In [8]:
# 月度数据合成
import os
import numpy as np
from osgeo import gdal, osr
import datetime
import re
import logging
from tqdm import tqdm

# ============================== 配置日志 ==============================
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(),
        logging.FileHandler('vwc_monthly_composite.log')
    ]
)
logger = logging.getLogger()

# ============================== 配置参数 ==============================
INPUT_DIR = r'E:\data\VWC\VWCMap\Daily'  # 每日VWC数据目录
OUTPUT_DIR = r'E:\data\VWC\VWCMap\Monthly'  # 月度合成输出目录

# 起始和结束年份
START_YEAR = 2015
END_YEAR = 2017

# 确保输出目录存在
os.makedirs(OUTPUT_DIR, exist_ok=True)

# ============================== 辅助函数 ==============================
def get_daily_files_for_month(year, month):
    """获取指定年月的所有每日文件（仅TIFF格式）"""
    files = []
    # 修正：只匹配真正的TIFF文件
    pattern = re.compile(r'VWC-(\d{4})(\d{2})(\d{2})\.tif$')  # 添加$确保只匹配.tif结尾
    
    # 列出所有每日文件
    for filename in os.listdir(INPUT_DIR):
        # 跳过辅助文件（.aux.xml等）
        if not filename.endswith('.tif'):
            continue
            
        match = pattern.match(filename)
        if not match:
            continue
            
        file_year = int(match.group(1))
        file_month = int(match.group(2))
        file_day = int(match.group(3))
        
        if file_year == year and file_month == month:
            file_path = os.path.join(INPUT_DIR, filename)
            files.append({
                'path': file_path,
                'date': datetime.date(file_year, file_month, file_day)
            })
    
    # 按日期排序
    files.sort(key=lambda x: x['date'])
    return files

def create_monthly_geotiff(data, output_path, nodata=-9999.0):
    """创建地理参考的月度TIFF文件"""
    try:
        driver = gdal.GetDriverByName('GTiff')
        rows, cols, bands = data.shape
        
        # 创建数据集
        out_ds = driver.Create(
            output_path, 
            cols, 
            rows, 
            bands, 
            gdal.GDT_Float32,
            options=['COMPRESS=LZW']
        )
        
        # 设置地理变换
        out_ds.SetGeoTransform((-180, 0.1, 0, 90, 0, -0.1))
        
        # 设置坐标系 (WGS84)
        srs = osr.SpatialReference()
        srs.ImportFromEPSG(4326)
        out_ds.SetProjection(srs.ExportToWkt())
        
        # 添加日期元数据
        out_ds.SetMetadata({'PRODUCTION_DATE': datetime.date.today().isoformat()})
        
        # 写入每个波段数据
        for band_idx in range(bands):
            band = out_ds.GetRasterBand(band_idx+1)
            band.WriteArray(data[:, :, band_idx])
            band.SetNoDataValue(nodata)
            band.SetDescription(f"Model_{band_idx+1}")
        
        # 清理
        out_ds.FlushCache()
        out_ds = None
        logger.info(f"成功创建月度合成GeoTIFF: {output_path}")
        return True
    except Exception as e:
        logger.error(f"创建GeoTIFF失败: {str(e)}")
        return False

def generate_monthly_composites():
    """生成月度合成数据"""
    total_months = (END_YEAR - START_YEAR + 1) * 12
    processed = 0
    
    # 进度条初始化
    pbar = tqdm(total=total_months, desc="月度合成进度")
    
    for year in range(START_YEAR, END_YEAR + 1):
        for month in range(1, 13):
            # 获取该月的所有每日文件
            files = get_daily_files_for_month(year, month)
            
            if not files:
                logger.warning(f"在{year}年{month}月未找到任何每日文件")
                pbar.update(1)
                processed += 1
                continue
            
            # 输出文件名 - 修正：只包含年月
            output_filename = f'VWC-{year}{str(month).zfill(2)}.tif'
            output_path = os.path.join(OUTPUT_DIR, output_filename)
            
            # 检查文件是否已存在
            if os.path.exists(output_path):
                logger.info(f"月度合成已存在: {output_path}")
                pbar.update(1)
                processed += 1
                continue
            
            logger.info(f"生成{year}年{month}月合成 ({len(files)}个每日文件)")
            
            # 初始化数据立方体 (高度, 宽度, 波段) 和有效计数
            num_bands = None
            monthly_sum = None
            valid_count = None
            rows, cols = 0, 0
            
            # 用于存储每日数据的列表（用于中值合成）
            daily_arrays = {}
            
            # 处理每个文件
            for file_info in files:
                try:
                    # 打开文件 - 添加额外检查确保是TIFF文件
                    if not file_info['path'].endswith('.tif'):
                        logger.warning(f"跳过非TIFF文件: {file_info['path']}")
                        continue
                    
                    ds = gdal.Open(file_info['path'])
                    if ds is None:
                        logger.warning(f"无法打开文件: {file_info['path']}")
                        continue
                    
                    # 获取图像尺寸（只在第一次确定）
                    if rows == 0:
                        rows = ds.RasterYSize
                        cols = ds.RasterXSize
                    
                    # 获取波段数量（只在第一次确定）
                    if num_bands is None:
                        num_bands = ds.RasterCount
                        monthly_sum = np.zeros((rows, cols, num_bands), dtype=np.float32)
                        valid_count = np.zeros((rows, cols, num_bands), dtype=np.uint16)
                        # 初始化每日数据存储
                        daily_arrays = {i: [] for i in range(num_bands)}
                    
                    # 读取每个波段
                    for band_idx in range(num_bands):
                        band = ds.GetRasterBand(band_idx+1)
                        data = band.ReadAsArray()
                        nodata = band.GetNoDataValue()
                        
                        if nodata is None:
                            nodata = -9999.0
                        
                        # 创建有效值掩膜
                        valid_mask = (data != nodata) & (~np.isnan(data))
                        
                        # 累加有效值
                        monthly_sum[:, :, band_idx] += np.where(valid_mask, data, 0)
                        valid_count[:, :, band_idx] += valid_mask.astype(np.uint16)
                        
                        # 收集每日数据用于中值计算
                        daily_arrays[band_idx].append(data)
                    
                    # 关闭数据集
                    ds = None
                    
                except Exception as e:
                    logger.error(f"处理文件 {file_info['path']} 错误: {str(e)}")
            
            # 计算月度平均值 (避免除以零)
            if monthly_sum is not None and valid_count is not None:
                # 计算平均值
                with np.errstate(divide='ignore', invalid='ignore'):
                    monthly_avg = np.where(
                        valid_count > 0,
                        monthly_sum / valid_count,
                        -9999.0  # 无效值
                    )
                
                # 计算中值（更健壮的指标）
                monthly_median = np.full((rows, cols, num_bands), -9999.0, dtype=np.float32)
                for band_idx in range(num_bands):
                    if daily_arrays[band_idx]:
                        # 计算每个像素的中值
                        stack = np.stack(daily_arrays[band_idx], axis=0)
                        valid_stack = np.ma.masked_equal(stack, -9999.0)
                        monthly_median[:, :, band_idx] = np.ma.median(valid_stack, axis=0).filled(-9999.0)
                    
                    # 如果中值计算失败，使用平均值
                    if np.all(monthly_median[:, :, band_idx] == -9999.0):
                        monthly_median[:, :, band_idx] = monthly_avg[:, :, band_idx]
                
                # 保存合成结果 - 使用中值作为更健壮的指标
                success = create_monthly_geotiff(monthly_median, output_path)
                
                if success:
                    # 输出质量报告
                    valid_percent = (valid_count > 0).sum() / (rows * cols * num_bands) * 100
                    logger.info(f"月度合成完成: {year}年{month}月, 有效数据: {valid_percent:.2f}%")
                else:
                    logger.error(f"保存月度合成失败: {output_path}")
            
            # 更新进度条
            pbar.update(1)
            processed += 1
    
    # 关闭进度条
    pbar.close()
    logger.info(f"月度合成处理完成! 共处理 {processed} 个月份")

# ============================== 执行主函数 ==============================
if __name__ == "__main__":
    logger.info(f"输入目录: {INPUT_DIR}")
    logger.info(f"输出目录: {OUTPUT_DIR}")
    logger.info(f"生成月度合成数据: {START_YEAR} 到 {END_YEAR}")
    
    try:
        generate_monthly_composites()
        logger.info("月度合成数据生成完成")
    except Exception as e:
        logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True)

2025-08-07 21:11:04,466 - INFO - 输入目录: E:\data\VWC\VWCMap\Daily
2025-08-07 21:11:04,468 - INFO - 输出目录: E:\data\VWC\VWCMap\Monthly
2025-08-07 21:11:04,469 - INFO - 生成月度合成数据: 2015 到 2017
2025-08-07 21:11:04,477 - INFO - 月度合成已存在: E:\data\VWC\VWCMap\Monthly\VWC-201501.tif       | 0/36 [00:00<?, ?it/s]
2025-08-07 21:11:04,482 - INFO - 月度合成已存在: E:\data\VWC\VWCMap\Monthly\VWC-201502.tif
2025-08-07 21:11:04,487 - INFO - 月度合成已存在: E:\data\VWC\VWCMap\Monthly\VWC-201503.tif
2025-08-07 21:11:04,490 - INFO - 月度合成已存在: E:\data\VWC\VWCMap\Monthly\VWC-201504.tif
2025-08-07 21:11:04,495 - INFO - 月度合成已存在: E:\data\VWC\VWCMap\Monthly\VWC-201505.tif
2025-08-07 21:11:04,499 - INFO - 月度合成已存在: E:\data\VWC\VWCMap\Monthly\VWC-201506.tif
2025-08-07 21:11:04,503 - INFO - 月度合成已存在: E:\data\VWC\VWCMap\Monthly\VWC-201507.tif
2025-08-07 21:11:04,507 - INFO - 月度合成已存在: E:\data\VWC\VWCMap\Monthly\VWC-201508.tif
2025-08-07 21:11:04,511 - INFO - 月度合成已存在: E:\data\VWC\VWCMap\Monthly\VWC-201509.tif
2025-08-07 21:11:04,516 - INF