In [None]:
import numpy as np
import xarray as xr
import pandas as pd
from pathlib import Path
from netCDF4 import Dataset, date2num

resolution = "p05"
ROOT_DIR = Path("/home/xuxh22/stu01/bedrock/data").resolve()

# =========================
# INPUT
# =========================
ds  = xr.open_dataset(ROOT_DIR / "diff/diff.nc")
ds2 = xr.open_dataset(ROOT_DIR / "SC/SC.nc")
ds3 = xr.open_dataset(ROOT_DIR / "Ssoil/Ssoil.nc")

diff  = ds["diff"]    # (time, lat, lon) = (828, 3600, 7200)
sc    = ds2["SC"]     # (time, lat, lon)
ssoil = ds3["Ssoil"]  # (lat, lon)

time_len = diff.sizes["time"]
nlat     = diff.sizes["lat"]
nlon     = diff.sizes["lon"]

# time decoding
time_pd = pd.to_datetime(ds["time"].values)
years = time_pd.year.values
uniq_years = np.unique(years)

print("Time range:", time_pd[0], "to", time_pd[-1], "len=", time_len)
print("Years:", uniq_years[0], "->", uniq_years[-1], "N=", len(uniq_years))
print("Shape:", (time_len, nlat, nlon))

# =========================
# 2D static arrays
# =========================
ssoil2d = ssoil.values.astype(np.float32)

# =========================
# OUTPUT: CWD (year-reset)
# =========================
out_cwd = ROOT_DIR / "CWD_yearreset.nc"

# 输出每年的 Dr/Dbedrock 文件目录
OUT_YDIR = ROOT_DIR
OUT_YDIR.mkdir(parents=True, exist_ok=True)

# time encoding for netcdf
time_dt   = time_pd.to_pydatetime()
time_unit = "days since 1970-01-01 00:00:00"
time_cal  = "proleptic_gregorian"
time_num  = date2num(time_dt, units=time_unit, calendar=time_cal)

# choose output chunk sizes (tune)
chunk_t, chunk_y, chunk_x = 1, 600, 600
chunk_y = min(chunk_y, nlat)
chunk_x = min(chunk_x, nlon)

# =========================
# State arrays (reset per year)
# =========================
cwd_i = np.zeros((nlat, nlon), dtype=np.float32)  # running CWD within current year
sr_y  = np.zeros((nlat, nlon), dtype=np.float32)  # max CWD within current year (Dr)

def finalize_and_write_year(year, Dr2d, Ssoil2d):
    """
    year: int
    Dr2d: (lat, lon) float32, yearly max CWD
    write:
      Dr_{year}.nc
      Dbedrock_{year}.nc
    """
    Dbedrock2d = np.where(Dr2d > Ssoil2d, Dr2d - Ssoil2d, 0.0).astype(np.float32)

    # 你要求：18个Dr_year.nc 和 18个Dbedrock_year.nc（每年各一个文件）
    f_dr = OUT_YDIR / f"Dr_{year}.nc"
    f_db = OUT_YDIR / f"Dbedrock_{year}.nc"

    xr.Dataset(
        {"Dr": (("lat", "lon"), Dr2d)},
        coords={"lat": ds["lat"].values, "lon": ds["lon"].values},
    ).to_netcdf(f_dr)

    xr.Dataset(
        {"Dbedrock": (("lat", "lon"), Dbedrock2d)},
        coords={"lat": ds["lat"].values, "lon": ds["lon"].values},
    ).to_netcdf(f_db)

    print(f"[YEAR OUT] {year} -> {f_dr.name}, {f_db.name}")

# =========================
# STREAM WRITE CWD_yearreset.nc
# =========================
with Dataset(out_cwd, "w", format="NETCDF4") as nc:
    nc.createDimension("time", time_len)
    nc.createDimension("lat",  nlat)
    nc.createDimension("lon",  nlon)

    vtime = nc.createVariable("time", "f8", ("time",))
    vtime.units = time_unit
    vtime.calendar = time_cal
    vtime[:] = time_num

    vlon = nc.createVariable("lon", "f4", ("lon",))
    vlon.units = "degrees_east"
    vlon[:] = ds["lon"].values.astype(np.float32)

    vlat = nc.createVariable("lat", "f4", ("lat",))
    vlat.units = "degrees_north"
    vlat[:] = ds["lat"].values.astype(np.float32)

    vCWD = nc.createVariable(
        "CWD", "f4", ("time", "lat", "lon"),
        zlib=True, complevel=3, shuffle=True,
        chunksizes=(chunk_t, chunk_y, chunk_x),
        fill_value=np.float32(np.nan)
    )
    vCWD.long_name = "Cumulative Water Deficit (8-day running state, reset to 0 at each year boundary)"

    # 初始化当前年份
    curr_year = int(years[0])

    # 逐时次处理
    for i in range(time_len):
        y = int(years[i])

        # 进入新的一年：先把上一年的 Dr/Dbedrock 写出去，然后重置状态
        if y != curr_year:
            # finalize previous year
            finalize_and_write_year(curr_year, sr_y, ssoil2d)

            # reset for new year
            cwd_i.fill(0.0)
            sr_y.fill(0.0)
            curr_year = y

        if i % 10 == 0:
            print(f"Processing {i}/{time_len-1}  (year={y})")

        # 读当前时次
        diff_i = diff.isel(time=i).values.astype(np.float32)
        sc_i   = sc.isel(time=i).values.astype(np.float32)

        # 年内累计（与你原逻辑一致，只是跨年重置了 cwd_i）
        delta_tn_i = diff_i * sc_i
        cwd_i = np.where(delta_tn_i >= 0, cwd_i + delta_tn_i, 0.0).astype(np.float32)

        # 当年最大值 Dr（年内最大 CWD）
        sr_y = np.maximum(sr_y, cwd_i)

        # 写出该时次的 CWD（全时段 828 都会写）
        vCWD[i, :, :] = cwd_i

    # 循环结束后：别忘了输出最后一年
    finalize_and_write_year(curr_year, sr_y, ssoil2d)

print("[OK] Wrote CWD (year-reset):", out_cwd)
print("[OK] Yearly Dr/Dbedrock in:", OUT_YDIR)
