In [5]:
# -*- coding: utf-8 -*-
"""
Section 3.5 - Few-Shot Fine-Tuning (Result & Evaluation)
- 复用你在 3.4 中的预处理风格与特征构造
- 自动匹配文件名（<SITE>*_Met.nc / <SITE>*_Flux.nc）
- 兼容多种经纬度字段来源（coords/variables/attrs），缺失时回退为 0 并将经度规范到 [-180, 180]
- Zero-shot：直接用 170-site 通用模型推理（PyCaret load_model）
- Few-shot：用目标站点前 10% / 20% 的时间片继续 .fit() 微调通用模型
- 与同一时间窗口的 Zero-shot 进行对比评估
- 输出：few_shot_results.csv、按站点与比例的时间序列图（fig_3_5/），R² 提升柱状图（fig_3_5/r2_gain_bar.png）
"""

import os
import re
import glob
import gc
import sys
import shutil
import time
from pathlib import Path

import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
from scipy.stats import pearsonr
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error

# ======== PyCaret ========
from pycaret.regression import load_model, predict_model

# ========================== 0) 配置（按需改动） ==========================
MET_FOLDER = "plumber2_met_nc_files"
FLUX_FOLDER = "plumber2_nc_files"

# 目标站点（站点前缀）
TARGET_SITES = ["AU-ASM", "FI-Hyy", "US-UMB"]

# few-shot 比例（按时间顺序，取最前面的 frac 用于微调，其余为测试集）
FEW_SHOT_FRACS = [0.10, 0.20]

# 170 站点通用模型（PyCaret 保存的模型名；写不带 .pkl 的前缀即可）
MODEL_GPP_NAME = "automl_170sites_GPP"
MODEL_NEE_NAME = "automl_170sites_NEE"

# 输出
OUT_DIR = Path("fig_3_5")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# 与训练一致的特征
FEATURES_RAW = ['SWdown', 'LWdown', 'Tair', 'Qair', 'RH', 'Psurf', 'Wind',
                'CO2air', 'VPD', 'LAI', 'Ustar']
DERIVED_FEATURES = ['SW_LAI', 'RH_Tair', 'SWdown_lag1', 'Tair_lag1']
GEO_FEATURES = ['latitude', 'longitude']
FEATURES_ALL = FEATURES_RAW + DERIVED_FEATURES + GEO_FEATURES

# ========================== 1) xarray 打开（含 Windows 中文路径容错） ==========================
_DEF_OPEN_BACKENDS = ("netcdf4", "h5netcdf")

def _has_non_ascii(s: str) -> bool:
    try:
        s.encode('ascii')
        return False
    except UnicodeEncodeError:
        return True

def _ascii_cache_path(src_path: str) -> str:
    cache_root = Path(os.getenv('TEMP', 'C:\\temp')) / 'nc_cache'
    cache_root.mkdir(parents=True, exist_ok=True)
    return str(cache_root / Path(src_path).name)

def xr_open(path):
    abs_path = str(Path(path).resolve())
    if sys.platform.startswith('win') and _has_non_ascii(abs_path):
        cached = _ascii_cache_path(abs_path)
        try:
            if (not os.path.exists(cached)) or (os.path.getmtime(cached) < os.path.getmtime(abs_path)):
                shutil.copy2(abs_path, cached)
            read_path = cached
        except Exception:
            read_path = abs_path
    else:
        read_path = abs_path

    last_err = None
    for eng in _DEF_OPEN_BACKENDS:
        try:
            return xr.open_dataset(read_path, engine=eng)
        except Exception as e:
            last_err = e
            continue
    # 最后再尝试自动引擎
    try:
        return xr.open_dataset(read_path)
    except Exception:
        pass
    raise RuntimeError(f"无法打开 {abs_path}，请安装 netCDF4/h5netcdf。原始错误: {last_err}")

# ========================== 2) 经纬度提取 ==========================
def _extract_scalar_value(arr) -> float:
    try:
        val = getattr(arr, "values", arr)
        val = np.asarray(val)
        if val.ndim == 0:
            return float(val)
        return float(np.nanmean(val))
    except Exception:
        return np.nan

def get_lat_lon(ds: xr.Dataset):
    lat_candidates = ['lat', 'latitude', 'LAT', 'Latitude']
    lon_candidates = ['lon', 'longitude', 'LON', 'Longitude']

    lat, lon = np.nan, np.nan
    # coords
    for name in lat_candidates:
        if name in ds.coords:
            lat = _extract_scalar_value(ds.coords[name]); break
    for name in lon_candidates:
        if name in ds.coords:
            lon = _extract_scalar_value(ds.coords[name]); break
    # variables
    if np.isnan(lat):
        for name in lat_candidates:
            if name in ds.variables:
                lat = _extract_scalar_value(ds.variables[name]); break
    if np.isnan(lon):
        for name in lon_candidates:
            if name in ds.variables:
                lon = _extract_scalar_value(ds.variables[name]); break
    # attrs
    if np.isnan(lat):
        for k in ['site_latitude', 'Latitude', 'LAT', 'lat']:
            if k in ds.attrs:
                try:
                    lat = float(ds.attrs[k]); break
                except Exception:
                    pass
    if np.isnan(lon):
        for k in ['site_longitude', 'Longitude', 'LON', 'lon']:
            if k in ds.attrs:
                try:
                    lon = float(ds.attrs[k]); break
                except Exception:
                    pass

    if np.isnan(lat): lat = 0.0
    if np.isnan(lon): lon = 0.0
    if lon > 180: lon -= 360.0
    return float(lat), float(lon)

# ========================== 3) 站点数据构建（自动匹配文件名） ==========================
def build_site_df(site_prefix: str) -> pd.DataFrame:
    """
    匹配 <SITE>*_Met.nc / <SITE>*_Flux.nc，构建与训练一致的特征，保留 time 以进行时间切分。
    """
    met_candidates = glob.glob(os.path.join(MET_FOLDER, f"{site_prefix}*_Met.nc"))
    flux_candidates = glob.glob(os.path.join(FLUX_FOLDER, f"{site_prefix}*_Flux.nc"))
    if not met_candidates:
        raise FileNotFoundError(f"[{site_prefix}] 未找到 Met 文件（{MET_FOLDER}/{site_prefix}*_Met.nc）")
    if not flux_candidates:
        raise FileNotFoundError(f"[{site_prefix}] 未找到 Flux 文件（{FLUX_FOLDER}/{site_prefix}*_Flux.nc）")

    met_path = met_candidates[0]
    flux_path = flux_candidates[0]

    met_ds = xr_open(met_path)
    flux_ds = xr_open(flux_path)
    try:
        met_df = met_ds.to_dataframe().reset_index()
        flux_df = flux_ds.to_dataframe().reset_index()

        # 按 time 合并
        df = pd.merge_asof(
            met_df.sort_values('time'),
            flux_df.sort_values('time'),
            on='time'
        )

        # 必要列存在性兜底
        for col in FEATURES_RAW + ['GPP', 'NEE']:
            if col not in df.columns:
                df[col] = np.nan

        # 衍生特征
        df['SW_LAI'] = df['SWdown'] * df['LAI']
        df['RH_Tair'] = df['RH'] * df['Tair']
        df['SWdown_lag1'] = df['SWdown'].shift(1)
        df['Tair_lag1'] = df['Tair'].shift(1)

        # 经纬度
        lat, lon = get_lat_lon(met_ds)
        df['latitude']  = np.float32(lat)
        df['longitude'] = np.float32(lon)

        # 保留需要的列
        keep_cols = ['time'] + FEATURES_ALL + ['GPP', 'NEE']
        df = df[keep_cols]

        # 清理 NA，按时间排序
        df = df.dropna(subset=FEATURES_ALL + ['GPP', 'NEE']).sort_values('time').reset_index(drop=True)
        return df
    finally:
        try: met_ds.close()
        except Exception: pass
        try: flux_ds.close()
        except Exception: pass
        gc.collect()

# ========================== 4) 评估与绘图 ==========================
def compute_metrics(y_true, y_pred):
    y_true = np.asarray(y_true).ravel()
    y_pred = np.asarray(y_pred).ravel()
    r2 = r2_score(y_true, y_pred)
    rmse = float(np.sqrt(mean_squared_error(y_true, y_pred)))
    mae = float(mean_absolute_error(y_true, y_pred))
    rho = float(pearsonr(y_true, y_pred)[0]) if (np.std(y_true)>1e-12 and np.std(y_pred)>1e-12) else np.nan
    return {"R2": r2, "RMSE": rmse, "MAE": mae, "PearsonR": rho}

def plot_time_series(dates, y_true, y_zero, y_few, site, target, frac):
    plt.figure(figsize=(10,4))
    plt.plot(dates, y_true, label="Observed")
    plt.plot(dates, y_zero, label="Zero-shot", alpha=0.85)
    plt.plot(dates, y_few,  label=f"Few-shot ({int(frac*100)}%)", alpha=0.85)
    plt.title(f"{site} - {target} (Zero vs Few-shot)")
    plt.xlabel("Date"); plt.ylabel(target); plt.legend()
    plt.tight_layout()
    plt.savefig(OUT_DIR / f"{site}_{target}_{int(frac*100)}.png", dpi=200)
    plt.close()

def plot_r2_gain_bar(df_results: pd.DataFrame):
    # 仅画 few-shot 的 R2 增益（few - zero），按站点/目标/比例分组
    if df_results.empty: return
    df = df_results.copy()
    df["R2_gain"] = df["R2_few"] - df["R2_zero"]
    df["Label"] = df["Site"] + " - " + df["Target"] + " (" + (df["Frac"]*100).astype(int).astype(str) + "%)"
    order = df.sort_values("R2_gain", ascending=False)["Label"].tolist()

    plt.figure(figsize=(10, 5))
    plt.barh(df["Label"], df["R2_gain"])
    plt.gca().invert_yaxis()
    plt.xlabel("ΔR² (Few-shot − Zero-shot)")
    plt.title("R² Gain by Site / Target / Few-shot Ratio")
    plt.tight_layout()
    plt.savefig(OUT_DIR / "r2_gain_bar.png", dpi=200)
    plt.close()

# ========================== 5) Few-shot 主流程 ==========================
def load_pycaret_model_flex(name_or_path: str):
    """
    允许传入 'automl_170sites_GPP' 或 'automl_170sites_GPP.pkl'。
    """
    stem = name_or_path[:-4] if name_or_path.lower().endswith(".pkl") else name_or_path
    return load_model(stem)

def run_few_shot():
    t0 = time.time()
    results = []

    # 载入 170-site 通用模型
    model_map = {
        "GPP": load_pycaret_model_flex(MODEL_GPP_NAME),
        "NEE": load_pycaret_model_flex(MODEL_NEE_NAME),
    }

    for site in TARGET_SITES:
        df_site = build_site_df(site)
        dates = pd.to_datetime(df_site["time"])
        for target in ["GPP", "NEE"]:
            if target not in df_site.columns:
                print(f"[Skip] {site}: 缺少 {target}")
                continue

            # 为一致对比，Zero-shot 也在相同测试窗口计算
            for frac in FEW_SHOT_FRACS:
                n_few = max(1, int(len(df_site) * frac))
                few_df  = df_site.iloc[:n_few].copy()
                test_df = df_site.iloc[n_few:].copy()
                if len(test_df) < 5:
                    print(f"[Skip] {site}-{target} ({int(frac*100)}%): 测试样本太少 ({len(test_df)})")
                    continue

                # 基础模型（保持未改动）
                base_model = model_map[target]

                # Zero-shot on test window
                y_zero = predict_model(base_model, data=test_df[FEATURES_ALL])["prediction_label"].values

                # 重新加载一次同名模型，防止跨站点累积训练
                model_ft = load_pycaret_model_flex(MODEL_GPP_NAME if target=="GPP" else MODEL_NEE_NAME)
                # 继续训练（few-shot）
                model_ft = model_ft.fit(few_df[FEATURES_ALL], few_df[target])

                # Few-shot 预测
                y_few  = model_ft.predict(test_df[FEATURES_ALL])

                # 指标
                m_zero = compute_metrics(test_df[target].values, y_zero)
                m_few  = compute_metrics(test_df[target].values, y_few)

                results.append({
                    "Site": site, "Target": target, "Frac": frac,
                    "n_few": int(n_few), "n_test": int(len(test_df)),
                    "R2_zero": m_zero["R2"], "RMSE_zero": m_zero["RMSE"], "MAE_zero": m_zero["MAE"], "Pearson_zero": m_zero["PearsonR"],
                    "R2_few":  m_few["R2"],  "RMSE_few":  m_few["RMSE"],  "MAE_few":  m_few["MAE"],  "Pearson_few":  m_few["PearsonR"]
                })

                # 图：时间序列对比（测试窗口）
                plot_time_series(dates.iloc[n_few:], test_df[target].values, y_zero, y_few, site, target, frac)

                # 释放
                del model_ft, y_zero, y_few, m_zero, m_few
                gc.collect()

    df_res = pd.DataFrame(results)
    df_res.to_csv("few_shot_results.csv", index=False)
    plot_r2_gain_bar(df_res)

    print(f"✅ Done. Results -> few_shot_results.csv, Figures -> {OUT_DIR.resolve()}")
    print(f"Elapsed: {time.time()-t0:.1f}s")

# ========================== 6) 入口 ==========================
if __name__ == "__main__":
    run_few_shot()


Transformation Pipeline and Model Successfully Loaded
Transformation Pipeline and Model Successfully Loaded
Transformation Pipeline and Model Successfully Loaded
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000966 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 2560
[LightGBM] [Info] Number of data points in the train set: 12273, number of used features: 11
[LightGBM] [Info] Start training from score 1.603273


findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the fo

Transformation Pipeline and Model Successfully Loaded
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.001646 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 2823
[LightGBM] [Info] Number of data points in the train set: 24547, number of used features: 12
[LightGBM] [Info] Start training from score 1.154399


findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the fo

Transformation Pipeline and Model Successfully Loaded
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000919 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 2560
[LightGBM] [Info] Number of data points in the train set: 12273, number of used features: 11
[LightGBM] [Info] Start training from score -0.337456


findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the fo

Transformation Pipeline and Model Successfully Loaded
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.001493 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 2823
[LightGBM] [Info] Number of data points in the train set: 24547, number of used features: 12
[LightGBM] [Info] Start training from score -0.012602


findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the fo

Transformation Pipeline and Model Successfully Loaded
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.001820 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 2563
[LightGBM] [Info] Number of data points in the train set: 32505, number of used features: 11
[LightGBM] [Info] Start training from score 3.147782


findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the fo

Transformation Pipeline and Model Successfully Loaded
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.001874 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 2826
[LightGBM] [Info] Number of data points in the train set: 65011, number of used features: 12
[LightGBM] [Info] Start training from score 3.064891


findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the fo

Transformation Pipeline and Model Successfully Loaded
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.001640 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 2563
[LightGBM] [Info] Number of data points in the train set: 32505, number of used features: 11
[LightGBM] [Info] Start training from score -0.750681


findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the fo

Transformation Pipeline and Model Successfully Loaded
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.002206 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 2826
[LightGBM] [Info] Number of data points in the train set: 65011, number of used features: 12
[LightGBM] [Info] Start training from score -0.697114


findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the fo

Transformation Pipeline and Model Successfully Loaded
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.001154 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 3107
[LightGBM] [Info] Number of data points in the train set: 13149, number of used features: 13
[LightGBM] [Info] Start training from score 2.714658


findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the fo

Transformation Pipeline and Model Successfully Loaded
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.002034 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 3125
[LightGBM] [Info] Number of data points in the train set: 26299, number of used features: 13
[LightGBM] [Info] Start training from score 2.936363


findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the fo

Transformation Pipeline and Model Successfully Loaded
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.001024 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 3107
[LightGBM] [Info] Number of data points in the train set: 13149, number of used features: 13
[LightGBM] [Info] Start training from score -0.453426


findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the fo

Transformation Pipeline and Model Successfully Loaded
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.001369 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 3125
[LightGBM] [Info] Number of data points in the train set: 26299, number of used features: 13
[LightGBM] [Info] Start training from score -0.490124


findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the fo

✅ Done. Results -> few_shot_results.csv, Figures -> /root/autodl-tmp/dataset/fig_3_5
Elapsed: 33.6s


In [None]:
# -*- coding: utf-8 -*-
"""
Section 3.5 - Few-Shot Fine-Tuning (Result & Evaluation)
改动要点：
1) 仍然按站点与目标变量做 zero-shot 与 few-shot(10%、20%) 微调与评估；
2) 图像合并：同一站点/目标，将 10% 与 20% few-shot 放在**同一张图**里，
   使用两条曲线对比（在**相同测试时间窗**：从 max(10%,20%) 的切分点开始）。
3) 仍保存 few_shot_results.csv（各比例分别基于各自测试窗计算的指标），
   另生成合并图：{site}_{target}_FS_10_20.png，以及总体 ΔR² 柱状图。
"""

import os
import re
import glob
import gc
import sys
import shutil
import time
from pathlib import Path

import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
from scipy.stats import pearsonr
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error

from pycaret.regression import load_model, predict_model

# ========================== 配置 ==========================
MET_FOLDER = "plumber2_met_nc_files"
FLUX_FOLDER = "plumber2_nc_files"

TARGET_SITES = ["AU-ASM", "FI-Hyy", "US-UMB"]
FEW_SHOT_FRACS = [0.10, 0.20]

MODEL_GPP_NAME = "automl_170sites_GPP"  # 可传入不带 .pkl 的前缀
MODEL_NEE_NAME = "automl_170sites_NEE"

OUT_DIR = Path("fig_3_5")
OUT_DIR.mkdir(parents=True, exist_ok=True)

FEATURES_RAW = ['SWdown', 'LWdown', 'Tair', 'Qair', 'RH', 'Psurf', 'Wind',
                'CO2air', 'VPD', 'LAI', 'Ustar']
DERIVED_FEATURES = ['SW_LAI', 'RH_Tair', 'SWdown_lag1', 'Tair_lag1']
GEO_FEATURES = ['latitude', 'longitude']
FEATURES_ALL = FEATURES_RAW + DERIVED_FEATURES + GEO_FEATURES

# ========================== xarray 打开（含 Windows 中文路径容错） ==========================
_DEF_OPEN_BACKENDS = ("netcdf4", "h5netcdf")

def _has_non_ascii(s: str) -> bool:
    try:
        s.encode('ascii'); return False
    except UnicodeEncodeError:
        return True

def _ascii_cache_path(src_path: str) -> str:
    cache_root = Path(os.getenv('TEMP', 'C:\\temp')) / 'nc_cache'
    cache_root.mkdir(parents=True, exist_ok=True)
    return str(cache_root / Path(src_path).name)

def xr_open(path):
    abs_path = str(Path(path).resolve())
    if sys.platform.startswith('win') and _has_non_ascii(abs_path):
        cached = _ascii_cache_path(abs_path)
        try:
            if (not os.path.exists(cached)) or (os.path.getmtime(cached) < os.path.getmtime(abs_path)):
                shutil.copy2(abs_path, cached)
            read_path = cached
        except Exception:
            read_path = abs_path
    else:
        read_path = abs_path

    last_err = None
    for eng in _DEF_OPEN_BACKENDS:
        try:
            return xr.open_dataset(read_path, engine=eng)
        except Exception as e:
            last_err = e
            continue
    try:
        return xr.open_dataset(read_path)
    except Exception:
        pass
    raise RuntimeError(f"无法打开 {abs_path}，请安装 netCDF4/h5netcdf。原始错误: {last_err}")

# ========================== 经纬度提取 ==========================
def _extract_scalar_value(arr) -> float:
    try:
        val = getattr(arr, "values", arr)
        val = np.asarray(val)
        if val.ndim == 0:
            return float(val)
        return float(np.nanmean(val))
    except Exception:
        return np.nan

def get_lat_lon(ds: xr.Dataset):
    lat_candidates = ['lat', 'latitude', 'LAT', 'Latitude']
    lon_candidates = ['lon', 'longitude', 'LON', 'Longitude']

    lat, lon = np.nan, np.nan
    for name in lat_candidates:
        if name in ds.coords:
            lat = _extract_scalar_value(ds.coords[name]); break
    for name in lon_candidates:
        if name in ds.coords:
            lon = _extract_scalar_value(ds.coords[name]); break
    if np.isnan(lat):
        for name in lat_candidates:
            if name in ds.variables:
                lat = _extract_scalar_value(ds.variables[name]); break
    if np.isnan(lon):
        for name in lon_candidates:
            if name in ds.variables:
                lon = _extract_scalar_value(ds.variables[name]); break
    if np.isnan(lat):
        for k in ['site_latitude', 'Latitude', 'LAT', 'lat']:
            if k in ds.attrs:
                try: lat = float(ds.attrs[k]); break
                except: pass
    if np.isnan(lon):
        for k in ['site_longitude', 'Longitude', 'LON', 'lon']:
            if k in ds.attrs:
                try: lon = float(ds.attrs[k]); break
                except: pass

    if np.isnan(lat): lat = 0.0
    if np.isnan(lon): lon = 0.0
    if lon > 180: lon -= 360.0
    return float(lat), float(lon)

# ========================== 数据构建（自动匹配文件名） ==========================
def build_site_df(site_prefix: str) -> pd.DataFrame:
    met_candidates = glob.glob(os.path.join(MET_FOLDER, f"{site_prefix}*_Met.nc"))
    flux_candidates = glob.glob(os.path.join(FLUX_FOLDER, f"{site_prefix}*_Flux.nc"))
    if not met_candidates:
        raise FileNotFoundError(f"[{site_prefix}] 未找到 Met 文件（{MET_FOLDER}/{site_prefix}*_Met.nc）")
    if not flux_candidates:
        raise FileNotFoundError(f"[{site_prefix}] 未找到 Flux 文件（{FLUX_FOLDER}/{site_prefix}*_Flux.nc）")

    met_path = met_candidates[0]; flux_path = flux_candidates[0]
    met_ds = xr_open(met_path); flux_ds = xr_open(flux_path)
    try:
        met_df = met_ds.to_dataframe().reset_index()
        flux_df = flux_ds.to_dataframe().reset_index()

        df = pd.merge_asof(
            met_df.sort_values('time'), flux_df.sort_values('time'), on='time'
        )

        for col in FEATURES_RAW + ['GPP', 'NEE']:
            if col not in df.columns:
                df[col] = np.nan

        df['SW_LAI'] = df['SWdown'] * df['LAI']
        df['RH_Tair'] = df['RH'] * df['Tair']
        df['SWdown_lag1'] = df['SWdown'].shift(1)
        df['Tair_lag1'] = df['Tair'].shift(1)

        lat, lon = get_lat_lon(met_ds)
        df['latitude']  = np.float32(lat)
        df['longitude'] = np.float32(lon)

        keep_cols = ['time'] + FEATURES_ALL + ['GPP', 'NEE']
        df = df[keep_cols]
        df = df.dropna(subset=FEATURES_ALL + ['GPP', 'NEE']).sort_values('time').reset_index(drop=True)
        return df
    finally:
        try: met_ds.close()
        except: pass
        try: flux_ds.close()
        except: pass
        gc.collect()

# ========================== 指标与绘图 ==========================
def compute_metrics(y_true, y_pred):
    y_true = np.asarray(y_true).ravel()
    y_pred = np.asarray(y_pred).ravel()
    r2 = r2_score(y_true, y_pred)
    rmse = float(np.sqrt(mean_squared_error(y_true, y_pred)))
    mae = float(mean_absolute_error(y_true, y_pred))
    rho = float(pearsonr(y_true, y_pred)[0]) if (np.std(y_true)>1e-12 and np.std(y_pred)>1e-12) else np.nan
    return {"R2": r2, "RMSE": rmse, "MAE": mae, "PearsonR": rho}

def plot_r2_gain_bar(df_results: pd.DataFrame):
    if df_results.empty: return
    df = df_results.copy()
    df["R2_gain"] = df["R2_few"] - df["R2_zero"]
    df["Label"] = df["Site"] + " - " + df["Target"] + " (" + (df["Frac"]*100).astype(int).astype(str) + "%)"
    order = df.sort_values("R2_gain", ascending=False)["Label"].tolist()

    plt.figure(figsize=(10, 5))
    plt.barh(df.set_index("Label").loc[order].index, df.set_index("Label").loc[order]["R2_gain"])
    plt.gca().invert_yaxis()
    plt.xlabel("ΔR² (Few-shot − Zero-shot)")
    plt.title("R² Gain by Site / Target / Few-shot Ratio")
    plt.tight_layout()
    plt.savefig(OUT_DIR / "r2_gain_bar.png", dpi=200)
    plt.close()

def plot_combined(site, target, dates_common, y_true_common, y_zero_common,
                  y_few_dict):
    """
    将 10% 与 20% few-shot 曲线画在同一张图（统一测试时间窗）。
    y_few_dict: {0.10: y_pred_on_common, 0.20: y_pred_on_common}
    """
    plt.figure(figsize=(12, 4.5))
    plt.plot(dates_common, y_true_common, label="Observed")
    plt.plot(dates_common, y_zero_common, label="Zero-shot", alpha=0.9)

    for frac in sorted(y_few_dict.keys()):
        y = y_few_dict[frac]
        plt.plot(dates_common, y, label=f"Few-shot ({int(frac*100)}%)", alpha=0.9)

    plt.title(f"{site} - {target} (Zero vs Few-shot)")
    plt.xlabel("Date"); plt.ylabel(target); plt.legend()
    plt.tight_layout()
    plt.savefig(OUT_DIR / f"{site}_{target}_FS_10_20.png", dpi=200)
    plt.close()

# ========================== Few-shot 主流程 ==========================
def load_pycaret_model_flex(name_or_path: str):
    stem = name_or_path[:-4] if name_or_path.lower().endswith(".pkl") else name_or_path
    return load_model(stem)

def run_few_shot():
    t0 = time.time()
    results = []

    base_models = {
        "GPP": load_pycaret_model_flex(MODEL_GPP_NAME),
        "NEE": load_pycaret_model_flex(MODEL_NEE_NAME),
    }

    for site in TARGET_SITES:
        df_site = build_site_df(site)
        dates = pd.to_datetime(df_site["time"])

        for target in ["GPP", "NEE"]:
            if target not in df_site.columns:
                continue

            # —— 先训练两种比例的微调模型，并各自计算（各自测试窗）的指标 ——
            frac_to_model = {}
            frac_to_test_idx = {}
            for frac in FEW_SHOT_FRACS:
                n_few = max(1, int(len(df_site) * frac))
                few_df  = df_site.iloc[:n_few]
                test_df = df_site.iloc[n_few:]
                if len(test_df) < 5:
                    print(f"[Skip] {site}-{target} ({int(frac*100)}%): 测试样本太少 ({len(test_df)})")
                    continue

                # zero-shot（在各自测试窗）
                y_zero = predict_model(base_models[target], data=test_df[FEATURES_ALL])["prediction_label"].values

                # 重新加载，避免累积训练
                model_ft = load_pycaret_model_flex(MODEL_GPP_NAME if target=="GPP" else MODEL_NEE_NAME)
                model_ft = model_ft.fit(few_df[FEATURES_ALL], few_df[target])
                y_few = model_ft.predict(test_df[FEATURES_ALL])

                # 指标保存（各自窗）
                m_zero = compute_metrics(test_df[target].values, y_zero)
                m_few  = compute_metrics(test_df[target].values, y_few)
                results.append({
                    "Site": site, "Target": target, "Frac": frac,
                    "n_few": int(n_few), "n_test": int(len(test_df)),
                    "R2_zero": m_zero["R2"], "RMSE_zero": m_zero["RMSE"], "MAE_zero": m_zero["MAE"], "Pearson_zero": m_zero["PearsonR"],
                    "R2_few":  m_few["R2"],  "RMSE_few":  m_few["RMSE"],  "MAE_few":  m_few["MAE"],  "Pearson_few":  m_few["PearsonR"]
                })

                # 为合并图准备
                frac_to_model[frac] = model_ft
                frac_to_test_idx[frac] = n_few

            # —— 画合并图：统一测试窗，从最大 n_few 开始 ——
            if not frac_to_model:
                continue
            start_idx = max(frac_to_test_idx.values())
            test_df_common = df_site.iloc[start_idx:]
            if len(test_df_common) < 5:
                continue

            # 共同窗口上的 Observed / Zero-shot
            y_true_common = test_df_common[target].values
            y_zero_common = predict_model(base_models[target], data=test_df_common[FEATURES_ALL])["prediction_label"].values

            # 共同窗口上的 Few-shot 预测
            y_few_dict = {}
            for frac, model_ft in frac_to_model.items():
                y_few_dict[frac] = model_ft.predict(test_df_common[FEATURES_ALL])

            plot_combined(
                site, target,
                dates_common=dates.iloc[start_idx:],
                y_true_common=y_true_common,
                y_zero_common=y_zero_common,
                y_few_dict=y_few_dict
            )

            # 释放
            for m in frac_to_model.values():
                del m
            gc.collect()

    # 汇总与总览图
    df_res = pd.DataFrame(results)
    df_res.to_csv("few_shot_results.csv", index=False)
    plot_r2_gain_bar(df_res)

    print(f"✅ Done. Results -> few_shot_results.csv")
    print(f"✅ Figures -> {OUT_DIR.resolve()}  (含合并图 *_FS_10_20.png 与 r2_gain_bar.png)")
    print(f"Elapsed: {time.time()-t0:.1f}s")

if __name__ == "__main__":
    run_few_shot()


Transformation Pipeline and Model Successfully Loaded
Transformation Pipeline and Model Successfully Loaded
Transformation Pipeline and Model Successfully Loaded
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000992 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 2560
[LightGBM] [Info] Number of data points in the train set: 12273, number of used features: 11
[LightGBM] [Info] Start training from score 1.603273
Transformation Pipeline and Model Successfully Loaded
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.001285 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 2823
[LightGBM] [Info] Number of data points in the train set: 24547, number of used features: 12
[LightGBM] [Info] Start training from score 1.154399


findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the following families were found: Arial, Liberation Sans, Bitstream Vera Sans, sans-serif
findfont: Generic family 'sans-serif' not found because none of the fo