In [None]:
# -*- coding: utf-8 -*-
import os
import glob
import time
import gc
import sys
import shutil
from pathlib import Path
import warnings

import pandas as pd
import numpy as np
import xarray as xr

from flaml import AutoML
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from scipy.stats import pearsonr
import joblib

warnings.filterwarnings("ignore")

# ========================== 1) Global Parameters ==========================
met_folder = "plumber2_met_nc_files"
flux_folder = "plumber2_nc_files"

features_raw = ['SWdown', 'LWdown', 'Tair', 'Qair', 'RH', 'Psurf', 'Wind',
                'CO2air', 'VPD', 'LAI', 'Ustar']
derived_features = ['SW_LAI', 'RH_Tair', 'SWdown_lag1', 'Tair_lag1']
geo_features = ['latitude', 'longitude']
features_all = features_raw + derived_features + geo_features

# Batch experiments: 5/10/20/50/170
site_limits = [5, 10, 20, 50, 170]

# Time budget (seconds) for each scale
time_budget_map = {5: 240, 10: 420, 20: 900, 50: 1500, 170: 2400}
default_time_budget = 600

RANDOM_SEED = 42
N_SPLITS = 5
TEST_FRAC_PER_SITE = 0.10  # The last 10% of each site reserved as hold-out test set

# ========================== 2) xarray open function compatible with Windows non-ASCII paths ==========================
_def_open_backends = ("netcdf4", "h5netcdf")

def _has_non_ascii(s: str) -> bool:
    try:
        s.encode('ascii')
        return False
    except UnicodeEncodeError:
        return True

def _ascii_cache_path(src_path: str) -> str:
    cache_root = Path(os.getenv('TEMP', 'C:\\temp')) / 'nc_cache'
    cache_root.mkdir(parents=True, exist_ok=True)
    return str(cache_root / Path(src_path).name)

def xr_open(path):
    abs_path = str(Path(path).resolve())
    if sys.platform.startswith('win') and _has_non_ascii(abs_path):
        cached = _ascii_cache_path(abs_path)
        try:
            if (not os.path.exists(cached)) or (
                os.path.getmtime(cached) < os.path.getmtime(abs_path)
            ):
                shutil.copy2(abs_path, cached)
            read_path = cached
        except Exception:
            read_path = abs_path
    else:
        read_path = abs_path

    last_err = None
    for eng in _def_open_backends:
        try:
            return xr.open_dataset(read_path, engine=eng)
        except Exception as e:
            last_err = e
            continue
    raise RuntimeError(
        f"Failed to open {abs_path}. Please install netCDF4 or h5netcdf, "
        f"or move the data to an ASCII path. Original error: {last_err}"
    )

# ========================== 3) Latitude and Longitude Extraction ==========================
def _extract_scalar_value(arr) -> float:
    try:
        val = arr.values
        if np.ndim(val) == 0:
            return float(val)
        else:
            return float(np.nanmean(val))
    except Exception:
        return np.nan

def get_lat_lon(ds: xr.Dataset):
    lat_candidates = ['lat', 'latitude', 'LAT', 'Latitude']
    lon_candidates = ['lon', 'longitude', 'LON', 'Longitude']

    lat, lon = np.nan, np.nan
    for name in lat_candidates:
        if name in ds.coords:
            lat = _extract_scalar_value(ds.coords[name]); break
        if name in ds.variables:
            lat = _extract_scalar_value(ds.variables[name]); break
    for name in lon_candidates:
        if name in ds.coords:
            lon = _extract_scalar_value(ds.coords[name]); break
        if name in ds.variables:
            lon = _extract_scalar_value(ds.variables[name]); break

    if np.isnan(lat):
        for k in ['site_latitude', 'Latitude', 'LAT', 'lat']:
            if k in ds.attrs:
                try: lat = float(ds.attrs[k]); break
                except: pass
    if np.isnan(lon):
        for k in ['site_longitude', 'Longitude', 'LON', 'lon']:
            if k in ds.attrs:
                try: lon = float(ds.attrs[k]); break
                except: pass
    if np.isnan(lat): lat = 0.0
    if np.isnan(lon): lon = 0.0
    if lon > 180: lon -= 360.0
    return float(lat), float(lon)

# ========================== 4) Single-site Preprocessing ==========================
def preprocess_site(met_path, flux_path, lat, lon, site_name):
    try:
        met_ds = xr_open(met_path)
        flux_ds = xr_open(flux_path)

        met_df = met_ds.to_dataframe().reset_index()
        flux_df = flux_ds.to_dataframe().reset_index()

        df = pd.merge_asof(
            met_df.sort_values('time'),
            flux_df.sort_values('time'),
            on='time'
        )

        keep_cols = ['time'] + features_raw + ['GPP', 'NEE']
        df = df[keep_cols].dropna()

        # Derived features
        df['SW_LAI'] = df['SWdown'] * df['LAI']
        df['RH_Tair'] = df['RH'] * df['Tair']
        df['SWdown_lag1'] = df['SWdown'].shift(1)
        df['Tair_lag1'] = df['Tair'].shift(1)

        # Site metadata
        df['latitude']  = np.float32(lat)
        df['longitude'] = np.float32(lon)
        df['site'] = site_name

        df = df.dropna().sort_values('time')  # Keep time column, split by site end later
        for c in [*features_all, 'GPP', 'NEE']:
            df[c] = df[c].astype('float32')

        met_ds.close(); flux_ds.close()
        del met_ds, flux_ds, met_df, flux_df
        gc.collect()
        return df
    except KeyError as e:
        print(f"❌ Preprocessing failed (missing variables): {met_path}\nMissing column: {e}")
        return None
    except Exception as e:
        print(f"❌ Preprocessing failed: {met_path}\nError: {e}")
        return None

# ========================== 5) Leave-out hold-out set at the end of each site ==========================
def split_train_test_by_site(df_all: pd.DataFrame, test_frac: float = TEST_FRAC_PER_SITE):
    train_parts, test_parts = [], []
    for site, df_site in df_all.groupby('site'):
        df_site = df_site.sort_values('time')
        n = len(df_site)
        n_test = max(1, int(round(n * test_frac)))
        test_parts.append(df_site.iloc[-n_test:])
        train_parts.append(df_site.iloc[:-n_test] if n_test < n else df_site.iloc[:0])

    train_df = pd.concat(train_parts, axis=0).reset_index(drop=True)
    test_df  = pd.concat(test_parts,  axis=0).reset_index(drop=True)
    return train_df, test_df

# ========================== 6) Training (FLAML AutoML) and Saving ==========================
def train_and_save_flaml(df_all: pd.DataFrame, target: str, limit: int):
    assert target in ['NEE', 'GPP']

    # 1) Split by-site tail into hold-out test set
    train_df, test_df = split_train_test_by_site(df_all, TEST_FRAC_PER_SITE)
    print(f"🧪 Split done: train={len(train_df)}, test={len(test_df)} (by-site tail {int(TEST_FRAC_PER_SITE*100)}%)")

    X_train = train_df[features_all].copy()
    y_train = train_df[target].astype('float32').values
    X_test  = test_df[features_all].copy()
    y_test  = test_df[target].astype('float32').values
    groups  = train_df['site'].values  # GroupKFold by site (handled by FLAML)

    # 2) Estimator list (depends on environment availability)
    estimator_list = []
    try:
        import xgboost  # noqa
        estimator_list.append('xgboost')
    except Exception:
        pass
    try:
        import lightgbm  # noqa
        estimator_list.append('lgbm')
    except Exception:
        pass
    estimator_list += ['rf', 'extra_tree', 'lrl1']  # Fallback

    time_budget = time_budget_map.get(limit, default_time_budget)

    # 3) FLAML settings (do not pass fit_kwargs_by_estimator anymore)
    from flaml import AutoML
    automl = AutoML()
    automl_settings = {
        "task": "regression",
        "metric": "r2",
        "estimator_list": estimator_list,
        "log_file_name": f"flaml_{limit}sites_{target}.log",
        "eval_method": "cv",
        "n_splits": N_SPLITS,
        "split_type": "group",
        "groups": groups,
        "time_budget": time_budget,
        "seed": RANDOM_SEED,
        "verbose": 0,
    }

    print(f"⚙️  FLAML fit: target={target}, sites={limit}, budget={time_budget}s, estimators={estimator_list}")
    automl.fit(X_train=X_train, y_train=y_train, **automl_settings)

    # 4) Evaluation on hold-out set
    y_pred = automl.predict(X_test)
    r2   = r2_score(y_test, y_pred)
    rmse = mean_squared_error(y_test, y_pred, squared=False)
    mae  = mean_absolute_error(y_test, y_pred)
    try:
        rho = pearsonr(y_test, y_pred)[0]
    except Exception:
        rho = np.nan

    # 5) Save model
    model_path = f'flaml_{limit}sites_{target}.pkl'
    joblib.dump(automl, model_path)
    print(f"✅ Saved: {model_path}")

    return {
        "Sites": limit,
        "Target": target,
        "BestEstimator": automl.best_estimator,
        "BestConfig": str(automl.best_config),
        "BestR2_CV": (None if automl.best_loss is None else 1 - automl.best_loss),
        "R2_Holdout": r2,
        "RMSE_Holdout": rmse,
        "MAE_Holdout": mae,
        "Rho_Holdout": rho,
        "TimeBudgetSec": time_budget
    }

# ========================== 7) Main Workflow ==========================
def main():
    print(f"Current working directory: {os.getcwd()}")
    met_files_all = sorted(glob.glob(os.path.join(met_folder, "*_Met.nc")))
    print(f"Total {len(met_files_all)} matched Met files")
    if len(met_files_all) == 0:
        raise RuntimeError("No *_Met.nc files found, please check met_folder path.")

    summary_rows = []

    for limit in site_limits:
        t0 = time.time()
        print("\n" + "=" * 80)
        print(f"🚀 Processing first {limit} sites ...")

        all_data, used_sites = [], 0
        for met_path in met_files_all[:limit]:
            site_prefix = os.path.basename(met_path).replace("_Met.nc", "")
            flux_path = os.path.join(flux_folder, site_prefix + "_Flux.nc")
            if not os.path.exists(flux_path):
                print(f"⚠️ Missing Flux file, skipping {site_prefix}")
                continue

            try:
                ds = xr_open(met_path)
                lat, lon = get_lat_lon(ds)
                ds.close(); del ds
            except Exception:
                lat, lon = 0.0, 0.0

            df_site = preprocess_site(met_path, flux_path, lat, lon, site_prefix)
            if df_site is not None and len(df_site) > 0:
                all_data.append(df_site)
                used_sites += 1

        if not all_data:
            print(f"❌ limit={limit} Failed to load any site data, skipped.")
            continue

        df_all = pd.concat(all_data, axis=0).reset_index(drop=True)
        print(f"✅ Data prepared: sites {used_sites}, total records {len(df_all)}")

        res_nee = train_and_save_flaml(df_all, 'NEE', limit)
        res_gpp = train_and_save_flaml(df_all, 'GPP', limit)

        elapsed = round(time.time() - t0, 2)
        for r in (res_nee, res_gpp):
            r["WallTimeSec_Total"] = elapsed
            summary_rows.append(r)

        del all_data, df_all, res_nee, res_gpp
        gc.collect()

    if summary_rows:
        df_sum = pd.DataFrame(summary_rows)
        df_sum.to_csv("automl_flaml_results_summary.csv", index=False)
        print("\n📊 Generated automl_flaml_results_summary.csv and saved all models.")
    else:
        print("\n⚠️ No results produced, please check data and environment.")

if __name__ == "__main__":
    main()


当前工作目录: /root/autodl-tmp/dataset
总计匹配到 170 个 Met 文件

🚀 正在处理前 5 个站点 ...
✅ 数据准备完成：站点数 5，总记录 561019
🧪 Split done: train=504916, test=56103 (by-site tail 10%)
⚙️  FLAML fit: target=NEE, sites=5, budget=240s, estimators=['xgboost', 'lgbm', 'rf', 'extra_tree', 'lrl1']
