In [None]:
import xarray as xr
import numpy as np
import numba as nb

# -----------------------------
# Inputs
# -----------------------------
PATH_CWD  = "/share/home/dq076/bedrock/data/CWD/CWD.nc"
PATH_dEF  = "/share/home/dq076/bedrock/data/dEF/dEF.nc"
PATH_dSIF = "/share/home/dq076/bedrock/data/dSIF/dSIF.nc"

OUT_SdEF  = "/share/home/dq076/bedrock/data/S0/SdEF.nc4"
OUT_SdSIF = "/share/home/dq076/bedrock/data/S0/SdSIF.nc4"

# -----------------------------
# Performance tuning (important)
# -----------------------------
CHUNK = {"time": -1, "lat": 200, "lon": 200}
N_BINS = 15
Q = 0.90
MIN_VALID_BINS = 8
ZERO_EPS = 1e-6   # treat CWD <= ZERO_EPS as 0
MIN_POINTS_TOTAL = 30  # pooled (CWD,X) points needed to be meaningful, adjust if needed

# -----------------------------
# Read (lazy + chunked)
# -----------------------------
ds_cwd  = xr.open_dataset(PATH_CWD,  chunks=CHUNK)
ds_def  = xr.open_dataset(PATH_dEF,  chunks=CHUNK)
ds_dsif = xr.open_dataset(PATH_dSIF, chunks=CHUNK)

cwd  = ds_cwd["CWD"].astype("float32")
dEF  = ds_def["dEF"].astype("float32")
dSIF = ds_dsif["dSIF"].astype("float32")

# Force exact alignment (will raise if mismatch)
cwd, dEF, dSIF = xr.align(cwd, dEF, dSIF, join="exact")

# Ensure dimension order
cwd  = cwd.transpose("time", "lat", "lon")
dEF  = dEF.transpose("time", "lat", "lon")
dSIF = dSIF.transpose("time", "lat", "lon")

# Build year slices from time coordinate (assumes contiguous time axis)
years = cwd["time"].dt.year.values.astype(np.int32)
uniq_years = np.unique(years)

year_slices = []
for y in uniq_years:
    idx = np.where(years == y)[0]
    if idx.size > 0:
        year_slices.append((y, int(idx[0]), int(idx[-1] + 1)))  # [start, end)
year_slices_arr = np.array(year_slices, dtype=np.int32)  # (Ny, 3)


# =====================================================================
# Numba kernels (UPDATED to pooled-across-years binning, per your summary)
# =====================================================================

@nb.njit
def _linfit_s0(x, y):
    """Fit y = a + b*x, return S0=-a/b. Use only finite points."""
    n = 0
    sx = sy = sxx = sxy = 0.0
    for i in range(x.size):
        xi = x[i]
        yi = y[i]
        if np.isfinite(xi) and np.isfinite(yi):
            n += 1
            sx += xi
            sy += yi
            sxx += xi * xi
            sxy += xi * yi

    if n < MIN_VALID_BINS:
        return np.nan

    denom = n * sxx - sx * sx
    if denom == 0.0:
        return np.nan

    b = (n * sxy - sx * sy) / denom
    a = (sy - b * sx) / n

    # paper logic: slope should be negative (water stress reduces X)
    if not np.isfinite(a) or not np.isfinite(b) or b >= 0.0:
        return np.nan

    s0 = -a / b
    if not np.isfinite(s0) or s0 <= 0.0:
        return np.nan

    return s0


@nb.njit
def _event_window_one_year(cwd_y):
    """
    Given one-year CWD series (T,), return (ok, start, end, cwdmax).
    Event definition (your summary + my correction):
      - peak = argmax(CWD)
      - start = first index after last CWD<=0 before peak
      - end   = first index after peak where CWD < 0.9*CWDmax (exclude end and beyond)
    """
    T = cwd_y.size
    if T < 6:
        return 0, 0, 0, 0.0

    # peak
    peak = 0
    cwdmax = -1.0
    for t in range(T):
        v = cwd_y[t]
        if np.isfinite(v) and v > cwdmax:
            cwdmax = v
            peak = t
    if (not np.isfinite(cwdmax)) or cwdmax <= 0.0:
        return 0, 0, 0, 0.0

    # start: last ~0 before peak
    start = 0
    for t in range(peak, -1, -1):
        v = cwd_y[t]
        if np.isfinite(v) and v <= ZERO_EPS:
            start = t + 1
            break

    # end: first after peak where CWD drops below 0.9*cwdmax
    thr = 0.9 * cwdmax
    end = T
    for t in range(peak + 1, T):
        v = cwd_y[t]
        if np.isfinite(v) and v < thr:
            end = t
            break

    if end - start < 6:
        return 0, 0, 0, 0.0

    return 1, start, end, cwdmax


@nb.njit
def _s0_pooled_one_pixel(cwd_ts, x_ts, year_slices_arr):
    """
    Correct method for ONE pixel:
      - For each year: pick max-CWD event window (start:end) using CWD only
      - Pool all (CWD, X) points from all selected yearly events
      - Define bins using GLOBAL max CWD across pooled points (i.e., across years)
      - For each bin: compute X 90% quantile
      - Regress Xq90 vs bin-center CWD => S0=-a/b
    """
    Ny = year_slices_arr.shape[0]
    T_all = cwd_ts.size  # up to 828

    # ---------------------------------------------------------
    # Pass 1: find global max CWD across ALL yearly max events
    # ---------------------------------------------------------
    global_max = -1.0
    total_points = 0

    for k in range(Ny):
        st = year_slices_arr[k, 1]
        ed = year_slices_arr[k, 2]
        ok, s, e, cwdmax = _event_window_one_year(cwd_ts[st:ed])
        if ok == 0:
            continue
        # update global max from this year max
        if cwdmax > global_max:
            global_max = cwdmax
        total_points += (e - s)

    if (not np.isfinite(global_max)) or global_max <= 0.0:
        return np.nan
    if total_points < MIN_POINTS_TOTAL:
        return np.nan

    # ---------------------------------------------------------
    # Pass 2: bin pooled points (store X values per bin)
    #   bin range: [0, global_max], equal-width N_BINS
    # ---------------------------------------------------------
    # We store X values in bin_vals[bin, idx], idx up to total_points (<=828)
    bin_vals = np.empty((N_BINS, T_all), dtype=np.float32)
    bin_cnt  = np.zeros(N_BINS, dtype=np.int32)

    for k in range(Ny):
        st = year_slices_arr[k, 1]
        ed = year_slices_arr[k, 2]
        cwd_y = cwd_ts[st:ed]
        x_y   = x_ts[st:ed]

        ok, s, e, _ = _event_window_one_year(cwd_y)
        if ok == 0:
            continue

        for t in range(s, e):
            c = cwd_y[t]
            v = x_y[t]
            if not (np.isfinite(c) and np.isfinite(v)):
                continue
            if c < 0.0:
                continue

            # bin index by GLOBAL max
            bi = int(np.floor((c / global_max) * (N_BINS - 1) + 1e-8))
            if bi < 0:
                bi = 0
            elif bi >= N_BINS:
                bi = N_BINS - 1

            kk = bin_cnt[bi]
            bin_vals[bi, kk] = v
            bin_cnt[bi] = kk + 1

    # ---------------------------------------------------------
    # Quantile per bin -> regression arrays
    #   CWD coordinate for each bin = bin center (NOT the CWD of the chosen X point)
    # ---------------------------------------------------------
    x_fit = np.empty(N_BINS, dtype=np.float32)
    y_fit = np.empty(N_BINS, dtype=np.float32)

    valid_bins = 0
    for bi in range(N_BINS):
        cnt = bin_cnt[bi]
        x_fit[bi] = (bi + 0.5) / N_BINS * global_max  # bin center

        if cnt <= 0:
            y_fit[bi] = np.nan
            continue

        arr = bin_vals[bi, :cnt]
        arr.sort()
        qi = int(np.floor(Q * (cnt - 1)))
        y_fit[bi] = arr[qi]
        if np.isfinite(y_fit[bi]):
            valid_bins += 1

    if valid_bins < MIN_VALID_BINS:
        return np.nan

    return _linfit_s0(x_fit, y_fit)


@nb.njit(parallel=True)
def _sd_map_pooled(cwd3d, x3d, year_slices_arr):
    """
    cwd3d, x3d: (T, Y, X)
    output: (Y, X) Sd (S0) estimated from pooled across years
    """
    T, Y, X = cwd3d.shape
    out = np.full((Y, X), np.nan, dtype=np.float32)

    for j in nb.prange(Y):
        for i in range(X):
            out[j, i] = _s0_pooled_one_pixel(cwd3d[:, j, i], x3d[:, j, i], year_slices_arr)

    return out


def compute_Sd(cwd_da, x_da, out_name):
    Sd = xr.apply_ufunc(
        _sd_map_pooled,
        cwd_da,
        x_da,
        xr.DataArray(year_slices_arr, dims=("ny", "three")),
        input_core_dims=[["time", "lat", "lon"], ["time", "lat", "lon"], ["ny", "three"]],
        output_core_dims=[["lat", "lon"]],
        dask="parallelized",
        vectorize=False,
        output_dtypes=["float32"],
    )
    Sd.name = out_name
    Sd = Sd.assign_coords(lat=cwd_da["lat"], lon=cwd_da["lon"])
    return Sd


SdEF  = compute_Sd(cwd, dEF,  "SdEF")
SdSIF = compute_Sd(cwd, dSIF, "SdSIF")


# -----------------------------
# Write netCDF-4 (NC4)
# -----------------------------
enc2d = {
    "zlib": True,
    "complevel": 4,
    "dtype": "float32",
    "chunksizes": (200, 200),
    "_FillValue": np.float32(np.nan),
}

xr.Dataset({"SdEF": SdEF}).to_netcdf(
    OUT_SdEF, engine="netcdf4", format="NETCDF4", encoding={"SdEF": enc2d}
)
xr.Dataset({"SdSIF": SdSIF}).to_netcdf(
    OUT_SdSIF, engine="netcdf4", format="NETCDF4", encoding={"SdSIF": enc2d}
)
