In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Full-run script: compute S0 with lightweight QC outputs (6 variables) for both dEF and dSIF.
Based on your original cal_S0.py structure (tile + multiprocess + assemble).
Outputs per dataset (SdEF / SdSIF):
  - S0      (float32) : output S0 if linear accepted; NaN if invalid or segmented-preferred
  - p_one   (float32) : one-sided p-value for slope b<0 (Normal approx)
  - bic_lin (float32) : BIC of linear model
  - bic_seg (float32) : BIC of 1-change-point piecewise model (NaN if not fitted)
  - cp      (int16)   : change-point index in valid-bin sequence (-1 if none)
  - flag    (uint8)   : 0 invalid; 1 linear accepted (S0 valid); 2 segmented better (flattening/change point)

flag meaning:
  0 = invalid / filtered
  1 = linear accepted => S0 output
  2 = segmented clearly better => classify change point / flattening => S0 not output
"""

import os
import math
import json
from pathlib import Path
import numpy as np
import pandas as pd
import xarray as xr
from concurrent.futures import ProcessPoolExecutor, as_completed

# ==========
# Thread oversubscription control (important)
# ==========
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")
os.environ.setdefault("HDF5_USE_PATH_LOCKING", "FALSE")

# ======================
# PATH CONFIG
# ======================
NOTEBOOK_DIR = Path().resolve()
ROOT_DIR = (NOTEBOOK_DIR / "../data").resolve()
RUN_DIR = ROOT_DIR / "run"

CWD_PATH  = RUN_DIR / "CWD.nc"
dEF_PATH  = ROOT_DIR / "ET/dEF.nc"
dSIF_PATH = ROOT_DIR / "SIF/dSIF.nc"

OUT_GLOBAL_DIR = RUN_DIR / "veg_activity"
OUT_GLOBAL_DIR.mkdir(parents=True, exist_ok=True)

OUT_TILE_DIR = RUN_DIR / "veg_activity/tiles"
OUT_TILE_DIR.mkdir(parents=True, exist_ok=True)

# ======================
# VARIABLE NAMES
# ======================
VAR_CWD  = "CWD"
VAR_dEF  = "dEF"
VAR_dSIF = "dSIF"

# ======================
# TILE CONFIG
# ======================
TILE_Y = 300
TILE_X = 300

# ======================
# METHOD PARAMS
# ======================
NBINS = 50
Q = 0.90

MIN_EVENT_STEPS = 1       # 8-day data, 1 step = 8 days
MINVALIDBINS = 8
MINPOINTSTOTAL = 30
RUNAWAY_YEARS = 5         # consecutive years without reset => discard

# Missing value handling for dEF/dSIF (-9999 in your data)
FILL_X = -9999.0

# ===== QC hyper-parameters (tune if needed) =====
ALPHA = 0.05              # one-sided p-value threshold for b<0
BMIN  = 1e-8              # minimum |slope| to avoid exploding S0
MIN_SEG_POINTS = 5        # min bins per segment for piecewise
DELTA_BIC = 2.0           # segmented is "clearly better" if bic_seg + DELTA_BIC < bic_lin

# ======================
# NetCDF encoding
# ======================
ENC_F32_2D = dict(zlib=True, complevel=4, dtype="float32")
ENC_I16_2D = dict(zlib=True, complevel=4, dtype="int16")
ENC_U8_2D  = dict(zlib=True, complevel=4, dtype="uint8")

# ======================
# Load coordinate template (single process)
# ======================
ds_cwd  = xr.open_dataset(CWD_PATH,  engine="netcdf4")
ds_def  = xr.open_dataset(dEF_PATH,  engine="netcdf4")
ds_dsif = xr.open_dataset(dSIF_PATH, engine="netcdf4")

time = ds_cwd["time"].values
lat  = ds_cwd["lat"].values
lon  = ds_cwd["lon"].values

assert ds_def["time"].size == ds_cwd["time"].size
assert ds_dsif["time"].size == ds_cwd["time"].size
assert ds_def["lat"].size  == ds_cwd["lat"].size
assert ds_def["lon"].size  == ds_cwd["lon"].size

print("time:", time[0], "->", time[-1], "len=", len(time))
print("lat:", lat[0], lat[-1], "n=", len(lat))
print("lon:", lon[0], lon[-1], "n=", len(lon))

# ======================
# Build year slices
# ======================
def build_year_slices(time_index: np.ndarray) -> np.ndarray:
    """
    return int32 array shape (n_years, 2) with [start, end) indices
    """
    t = pd.to_datetime(time_index)
    years = t.year.values
    uniq_years = np.unique(years)
    slices = []
    for y in uniq_years:
        idx = np.where(years == y)[0]
        if idx.size == 0:
            continue
        slices.append((int(idx[0]), int(idx[-1] + 1)))
    return np.array(slices, dtype=np.int32)

YEAR_SLICES = build_year_slices(time)
print("N years:", YEAR_SLICES.shape[0], "first/last:", YEAR_SLICES[0], YEAR_SLICES[-1])

NY = len(lat)
NX = len(lon)

def make_tiles(ny, nx, tile_y, tile_x):
    tiles = []
    for y0 in range(0, ny, tile_y):
        y1 = min(ny, y0 + tile_y)
        for x0 in range(0, nx, tile_x):
            x1 = min(nx, x0 + tile_x)
            tiles.append((y0, y1, x0, x1))
    return tiles

TILES = make_tiles(NY, NX, TILE_Y, TILE_X)
print("Total tiles:", len(TILES), "Example tile:", TILES[0])

# ======================
# Numba core (QC version)
# ======================
import numba as nb

@nb.njit(cache=True)
def _event_window_one_year(cwd_year, min_event_steps):
    """
    Find the largest CWD event in one year.
    Return (ok, start, end_exclusive, cwdmax).
    Key fix vs your old code: start is reset_idx+1 to avoid including CWD<=0 reset point.
    """
    Ty = cwd_year.size
    peak = -1
    cwdmax = -1e30
    for i in range(Ty):
        v = cwd_year[i]
        if np.isfinite(v) and v > cwdmax:
            cwdmax = v
            peak = i
    if peak < 0 or (not np.isfinite(cwdmax)) or cwdmax <= 0.0:
        return 0, 0, 0, 0.0

    # find last reset (CWD<=0) before peak, then +1
    reset_idx = -1
    for i in range(peak, -1, -1):
        v = cwd_year[i]
        if np.isfinite(v) and v <= 0.0:
            reset_idx = i
            break
    if reset_idx >= 0:
        start = reset_idx + 1
        if start >= Ty:
            return 0, 0, 0, 0.0
    else:
        start = 0

    # end: first time after peak when CWD < 0.9*CWDmax
    thr = 0.9 * cwdmax
    end = Ty
    for i in range(peak, Ty):
        v = cwd_year[i]
        if np.isfinite(v) and v < thr:
            end = i
            break

    if end - start < min_event_steps:
        return 0, 0, 0, 0.0

    return 1, start, end, cwdmax


@nb.njit(cache=True)
def _count_runaway_years(cwd, year_slices):
    """
    Max consecutive years without any reset (CWD<=0).
    """
    nY = year_slices.shape[0]
    max_consec = 0
    consec = 0
    for yi in range(nY):
        s = year_slices[yi, 0]
        e = year_slices[yi, 1]
        has_reset = False
        for t in range(s, e):
            v = cwd[t]
            if np.isfinite(v) and v <= 0.0:
                has_reset = True
                break
        if has_reset:
            consec = 0
        else:
            consec += 1
            if consec > max_consec:
                max_consec = consec
    return max_consec


@nb.njit(cache=True)
def _norm_cdf(z):
    # Phi(z) = 0.5 * erfc(-z/sqrt(2))
    return 0.5 * math.erfc(-z / 1.4142135623730951)


@nb.njit(cache=True)
def _ols_fit(x, y):
    """
    OLS fit y=a+b*x
    returns (ok, a, b, rss, se_b, t_b, p_one_sided_for_b<0)
    p-value uses Normal approximation (fast & numba-friendly).
    """
    n = x.size
    if n < 3:
        return 0, 0.0, 0.0, 1e30, 0.0, 0.0, 1.0

    sx = 0.0
    sy = 0.0
    sxx = 0.0
    sxy = 0.0
    for i in range(n):
        sx += x[i]
        sy += y[i]
        sxx += x[i] * x[i]
        sxy += x[i] * y[i]

    denom = n * sxx - sx * sx
    if denom == 0.0 or (not np.isfinite(denom)):
        return 0, 0.0, 0.0, 1e30, 0.0, 0.0, 1.0

    b = (n * sxy - sx * sy) / denom
    a = (sy - b * sx) / n
    if (not np.isfinite(a)) or (not np.isfinite(b)):
        return 0, 0.0, 0.0, 1e30, 0.0, 0.0, 1.0

    rss = 0.0
    for i in range(n):
        r = y[i] - (a + b * x[i])
        rss += r * r
    if (not np.isfinite(rss)) or rss <= 0.0:
        rss = 1e-30

    xbar = sx / n
    sxxc = 0.0
    for i in range(n):
        dx = x[i] - xbar
        sxxc += dx * dx
    if sxxc <= 0.0 or (not np.isfinite(sxxc)):
        return 0, a, b, rss, 0.0, 0.0, 1.0

    s2 = rss / (n - 2)
    if (not np.isfinite(s2)) or s2 <= 0.0:
        return 0, a, b, rss, 0.0, 0.0, 1.0

    se_b = math.sqrt(s2 / sxxc)
    if (not np.isfinite(se_b)) or se_b <= 0.0:
        return 0, a, b, rss, 0.0, 0.0, 1.0

    t_b = b / se_b
    p_one = _norm_cdf(t_b)  # one-sided for b<0
    if not np.isfinite(p_one):
        p_one = 1.0
    return 1, a, b, rss, se_b, t_b, p_one


@nb.njit(cache=True)
def _bic_from_rss(rss, n, k):
    if rss <= 0.0:
        rss = 1e-30
    return n * math.log(rss / n) + k * math.log(n)


@nb.njit(cache=True)
def _best_piecewise_1cp_bic(x, y, min_seg_points):
    """
    Best 1-change-point piecewise linear (bruteforce cp).
    Returns (ok, cp, rss_total, bic_seg).
    BIC uses k=5 (a1,b1,a2,b2,cp) conservative.
    """
    n = x.size
    if n < 2 * min_seg_points:
        return 0, -1, 1e30, 1e30

    best_rss = 1e30
    best_cp = -1

    for cp in range(min_seg_points, n - min_seg_points + 1):
        ok1, a1, b1, rss1, _, _, _ = _ols_fit(x[:cp], y[:cp])
        if ok1 == 0:
            continue
        ok2, a2, b2, rss2, _, _, _ = _ols_fit(x[cp:], y[cp:])
        if ok2 == 0:
            continue
        rss = rss1 + rss2
        if rss < best_rss:
            best_rss = rss
            best_cp = cp

    if best_cp < 0:
        return 0, -1, 1e30, 1e30

    bic_seg = _bic_from_rss(best_rss, n, 5)
    return 1, best_cp, best_rss, bic_seg


@nb.njit(cache=True)
def _linfit_s0_qc(x, y, alpha, bmin):
    """
    Linear fit + QC:
      - a > 0
      - b < 0
      - one-sided p(b<0) < alpha
      - |b| >= bmin
    Returns (ok, s0, p_one, bic_lin)
    """
    ok, a, b, rss, se_b, t_b, p_one = _ols_fit(x, y)
    if ok == 0:
        return 0, np.nan, np.nan, np.nan

    n = x.size
    bic_lin = _bic_from_rss(rss, n, 2)

    if (not np.isfinite(a)) or a <= 0.0:
        return 0, np.nan, np.float32(p_one), np.float32(bic_lin)
    if (not np.isfinite(b)) or b >= 0.0:
        return 0, np.nan, np.float32(p_one), np.float32(bic_lin)
    if abs(b) < bmin:
        return 0, np.nan, np.float32(p_one), np.float32(bic_lin)
    if (not np.isfinite(p_one)) or (p_one >= alpha):
        return 0, np.nan, np.float32(p_one), np.float32(bic_lin)

    s0 = -a / b
    if (not np.isfinite(s0)) or s0 <= 0.0:
        return 0, np.nan, np.float32(p_one), np.float32(bic_lin)

    return 1, np.float32(s0), np.float32(p_one), np.float32(bic_lin)


@nb.njit(cache=True)
def _s0_one_pixel_qc(
    cwd, X, year_slices, nbins, q,
    min_event_steps, min_valid_bins, min_points_total, runaway_years,
    alpha, bmin, min_seg_points, delta_bic
):
    """
    Returns:
      s0(float32), p_one(float32), bic_lin(float32), bic_seg(float32), cp(int16), flag(uint8)
    """
    # runaway check
    if _count_runaway_years(cwd, year_slices) >= runaway_years:
        return np.float32(np.nan), np.float32(np.nan), np.float32(np.nan), np.float32(np.nan), np.int16(-1), np.uint8(0)

    T = cwd.size
    nY = year_slices.shape[0]

    pool_cwd = np.empty(T, dtype=np.float32)
    pool_x   = np.empty(T, dtype=np.float32)
    n_pool = 0
    gmax = 0.0

    # per-year largest event
    for yi in range(nY):
        s = year_slices[yi, 0]
        e = year_slices[yi, 1]
        ok, st, ed, cwdmax = _event_window_one_year(cwd[s:e], min_event_steps)
        if ok == 0:
            continue

        # pooling: keep only CWD>0 to avoid reset points
        for tt in range(s + st, s + ed):
            c = cwd[tt]
            x = X[tt]
            if np.isfinite(c) and np.isfinite(x) and c > 0.0:
                pool_cwd[n_pool] = c
                pool_x[n_pool] = x
                n_pool += 1
                if c > gmax:
                    gmax = c

    if n_pool < min_points_total or (not np.isfinite(gmax)) or gmax <= 0.0:
        return np.float32(np.nan), np.float32(np.nan), np.float32(np.nan), np.float32(np.nan), np.int16(-1), np.uint8(0)

    # binning
    counts = np.zeros(nbins, dtype=np.int32)
    bin_id = np.empty(n_pool, dtype=np.int32)

    for i in range(n_pool):
        b = int(pool_cwd[i] / gmax * nbins)
        if b < 0:
            b = 0
        elif b >= nbins:
            b = nbins - 1
        bin_id[i] = b
        counts[b] += 1

    offsets = np.empty(nbins, dtype=np.int32)
    pos = 0
    for b in range(nbins):
        offsets[b] = pos
        pos += counts[b]

    xgrp = np.empty(n_pool, dtype=np.float32)
    cursor = offsets.copy()
    for i in range(n_pool):
        b = bin_id[i]
        p = cursor[b]
        xgrp[p] = pool_x[i]
        cursor[b] += 1

    xb = np.empty(nbins, dtype=np.float32)
    yqv = np.empty(nbins, dtype=np.float32)
    n_valid = 0

    for b in range(nbins):
        n = counts[b]
        if n <= 0:
            continue
        st = offsets[b]
        ed = st + n

        xgrp[st:ed].sort()

        k = int(math.ceil(q * n) - 1)
        if k < 0:
            k = 0
        if k >= n:
            k = n - 1

        xq = xgrp[st + k]
        if not np.isfinite(xq):
            continue

        center = (b + 0.5) / nbins * gmax
        xb[n_valid] = center
        yqv[n_valid] = xq
        n_valid += 1

    if n_valid < min_valid_bins:
        return np.float32(np.nan), np.float32(np.nan), np.float32(np.nan), np.float32(np.nan), np.int16(-1), np.uint8(0)

    x_use = xb[:n_valid]
    y_use = yqv[:n_valid]

    # (1) linear QC
    ok, s0, p_one, bic_lin = _linfit_s0_qc(x_use, y_use, alpha, bmin)
    if ok == 0:
        # still provide p_one & bic_lin if possible
        ok0, a0, b0, rss0, se_b0, t_b0, p0 = _ols_fit(x_use, y_use)
        if ok0 == 1:
            bic0 = _bic_from_rss(rss0, n_valid, 2)
            return np.float32(np.nan), np.float32(p0), np.float32(bic0), np.float32(np.nan), np.int16(-1), np.uint8(0)
        return np.float32(np.nan), np.float32(np.nan), np.float32(np.nan), np.float32(np.nan), np.int16(-1), np.uint8(0)

    # (2) segmented 1-cp BIC test
    ok2, cp, rss_seg, bic_seg = _best_piecewise_1cp_bic(x_use, y_use, min_seg_points)
    if ok2 == 1:
        if bic_seg + delta_bic < bic_lin:
            # segmented is clearly better => change point / flattening
            return np.float32(np.nan), np.float32(p_one), np.float32(bic_lin), np.float32(bic_seg), np.int16(cp), np.uint8(2)
        else:
            return np.float32(s0), np.float32(p_one), np.float32(bic_lin), np.float32(bic_seg), np.int16(cp), np.uint8(1)

    # no segmented fit
    return np.float32(s0), np.float32(p_one), np.float32(bic_lin), np.float32(np.nan), np.int16(-1), np.uint8(1)


@nb.njit(parallel=True, cache=True)
def s0_tile_qc(
    cwd3d, X3d, year_slices, nbins, q,
    min_event_steps, min_valid_bins, min_points_total, runaway_years,
    alpha, bmin, min_seg_points, delta_bic
):
    """
    Returns 6 arrays (Y, X): s0, p_one, bic_lin, bic_seg, cp, flag
    """
    T, Y, X = cwd3d.shape
    s0_out = np.empty((Y, X), dtype=np.float32)
    p_out  = np.empty((Y, X), dtype=np.float32)
    bl_out = np.empty((Y, X), dtype=np.float32)
    bs_out = np.empty((Y, X), dtype=np.float32)
    cp_out = np.empty((Y, X), dtype=np.int16)
    fg_out = np.empty((Y, X), dtype=np.uint8)

    n_pix = Y * X
    for p in nb.prange(n_pix):
        j = p // X
        i = p - j * X
        s0, p_one, bic_lin, bic_seg, cp, flag = _s0_one_pixel_qc(
            cwd3d[:, j, i], X3d[:, j, i],
            year_slices, nbins, q,
            min_event_steps, min_valid_bins, min_points_total, runaway_years,
            alpha, bmin, min_seg_points, delta_bic
        )
        s0_out[j, i] = s0
        p_out[j, i]  = p_one
        bl_out[j, i] = bic_lin
        bs_out[j, i] = bic_seg
        cp_out[j, i] = cp
        fg_out[j, i] = flag

    return s0_out, p_out, bl_out, bs_out, cp_out, fg_out


# ======================
# Tile IO helpers
# ======================
def write_tile_qc(out_path, lat_tile, lon_tile, prefix,
                  s0, p_one, bic_lin, bic_seg, cp, flag):
    """
    Write one tile file containing 6 variables.
    """
    ds = xr.Dataset(
        data_vars=dict(
            S0=(("lat", "lon"), s0.astype(np.float32)),
            p_one=(("lat", "lon"), p_one.astype(np.float32)),
            bic_lin=(("lat", "lon"), bic_lin.astype(np.float32)),
            bic_seg=(("lat", "lon"), bic_seg.astype(np.float32)),
            cp=(("lat", "lon"), cp.astype(np.int16)),
            flag=(("lat", "lon"), flag.astype(np.uint8)),
        ),
        coords=dict(
            lat=("lat", lat_tile),
            lon=("lon", lon_tile),
        ),
        attrs=dict(
            product=prefix,
            NBINS=str(NBINS),
            quantile=str(Q),
            alpha=str(ALPHA),
            bmin=str(BMIN),
            min_seg_points=str(MIN_SEG_POINTS),
            delta_bic=str(DELTA_BIC),
            runaway_years=str(RUNAWAY_YEARS),
        )
    )

    enc = {
        "S0": ENC_F32_2D,
        "p_one": ENC_F32_2D,
        "bic_lin": ENC_F32_2D,
        "bic_seg": ENC_F32_2D,
        "cp": ENC_I16_2D,
        "flag": ENC_U8_2D,
    }

    ds.to_netcdf(out_path, engine="netcdf4", format="NETCDF4", encoding=enc)
    return out_path


def assemble_global_qc(pattern_prefix, out_path):
    """
    Assemble global from tile QC files:
      - expects tiles named like: {prefix}_y####-####_x####-####.nc
    """
    files = sorted(OUT_TILE_DIR.glob(f"{pattern_prefix}_y*-*_x*-*.nc"))
    if len(files) == 0:
        raise RuntimeError(f"No tile files found for {pattern_prefix}")

    ds = xr.open_mfdataset(files, combine="by_coords", engine="netcdf4")

    # global encoding: choose chunksizes (lat,lon). adjust as needed
    enc = {
        "S0": dict(zlib=True, complevel=4, dtype="float32", chunksizes=(600, 600)),
        "p_one": dict(zlib=True, complevel=4, dtype="float32", chunksizes=(600, 600)),
        "bic_lin": dict(zlib=True, complevel=4, dtype="float32", chunksizes=(600, 600)),
        "bic_seg": dict(zlib=True, complevel=4, dtype="float32", chunksizes=(600, 600)),
        "cp": dict(zlib=True, complevel=4, dtype="int16", chunksizes=(600, 600)),
        "flag": dict(zlib=True, complevel=4, dtype="uint8", chunksizes=(600, 600)),
    }

    ds.to_netcdf(out_path, engine="netcdf4", format="NETCDF4", encoding=enc)
    return out_path


# ======================
# Worker
# ======================
def worker_one_tile(tile_id: int):
    y0, y1, x0, x1 = TILES[tile_id]
    tile_tag = f"y{y0:04d}-{y1:04d}_x{x0:04d}-{x1:04d}"

    f_def = OUT_TILE_DIR / f"SdEF_QC_{tile_tag}.nc"
    f_sif = OUT_TILE_DIR / f"SdSIF_QC_{tile_tag}.nc"

    # checkpoint
    if f_def.exists() and f_sif.exists():
        return tile_id, "SKIP"

    # Each process opens datasets independently (important)
    ds_c = xr.open_dataset(CWD_PATH, engine="netcdf4")
    ds_e = xr.open_dataset(dEF_PATH, engine="netcdf4")
    ds_s = xr.open_dataset(dSIF_PATH, engine="netcdf4")

    cwd  = ds_c[VAR_CWD].isel(lat=slice(y0, y1), lon=slice(x0, x1)).transpose("time", "lat", "lon").load().data.astype(np.float32)

    dEF  = ds_e[VAR_dEF].isel(lat=slice(y0, y1), lon=slice(x0, x1)).transpose("time", "lat", "lon").load()
    dSIF = ds_s[VAR_dSIF].isel(lat=slice(y0, y1), lon=slice(x0, x1)).transpose("time", "lat", "lon").load()

    # fill -> NaN
    dEF  = dEF.where(dEF != FILL_X).data.astype(np.float32)
    dSIF = dSIF.where(dSIF != FILL_X).data.astype(np.float32)

    # compute QC outputs (6 vars)
    s0_def, p_def, bl_def, bs_def, cp_def, fg_def = s0_tile_qc(
        cwd, dEF, YEAR_SLICES, NBINS, Q,
        MIN_EVENT_STEPS, MINVALIDBINS, MINPOINTSTOTAL, RUNAWAY_YEARS,
        ALPHA, BMIN, MIN_SEG_POINTS, DELTA_BIC
    )
    s0_sif, p_sif, bl_sif, bs_sif, cp_sif, fg_sif = s0_tile_qc(
        cwd, dSIF, YEAR_SLICES, NBINS, Q,
        MIN_EVENT_STEPS, MINVALIDBINS, MINPOINTSTOTAL, RUNAWAY_YEARS,
        ALPHA, BMIN, MIN_SEG_POINTS, DELTA_BIC
    )

    lat_tile = lat[y0:y1]
    lon_tile = lon[x0:x1]

    write_tile_qc(f_def, lat_tile, lon_tile, "SdEF_QC", s0_def, p_def, bl_def, bs_def, cp_def, fg_def)
    write_tile_qc(f_sif, lat_tile, lon_tile, "SdSIF_QC", s0_sif, p_sif, bl_sif, bs_sif, cp_sif, fg_sif)

    return tile_id, "OK"


# ======================
# Main
# ======================
if __name__ == "__main__":
    # Optional: quick compile on first tile to pay numba compile cost once
    y0, y1, x0, x1 = TILES[0]
    print("[WARMUP] compiling numba on first tile...")
    ds_c = xr.open_dataset(CWD_PATH, engine="netcdf4")
    ds_e = xr.open_dataset(dEF_PATH, engine="netcdf4")

    cwd0 = ds_c[VAR_CWD].isel(lat=slice(y0, y1), lon=slice(x0, x1)).transpose("time", "lat", "lon").load().data.astype(np.float32)
    x0da = ds_e[VAR_dEF].isel(lat=slice(y0, y1), lon=slice(x0, x1)).transpose("time", "lat", "lon").load()
    x0arr = x0da.where(x0da != FILL_X).data.astype(np.float32)

    _ = s0_tile_qc(
        cwd0, x0arr, YEAR_SLICES, NBINS, Q,
        MIN_EVENT_STEPS, MINVALIDBINS, MINPOINTSTOTAL, RUNAWAY_YEARS,
        ALPHA, BMIN, MIN_SEG_POINTS, DELTA_BIC
    )
    print("[WARMUP] done.")

    # ===== Parallel run =====
    N_WORKERS = 24
    done = 0
    fail = 0

    print(f"[RUN] Start tiles: n={len(TILES)} workers={N_WORKERS}")
    with ProcessPoolExecutor(max_workers=N_WORKERS) as ex:
        futs = {ex.submit(worker_one_tile, tid): tid for tid in range(len(TILES))}
        for fut in as_completed(futs):
            tid = futs[fut]
            try:
                tid2, status = fut.result()
                done += 1
                if status not in ("OK", "SKIP"):
                    fail += 1
                if done % 10 == 0:
                    print(f"[PROGRESS] {done}/{len(TILES)} done, fail={fail}")
            except Exception as e:
                fail += 1
                print(f"[FAIL] tile {tid}: {repr(e)}")

    print("[DONE] tiles complete. fail=", fail)
    if fail > 0:
        print("[WARN] Some tiles failed; fix failures before assembling globals.")

    # ===== Assemble globals =====
    out_def = assemble_global_qc("SdEF_QC", OUT_GLOBAL_DIR / "SdEF_QC.nc")
    out_sif = assemble_global_qc("SdSIF_QC", OUT_GLOBAL_DIR / "SdSIF_QC.nc")
    print("[WROTE] global:", out_def, out_sif)

    # ===== quick QC summary =====
    def qc_report(path):
        ds = xr.open_dataset(path)
        s0 = ds["S0"].values
        fg = ds["flag"].values
        rep = {
            "file": str(path),
            "S0": {
                "nan_frac": float(np.isnan(s0).mean()),
                "p01": float(np.nanpercentile(s0, 1)),
                "p50": float(np.nanpercentile(s0, 50)),
                "p99": float(np.nanpercentile(s0, 99)),
                "min": float(np.nanmin(s0)),
                "max": float(np.nanmax(s0)),
            },
            "flag": {
                "frac_0_invalid": float((fg == 0).mean()),
                "frac_1_linear": float((fg == 1).mean()),
                "frac_2_segmented": float((fg == 2).mean()),
            }
        }
        return rep

    qc = {
        "SdEF_QC": qc_report(out_def),
        "SdSIF_QC": qc_report(out_sif),
    }
    print(json.dumps(qc, indent=2, ensure_ascii=False))


time: 2003-01-01T00:00:00.000000000 -> 2020-12-26T00:00:00.000000000 len= 828
lat: -89.975 89.975 n= 3600
lon: -179.975 179.975 n= 7200
N years: 18 first/last: [ 0 46] [782 828]
Total tiles: 288 Example tile: (0, 300, 0, 300)
[WARMUP] compiling numba on first tile...
[WARMUP] done.
[RUN] Start tiles: n=288 workers=24
[PROGRESS] 10/288 done, fail=0
[PROGRESS] 20/288 done, fail=0
[PROGRESS] 30/288 done, fail=0
[PROGRESS] 40/288 done, fail=0
[PROGRESS] 50/288 done, fail=0
[PROGRESS] 60/288 done, fail=0
[PROGRESS] 70/288 done, fail=0
[PROGRESS] 80/288 done, fail=0
[PROGRESS] 90/288 done, fail=0
[PROGRESS] 100/288 done, fail=0
[PROGRESS] 110/288 done, fail=0
[PROGRESS] 120/288 done, fail=0
[PROGRESS] 130/288 done, fail=0
[PROGRESS] 140/288 done, fail=0
[PROGRESS] 150/288 done, fail=0
[PROGRESS] 160/288 done, fail=0
[PROGRESS] 170/288 done, fail=0
[PROGRESS] 180/288 done, fail=0
[PROGRESS] 190/288 done, fail=0
[PROGRESS] 200/288 done, fail=0
[PROGRESS] 210/288 done, fail=0
[PROGRESS] 220/288