In [None]:
import os
import math
import json
from pathlib import Path
import numpy as np
import pandas as pd
import xarray as xr
from concurrent.futures import ProcessPoolExecutor, as_completed

# ==========
# 线程超卖控制（重要）
# ==========
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")

# 如果你的文件系统锁会出问题，可尝试：
os.environ.setdefault("HDF5_USE_PATH_LOCKING", "FALSE")


'FALSE'

In [None]:
# ======================
# PATH CONFIG
# ======================
NOTEBOOK_DIR = Path().resolve()
ROOT_DIR = (NOTEBOOK_DIR / "../data").resolve()
RUN_DIR = ROOT_DIR / "run"
CWD_PATH  = RUN_DIR / "CWD.nc"
dEF_PATH  = ROOT_DIR / "ET/dEF.nc"
dSIF_PATH = ROOT_DIR / "SIF/dSIF.nc"

OUT_GLOBAL_DIR = RUN_DIR / "veg_activity"
OUT_GLOBAL_DIR.mkdir(parents=True, exist_ok=True)

OUT_TILE_DIR = RUN_DIR / "veg_activity/tiles"
OUT_TILE_DIR.mkdir(parents=True, exist_ok=True)

# ======================
# VARIABLE NAMES
# ======================
VAR_CWD  = "CWD"
VAR_dEF  = "dEF"
VAR_dSIF = "dSIF"

# ======================
# TILE CONFIG
# ======================
TILE_Y = 300   # lat block
TILE_X = 300   # lon block

# ======================
# METHOD PARAMS
# ======================
NBINS = 50
Q = 0.90

MIN_EVENT_STEPS = 1       # >=8天; 你的数据是8-day，所以1 step = 8天
MINVALIDBINS = 8
MINPOINTSTOTAL = 30

RUNAWAY_YEARS = 5         # 连续>=5年无CWD<=0重置 -> discard pixel

# 缺测值处理（你的 dEF/dSIF 是 -9999）
FILL_X = -9999.0

# 输出压缩
ENC_2D = dict(zlib=True, complevel=4, dtype="float32")


In [3]:
ds_cwd  = xr.open_dataset(CWD_PATH,  engine="netcdf4")
ds_def  = xr.open_dataset(dEF_PATH,  engine="netcdf4")
ds_dsif = xr.open_dataset(dSIF_PATH, engine="netcdf4")

print(ds_cwd)
print(ds_def)
print(ds_dsif)

# 取 time/lat/lon，确保一致
time = ds_cwd["time"].values
lat  = ds_cwd["lat"].values
lon  = ds_cwd["lon"].values

assert ds_def["time"].size == ds_cwd["time"].size
assert ds_dsif["time"].size == ds_cwd["time"].size
assert ds_def["lat"].size  == ds_cwd["lat"].size
assert ds_def["lon"].size  == ds_cwd["lon"].size

print("time:", time[0], "->", time[-1], "len=", len(time))
print("lat:", lat[0], lat[-1], "n=", len(lat))
print("lon:", lon[0], lon[-1], "n=", len(lon))


<xarray.Dataset>
Dimensions:  (time: 828, lat: 3600, lon: 7200)
Coordinates:
  * time     (time) datetime64[ns] 2003-01-01 2003-01-09 ... 2020-12-26
  * lat      (lat) float32 -89.97 -89.93 -89.88 -89.82 ... 89.88 89.93 89.97
  * lon      (lon) float32 -180.0 -179.9 -179.9 -179.8 ... 179.9 179.9 180.0
Data variables:
    CWD      (time, lat, lon) float32 ...
<xarray.Dataset>
Dimensions:  (time: 828, lon: 7200, lat: 3600)
Coordinates:
  * time     (time) datetime64[ns] 2003-01-01 2003-01-09 ... 2020-12-26
  * lon      (lon) float64 -180.0 -179.9 -179.9 -179.8 ... 179.9 179.9 180.0
  * lat      (lat) float64 -89.97 -89.92 -89.88 -89.83 ... 89.88 89.93 89.98
Data variables:
    crs      |S1 ...
    dEF      (time, lat, lon) float32 ...
Attributes:
    CDI:                 Climate Data Interface version 2.5.0 (https://mpimet...
    Conventions:         CF-1.5
    GDAL_AREA_OR_POINT:  Area
    GDAL:                GDAL 3.10.3, released 2025/04/01
    history:             Fri Jan 09 21:33:58

In [4]:
def build_year_slices(time_index: np.ndarray) -> np.ndarray:
    """
    return int32 array shape (n_years, 2) with [start, end) indices
    """
    t = pd.to_datetime(time_index)
    years = t.year.values
    uniq_years = np.unique(years)
    slices = []
    for y in uniq_years:
        idx = np.where(years == y)[0]
        if idx.size == 0:
            continue
        slices.append((int(idx[0]), int(idx[-1] + 1)))
    return np.array(slices, dtype=np.int32)

YEAR_SLICES = build_year_slices(time)
print("N years:", YEAR_SLICES.shape[0], "first/last:", YEAR_SLICES[0], YEAR_SLICES[-1])


N years: 18 first/last: [ 0 46] [782 828]


In [7]:
NY = len(lat)
NX = len(lon)

def make_tiles(ny, nx, tile_y, tile_x):
    tiles = []
    for y0 in range(0, ny, tile_y):
        y1 = min(ny, y0 + tile_y)
        for x0 in range(0, nx, tile_x):
            x1 = min(nx, x0 + tile_x)
            tiles.append((y0, y1, x0, x1))
    return tiles

TILES = make_tiles(NY, NX, TILE_Y, TILE_X)
print("Total tiles:", len(TILES))
print("Example tile:", TILES[0])


Total tiles: 288
Example tile: (0, 300, 0, 300)


In [8]:
def load_tile(y0, y1, x0, x1):
    cwd  = ds_cwd[VAR_CWD].isel(lat=slice(y0,y1), lon=slice(x0,x1)).transpose("time","lat","lon").load()
    dEF  = ds_def[VAR_dEF].isel(lat=slice(y0,y1), lon=slice(x0,x1)).transpose("time","lat","lon").load()
    dSIF = ds_dsif[VAR_dSIF].isel(lat=slice(y0,y1), lon=slice(x0,x1)).transpose("time","lat","lon").load()

    # 缺测：-9999 -> NaN
    dEF  = dEF.where(dEF != FILL_X)
    dSIF = dSIF.where(dSIF != FILL_X)

    # 转 float32（减内存/加速）
    return (cwd.data.astype(np.float32),
            dEF.data.astype(np.float32),
            dSIF.data.astype(np.float32))

# 试一个 tile
y0,y1,x0,x1 = TILES[0]
cwd_t, def_t, dsif_t = load_tile(y0,y1,x0,x1)
print(cwd_t.shape, def_t.shape, dsif_t.shape)
print("CWD finite ratio:", np.isfinite(cwd_t).mean())
print("dEF finite ratio:", np.isfinite(def_t).mean())
print("dSIF finite ratio:", np.isfinite(dsif_t).mean())


(828, 300, 300) (828, 300, 300) (828, 300, 300)
CWD finite ratio: 1.0
dEF finite ratio: 0.0
dSIF finite ratio: 3.6500268384326356e-06


In [9]:
import numba as nb

@nb.njit(cache=True)
def _event_window_one_year(cwd_year, min_event_steps):
    """
    cwd_year: 1D float32, length Ty
    return (ok, start, end_exclusive, cwdmax)
    """
    Ty = cwd_year.size
    # 找 peak（最大值）
    peak = -1
    cwdmax = -1e30
    for i in range(Ty):
        v = cwd_year[i]
        if np.isfinite(v) and v > cwdmax:
            cwdmax = v
            peak = i
    if peak < 0 or not np.isfinite(cwdmax) or cwdmax <= 0.0:
        return 0, 0, 0, 0.0

    # start: 从 peak 向前找最后一个 CWD<=0
    start = 0
    for i in range(peak, -1, -1):
        v = cwd_year[i]
        if np.isfinite(v) and v <= 0.0:
            start = i
            break

    # end: 从 peak 向后找第一个 CWD < 0.9*CWDmax
    thr = 0.9 * cwdmax
    end = Ty
    for i in range(peak, Ty):
        v = cwd_year[i]
        if np.isfinite(v) and v < thr:
            end = i
            break

    if end - start < min_event_steps:
        return 0, 0, 0, 0.0

    return 1, start, end, cwdmax


@nb.njit(cache=True)
def _count_runaway_years(cwd, year_slices):
    """
    统计“连续多少年该年内从未出现 CWD<=0”（即没有重置）。
    返回 max consecutive years.
    """
    nY = year_slices.shape[0]
    max_consec = 0
    consec = 0
    for yi in range(nY):
        s = year_slices[yi,0]
        e = year_slices[yi,1]
        has_reset = False
        for t in range(s, e):
            v = cwd[t]
            if np.isfinite(v) and v <= 0.0:
                has_reset = True
                break
        if has_reset:
            consec = 0
        else:
            consec += 1
            if consec > max_consec:
                max_consec = consec
    return max_consec


@nb.njit(cache=True)
def _linfit_s0(x, y):
    """
    y = a + b*x  (OLS)
    return (ok, a, b, s0)
    """
    n = x.size
    sx = 0.0
    sy = 0.0
    sxx = 0.0
    sxy = 0.0
    for i in range(n):
        sx += x[i]
        sy += y[i]
        sxx += x[i]*x[i]
        sxy += x[i]*y[i]

    denom = n*sxx - sx*sx
    if denom == 0.0:
        return 0, 0.0, 0.0, np.nan

    b = (n*sxy - sx*sy)/denom
    a = (sy - b*sx)/n

    if not np.isfinite(a) or not np.isfinite(b):
        return 0, 0.0, 0.0, np.nan
    if b >= 0.0:
        return 0, a, b, np.nan

    s0 = -a/b
    if (not np.isfinite(s0)) or (s0 <= 0.0):
        return 0, a, b, np.nan

    return 1, a, b, s0


@nb.njit(cache=True)
def _s0_one_pixel(cwd, X, year_slices, nbins, q,
                  min_event_steps, min_valid_bins, min_points_total, runaway_years):
    """
    cwd, X: 1D length T
    return s0 (float32) or nan
    """
    # runaway check
    if _count_runaway_years(cwd, year_slices) >= runaway_years:
        return np.nan

    T = cwd.size
    nY = year_slices.shape[0]

    # pooling arrays (最多 T 点)
    pool_cwd = np.empty(T, dtype=np.float32)
    pool_x   = np.empty(T, dtype=np.float32)
    n_pool = 0
    gmax = 0.0

    # per-year largest event (在每年 slice 内找最大值窗口)
    for yi in range(nY):
        s = year_slices[yi,0]
        e = year_slices[yi,1]
        ok, st, ed, cwdmax = _event_window_one_year(cwd[s:e], min_event_steps)
        if ok == 0:
            continue

        # 把该窗口内点加入 pool
        for tt in range(s+st, s+ed):
            c = cwd[tt]
            x = X[tt]
            if np.isfinite(c) and np.isfinite(x) and c >= 0.0:
                pool_cwd[n_pool] = c
                pool_x[n_pool] = x
                n_pool += 1
                if c > gmax:
                    gmax = c

    if n_pool < min_points_total or (not np.isfinite(gmax)) or gmax <= 0.0:
        return np.nan

    # ===== 分箱：先 count，再按 bin regroup（O(n_pool)）
    counts = np.zeros(nbins, dtype=np.int32)
    bin_id = np.empty(n_pool, dtype=np.int32)

    for i in range(n_pool):
        b = int(pool_cwd[i] / gmax * nbins)
        if b < 0:
            b = 0
        elif b >= nbins:
            b = nbins - 1
        bin_id[i] = b
        counts[b] += 1

    # prefix sum -> offsets
    offsets = np.empty(nbins, dtype=np.int32)
    pos = 0
    for b in range(nbins):
        offsets[b] = pos
        pos += counts[b]

    # regroup X by bin in one array
    xgrp = np.empty(n_pool, dtype=np.float32)
    cursor = offsets.copy()
    for i in range(n_pool):
        b = bin_id[i]
        p = cursor[b]
        xgrp[p] = pool_x[i]
        cursor[b] += 1

    # 对每个 bin：sort in-place -> 取 q 分位
    # 同时构建回归点 (xbincenter, Xq)
    xb = np.empty(nbins, dtype=np.float32)
    yq = np.empty(nbins, dtype=np.float32)
    n_valid = 0

    for b in range(nbins):
        n = counts[b]
        if n <= 0:
            continue
        st = offsets[b]
        ed = st + n

        # in-place sort
        xgrp[st:ed].sort()

        # q-quantile index（nearest-rank风格）
        k = int(math.ceil(q * n) - 1)
        if k < 0:
            k = 0
        if k >= n:
            k = n - 1

        xq = xgrp[st + k]
        if not np.isfinite(xq):
            continue

        # bin center
        center = (b + 0.5) / nbins * gmax
        xb[n_valid] = center
        yq[n_valid] = xq
        n_valid += 1

    if n_valid < min_valid_bins:
        return np.nan

    ok, a, b, s0 = _linfit_s0(xb[:n_valid], yq[:n_valid])
    if ok == 0:
        return np.nan
    return np.float32(s0)


@nb.njit(parallel=True, cache=True)
def s0_tile(cwd3d, X3d, year_slices, nbins, q,
            min_event_steps, min_valid_bins, min_points_total, runaway_years):
    """
    cwd3d, X3d: (T, Y, X)
    return s0_2d: (Y, X)
    """
    T, Y, X = cwd3d.shape
    out = np.empty((Y, X), dtype=np.float32)
    n_pix = Y * X

    for p in nb.prange(n_pix):
        j = p // X
        i = p - j * X
        out[j, i] = _s0_one_pixel(cwd3d[:, j, i], X3d[:, j, i],
                                  year_slices, nbins, q,
                                  min_event_steps, min_valid_bins, min_points_total, runaway_years)
    return out


In [10]:
# 选一个 tile
y0,y1,x0,x1 = TILES[0]
cwd_t, def_t, dsif_t = load_tile(y0,y1,x0,x1)

# 运行 numba（第一次会编译，稍慢；后面很快）
s0_def  = s0_tile(cwd_t, def_t, YEAR_SLICES, NBINS, Q,
                  MIN_EVENT_STEPS, MINVALIDBINS, MINPOINTSTOTAL, RUNAWAY_YEARS)

s0_dsif = s0_tile(cwd_t, dsif_t, YEAR_SLICES, NBINS, Q,
                  MIN_EVENT_STEPS, MINVALIDBINS, MINPOINTSTOTAL, RUNAWAY_YEARS)

print("SdEF tile:", np.nanmin(s0_def), np.nanmax(s0_def), "nan%", np.isnan(s0_def).mean())
print("SdSIF tile:", np.nanmin(s0_dsif), np.nanmax(s0_dsif), "nan%", np.isnan(s0_dsif).mean())


SdEF tile: nan nan nan% 1.0
SdSIF tile: nan nan nan% 1.0


  print("SdEF tile:", np.nanmin(s0_def), np.nanmax(s0_def), "nan%", np.isnan(s0_def).mean())
  print("SdSIF tile:", np.nanmin(s0_dsif), np.nanmax(s0_dsif), "nan%", np.isnan(s0_dsif).mean())


In [None]:
def write_tile(y0,y1,x0,x1, s0_def, s0_dsif):
    lat_tile = lat[y0:y1]
    lon_tile = lon[x0:x1]

    da_def = xr.DataArray(s0_def, dims=("lat","lon"),
                          coords={"lat": lat_tile, "lon": lon_tile},
                          name="SdEF").astype("float32")
    da_sif = xr.DataArray(s0_dsif, dims=("lat","lon"),
                          coords={"lat": lat_tile, "lon": lon_tile},
                          name="SdSIF").astype("float32")

    tile_tag = f"y{y0:04d}-{y1:04d}_x{x0:04d}-{x1:04d}"
    f_def = OUT_TILE_DIR / f"SdEF_{tile_tag}.nc"
    f_sif = OUT_TILE_DIR / f"SdSIF_{tile_tag}.nc"

    xr.Dataset({"SdEF": da_def}).to_netcdf(f_def, engine="netcdf4", format="NETCDF4",
                                          encoding={"SdEF": ENC_2D})
    xr.Dataset({"SdSIF": da_sif}).to_netcdf(f_sif, engine="netcdf4", format="NETCDF4",
                                            encoding={"SdSIF": ENC_2D})

    return f_def, f_sif

f1, f2 = write_tile(y0,y1,x0,x1, s0_def, s0_dsif)
print("Wrote:", f1, f2)


Wrote: /tera04/zhwei/xionghui/bedrock/data/out_tiles/SdEF_y0000-0300_x0000-0300.nc4 /tera04/zhwei/xionghui/bedrock/data/out_tiles/SdSIF_y0000-0300_x0000-0300.nc4


In [None]:
def worker_one_tile(tile_id):
    y0,y1,x0,x1 = TILES[tile_id]
    tile_tag = f"y{y0:04d}-{y1:04d}_x{x0:04d}-{x1:04d}"
    f_def = OUT_TILE_DIR / f"SdEF_{tile_tag}.nc"
    f_sif = OUT_TILE_DIR / f"SdSIF_{tile_tag}.nc"

    # checkpoint: 已存在则跳过
    if f_def.exists() and f_sif.exists():
        return tile_id, "SKIP"

    # 每个进程独立打开（关键）
    ds_c = xr.open_dataset(CWD_PATH, engine="netcdf4")
    ds_e = xr.open_dataset(dEF_PATH, engine="netcdf4")
    ds_s = xr.open_dataset(dSIF_PATH, engine="netcdf4")

    cwd  = ds_c[VAR_CWD].isel(lat=slice(y0,y1), lon=slice(x0,x1)).transpose("time","lat","lon").load().data.astype(np.float32)
    dEF  = ds_e[VAR_dEF].isel(lat=slice(y0,y1), lon=slice(x0,x1)).transpose("time","lat","lon").load()
    dSIF = ds_s[VAR_dSIF].isel(lat=slice(y0,y1), lon=slice(x0,x1)).transpose("time","lat","lon").load()

    dEF  = dEF.where(dEF != FILL_X).data.astype(np.float32)
    dSIF = dSIF.where(dSIF != FILL_X).data.astype(np.float32)

    # compute
    s0_def  = s0_tile(cwd, dEF,  YEAR_SLICES, NBINS, Q,
                      MIN_EVENT_STEPS, MINVALIDBINS, MINPOINTSTOTAL, RUNAWAY_YEARS)
    s0_dsif = s0_tile(cwd, dSIF, YEAR_SLICES, NBINS, Q,
                      MIN_EVENT_STEPS, MINVALIDBINS, MINPOINTSTOTAL, RUNAWAY_YEARS)

    # write
    lat_tile = lat[y0:y1]
    lon_tile = lon[x0:x1]

    da_def = xr.DataArray(s0_def, dims=("lat","lon"),
                          coords={"lat": lat_tile, "lon": lon_tile},
                          name="SdEF").astype("float32")
    da_sif = xr.DataArray(s0_dsif, dims=("lat","lon"),
                          coords={"lat": lat_tile, "lon": lon_tile},
                          name="SdSIF").astype("float32")

    xr.Dataset({"SdEF": da_def}).to_netcdf(f_def, engine="netcdf4", format="NETCDF4",
                                           encoding={"SdEF": ENC_2D})
    xr.Dataset({"SdSIF": da_sif}).to_netcdf(f_sif, engine="netcdf4", format="NETCDF4",
                                            encoding={"SdSIF": ENC_2D})
    return tile_id, "OK"


# ====== 并行执行 ======
N_WORKERS = 24  # 你说至少24进程

done = 0
fail = 0

with ProcessPoolExecutor(max_workers=N_WORKERS) as ex:
    futs = {ex.submit(worker_one_tile, tid): tid for tid in range(len(TILES))}
    for fut in as_completed(futs):
        tid = futs[fut]
        try:
            tid2, status = fut.result()
            done += 1
            if status != "OK" and status != "SKIP":
                fail += 1
            if done % 10 == 0:
                print(f"[PROGRESS] {done}/{len(TILES)} done, fail={fail}")
        except Exception as e:
            fail += 1
            print(f"[FAIL] tile {tid}:", repr(e))

print("All done. fail=", fail)


[PROGRESS] 10/288 done, fail=0
[PROGRESS] 20/288 done, fail=0
[PROGRESS] 30/288 done, fail=0
[PROGRESS] 40/288 done, fail=0
[PROGRESS] 50/288 done, fail=0
[PROGRESS] 60/288 done, fail=0
[PROGRESS] 70/288 done, fail=0
[PROGRESS] 80/288 done, fail=0
[PROGRESS] 90/288 done, fail=0
[PROGRESS] 100/288 done, fail=0
[PROGRESS] 110/288 done, fail=0
[PROGRESS] 120/288 done, fail=0
[PROGRESS] 130/288 done, fail=0
[PROGRESS] 140/288 done, fail=0
[PROGRESS] 150/288 done, fail=0
[PROGRESS] 160/288 done, fail=0
[PROGRESS] 170/288 done, fail=0
[PROGRESS] 180/288 done, fail=0
[PROGRESS] 190/288 done, fail=0
[PROGRESS] 200/288 done, fail=0
[PROGRESS] 210/288 done, fail=0
[PROGRESS] 220/288 done, fail=0
[PROGRESS] 230/288 done, fail=0
[PROGRESS] 240/288 done, fail=0
[PROGRESS] 250/288 done, fail=0
[PROGRESS] 260/288 done, fail=0
[PROGRESS] 270/288 done, fail=0
[PROGRESS] 280/288 done, fail=0
All done. fail= 0


In [10]:
def assemble_global(varname, pattern_prefix, out_path):
    files = sorted(OUT_TILE_DIR.glob(f"{pattern_prefix}_y*-*_x*-*.nc"))
    if len(files) == 0:
        raise RuntimeError("No tile files found for " + pattern_prefix)

    # open_mfdataset 自动按坐标拼接（要求 tile 坐标不重叠且完整）
    ds = xr.open_mfdataset(files, combine="by_coords", engine="netcdf4")
    da = ds[varname].astype("float32")

    # 输出 chunk 可自行调
    enc = {varname: dict(zlib=True, complevel=4, dtype="float32", chunksizes=(600,600))}
    xr.Dataset({varname: da}).to_netcdf(out_path, engine="netcdf4", format="NETCDF4", encoding=enc)
    return out_path

out_def = assemble_global("SdEF", "SdEF", OUT_GLOBAL_DIR / "SdEF.nc")
out_sif = assemble_global("SdSIF", "SdSIF", OUT_GLOBAL_DIR / "SdSIF.nc")
print("Wrote global:", out_def, out_sif)


Wrote global: /tera04/zhwei/xionghui/bedrock/data/veg_activity/SdEF.nc /tera04/zhwei/xionghui/bedrock/data/veg_activity/SdSIF.nc


In [11]:
ds1 = xr.open_dataset(out_def)
ds2 = xr.open_dataset(out_sif)

SdEF  = ds1["SdEF"].values
SdSIF = ds2["SdSIF"].values

qc = {
    "SdEF": {
        "nan_frac": float(np.isnan(SdEF).mean()),
        "p01": float(np.nanpercentile(SdEF, 1)),
        "p50": float(np.nanpercentile(SdEF, 50)),
        "p99": float(np.nanpercentile(SdEF, 99)),
        "min": float(np.nanmin(SdEF)),
        "max": float(np.nanmax(SdEF)),
    },
    "SdSIF": {
        "nan_frac": float(np.isnan(SdSIF).mean()),
        "p01": float(np.nanpercentile(SdSIF, 1)),
        "p50": float(np.nanpercentile(SdSIF, 50)),
        "p99": float(np.nanpercentile(SdSIF, 99)),
        "min": float(np.nanmin(SdSIF)),
        "max": float(np.nanmax(SdSIF)),
    }
}
print(json.dumps(qc, indent=2))


{
  "SdEF": {
    "nan_frac": 0.8817354938271605,
    "p01": 25.294102001190186,
    "p50": 324.04734802246094,
    "p99": 24972.2095703125,
    "min": 0.004365758970379829,
    "max": 293556928.0
  },
  "SdSIF": {
    "nan_frac": 0.8932793981481482,
    "p01": 11.527358474731447,
    "p50": 348.12701416015625,
    "p99": 19929.036035156165,
    "min": 0.0007142720278352499,
    "max": 477670848.0
  }
}
