In [3]:
import xarray as xr
import numpy as np
import numba as nb

# ---------------------------
# 1) 以 chunk 方式懒加载（不要 .values）
# ---------------------------
data_path = f'/share/home/dq076/bedrock/data/'

ds  = xr.open_dataset(f'{data_path}diff/diff.nc', chunks={"time": -1, "lat": 200, "lon": 200})
ds2 = xr.open_dataset(f'{data_path}SC/SC.nc', chunks={"time": -1, "lat": 200, "lon": 200})
ds3 = xr.open_dataset(f'{data_path}Ssoil/Ssoil.nc', chunks={"lat": 200, "lon": 200})

# 变量（按你说的：diff.nc 是 ET-P；sc 已是0/1权重；ssoil是2D）
diff  = ds["diff"].astype("float32")
sc = ds2["SC"].astype("float32")
ssoil = ds3["Ssoil"].astype("float32")

x = (diff * sc).transpose("time", "lat", "lon")  # 保证 time 在第一维，利于 scan

# ---------------------------
# 2) 用 numba 写一个沿 time 的递推
#    输入: (time, y, x) -> 输出: sr(y,x)
# ---------------------------
@nb.njit(parallel=True)
def sr_from_x(x3d):
    T, Y, X = x3d.shape
    sr = np.zeros((Y, X), dtype=x3d.dtype)
    cwd = np.zeros((Y, X), dtype=x3d.dtype)

    for t in range(T):
        xt = x3d[t]
        # 逐像元更新
        for j in nb.prange(Y):
            for i in range(X):
                v = xt[j, i]
                if v >= 0:
                    cwd[j, i] = cwd[j, i] + v
                else:
                    cwd[j, i] = 0.0
                if cwd[j, i] > sr[j, i]:
                    sr[j, i] = cwd[j, i]
    return sr

# ---------------------------
# 3) apply_ufunc：让 numba 函数在每个 dask block 上跑
# ---------------------------
sr = xr.apply_ufunc(
    sr_from_x,
    x,
    input_core_dims=[["time", "lat", "lon"]],
    output_core_dims=[["lat", "lon"]],
    dask="parallelized",
    vectorize=False,
    output_dtypes=["float32"],
)

# sbedrock
sbedrock = xr.where(sr > ssoil, sr - ssoil, 0.0).astype("float32")

# ---------------------------
# 4) 输出（压缩 + 合理 chunk）
# ---------------------------
encoding_2d = {
    "zlib": True, "complevel": 4, "dtype": "float32",
    "chunksizes": (200, 200),
}

xr.Dataset({"Sr": sr}).to_netcdf(
    f"{data_path}Sr.nc",
    encoding={"Sr": encoding_2d}
)

xr.Dataset({"Sbedrock": sbedrock}).to_netcdf(
    f"{data_path}Sbedrock.nc",
    encoding={"Sbedrock": encoding_2d}
)


  ds  = xr.open_dataset(f'{data_path}diff/diff.nc', chunks={"time": -1, "lat": 200, "lon": 200})
  ds2 = xr.open_dataset(f'{data_path}SC/SC.nc', chunks={"time": -1, "lat": 200, "lon": 200})
  ds2 = xr.open_dataset(f'{data_path}SC/SC.nc', chunks={"time": -1, "lat": 200, "lon": 200})
  ds3 = xr.open_dataset(f'{data_path}Ssoil/Ssoil.nc', chunks={"lat": 200, "lon": 200})
  ds3 = xr.open_dataset(f'{data_path}Ssoil/Ssoil.nc', chunks={"lat": 200, "lon": 200})


ValueError: cannot reindex or align along dimension 'time' because the (pandas) index has duplicate values

In [4]:
print(ds["time"].to_index().has_duplicates, ds2["time"].to_index().has_duplicates)
# 或者更直观
t1 = ds["time"].to_index()
t2 = ds2["time"].to_index()
print("diff duplicates:", t1[t1.duplicated()].unique()[:10])
print("SC   duplicates:", t2[t2.duplicated()].unique()[:10])


False True
diff duplicates: DatetimeIndex([], dtype='datetime64[ns]', name='time', freq=None)
SC   duplicates: DatetimeIndex(['2003-12-29 00:00:00', '2004-12-28 12:00:00',
               '2005-12-29 00:00:00', '2006-12-29 00:00:00',
               '2007-12-29 00:00:00', '2008-12-28 12:00:00',
               '2009-12-29 00:00:00', '2010-12-29 00:00:00',
               '2011-12-29 00:00:00', '2012-12-28 12:00:00'],
              dtype='datetime64[ns]', name='time', freq=None)


In [None]:
import xarray as xr
import numpy as np
import numba as nb

data_path = "/share/home/dq076/bedrock/data/"

ds  = xr.open_dataset(f"{data_path}diff/diff.nc", chunks={"time": -1, "lat": 200, "lon": 200})
ds2 = xr.open_dataset(f"{data_path}SC/SC.nc",     chunks={"time": -1, "lat": 200, "lon": 200})
ds3 = xr.open_dataset(f"{data_path}Ssoil/Ssoil.nc", chunks={"lat": 200, "lon": 200})

diff  = ds["diff"].astype("float32").transpose("time", "lat", "lon")
sc    = ds2["SC"].astype("float32").transpose("time", "lat", "lon")
ssoil = ds3["Ssoil"].astype("float32").transpose("lat", "lon")

# 按位置相乘（避免 time 重复导致 align）
x = xr.DataArray(
    diff.data * sc.data,
    dims=("time", "lat", "lon"),
    coords={"time": diff["time"], "lat": diff["lat"], "lon": diff["lon"]},
    name="x",
)

# 关键：写一个只沿最后一个维度 time 扫描的函数
# xarray 会把 time 放到 core dim 上，其它维度（lat,lon chunk）作为“批量维度”喂进来
@nb.njit
def sr_from_x_1d(x1d):
    """x1d: (T,) -> sr scalar"""
    sr = np.float32(0.0)
    cwd = np.float32(0.0)
    for t in range(x1d.shape[0]):
        v = x1d[t]
        if v >= 0:
            cwd = cwd + v
        else:
            cwd = 0.0
        if cwd > sr:
            sr = cwd
    return sr

# 将 sr_from_x_1d 向量化到 (lat,lon) 上：输出是 (lat,lon)
sr = xr.apply_ufunc(
    sr_from_x_1d,
    x,
    input_core_dims=[["time"]],     # ✅ 只有 time 是 core dim
    output_core_dims=[[]],          # ✅ 输出是标量（对每个像元）
    vectorize=True,                 # ✅ 对 lat/lon 自动向量化
    dask="parallelized",
    output_dtypes=["float32"],
    dask_gufunc_kwargs={"allow_rechunk": False},  # 不需要重分块
)

# 恢复成 (lat,lon) 的 DataArray
sr = sr.rename("Sr")

# sbedrock
sbedrock = xr.where(sr > ssoil, sr - ssoil, 0.0).astype("float32").rename("Sbedrock")

encoding_2d = {"zlib": True, "complevel": 4, "dtype": "float32", "chunksizes": (200, 200)}

xr.Dataset({"Sr": sr}).to_netcdf(
    f"{data_path}Sr.nc",
    engine="netcdf4", format="NETCDF4",
    encoding={"Sr": encoding_2d},
)

xr.Dataset({"Sbedrock": sbedrock}).to_netcdf(
    f"{data_path}Sbedrock.nc",
    engine="netcdf4", format="NETCDF4",
    encoding={"Sbedrock": encoding_2d},
)


  ds  = xr.open_dataset(f"{data_path}diff/diff.nc", chunks={"time": -1, "lat": 200, "lon": 200})
  ds2 = xr.open_dataset(f"{data_path}SC/SC.nc",     chunks={"time": -1, "lat": 200, "lon": 200})
  ds2 = xr.open_dataset(f"{data_path}SC/SC.nc",     chunks={"time": -1, "lat": 200, "lon": 200})
  ds3 = xr.open_dataset(f"{data_path}Ssoil/Ssoil.nc", chunks={"lat": 200, "lon": 200})
  ds3 = xr.open_dataset(f"{data_path}Ssoil/Ssoil.nc", chunks={"lat": 200, "lon": 200})


KeyboardInterrupt: 

: 

In [None]:
import xarray as xr
import numpy as np
import numba as nb

data_path = f'/share/home/dq076/bedrock/data/'

ds  = xr.open_dataset(f'{data_path}diff/diff.nc', chunks={"time": -1, "lat": 200, "lon": 200})
ds2 = xr.open_dataset(f'{data_path}SC/sc.nc', chunks={"time": -1, "lat": 200, "lon": 200})

diff  = ds["diff"].astype("float32")       # 这里假设变量名是 et = ET-P
sc = ds2["sc"].astype("float32")   # 0/1 mask

# 递推输入：x_t
x = (diff * sc).transpose("time", "lat", "lon")  # time放前面更快

# 2) numba：输出整个 time 序列的 cwd (T,Y,X)
@nb.njit(parallel=True)
def cwd_series_from_x(x3d):
    T, Y, X = x3d.shape
    out = np.empty((T, Y, X), dtype=x3d.dtype)

    cwd = np.zeros((Y, X), dtype=x3d.dtype)
    for t in range(T):
        xt = x3d[t]
        for j in nb.prange(Y):
            for i in range(X):
                v = xt[j, i]
                if v >= 0:
                    cwd[j, i] = cwd[j, i] + v
                else:
                    cwd[j, i] = 0.0
                out[t, j, i] = cwd[j, i]
    return out

# 3) apply_ufunc：让上面的 numba 在每个 dask block 上跑
cwd = xr.apply_ufunc(
    cwd_series_from_x,
    x,
    input_core_dims=[["time", "lat", "lon"]],
    output_core_dims=[["time", "lat", "lon"]],
    dask="parallelized",
    vectorize=False,
    output_dtypes=["float32"],
)

# 恢复原坐标（time/lat/lon）
cwd = cwd.assign_coords(time=ds["time"], lat=ds["lat"], lon=ds["lon"])
cwd.name = "CWD"
cwd.attrs.update({"long_name": "Cumulative Water Deficit", "units": "mm"})

# 4) 写出：强烈建议 chunk + 压缩，否则很慢很大
# time 维 chunk 一般设 1~10 比较好（方便按时刻访问），lat/lon 按块一致
encoding = {
    "CWD": {
        "zlib": True,
        "complevel": 4,
        "dtype": "float32",
        "chunksizes": (1, 200, 200),   # (time,lat,lon)
        "_FillValue": np.float32(0.0),
    }
}

xr.Dataset({"CWD": cwd}).to_netcdf(f"{data_path}CWD/CWD.nc", encoding=encoding)
print("CWD.nc has storage")
xr.Dataset({"CWD": cwd}).to_zarr(f"{data_path}CWD/CWD.zarr", mode="w")
print("CWD.zarr has storage")