In [2]:
import numpy as np
import xarray as xr
from pathlib import Path

# ==================== Configuration ====================
ROOT = Path("/home/xuxh22/stu01/bedrock/data").resolve()

PATH_CWD  = ROOT / "CWD.nc"
PATH_EF   = ROOT / "dEF/dEF.nc"
PATH_SIFY = ROOT / "dSIF/dSIF.nc"

VAR_CWD  = "CWD"
VAR_EF   = "dEF"
VAR_SIFY = "dSIF"

# Paper parameters
N_BINS = 50
Q = 0.90
TRIM_FRAC = 0.90

# Quality control
MIN_POINTS_TOTAL = 60
MIN_VALID_BINS   = 15
MIN_SEG_PTS      = 6
ZERO_EPS         = 1e-6

# ==================== Load Data ====================
ds_cwd = xr.open_dataset(PATH_CWD)
ds_ef  = xr.open_dataset(PATH_EF)
ds_sy  = xr.open_dataset(PATH_SIFY)

def pick_var(ds, name):
    """Pick variable by name, or use first variable if name is None."""
    if name is not None:
        return ds[name]
    return ds[list(ds.data_vars)[0]]

# Extract variables
cwd  = pick_var(ds_cwd, VAR_CWD).transpose("time", "lat", "lon")
ef   = pick_var(ds_ef, VAR_EF).transpose("time", "lat", "lon")
sify = pick_var(ds_sy, VAR_SIFY).transpose("time", "lat", "lon")

# Align coordinates (force consistency)
cwd  = cwd.assign_coords(lat=ef["lat"], lon=ef["lon"])
sify = sify.assign_coords(lat=ef["lat"], lon=ef["lon"])
cwd, ef, sify = xr.align(cwd, ef, sify, join="exact")

# ==================== Verify ====================
print("=" * 60)
print("Data loaded successfully")
print("=" * 60)
print(f"CWD shape:  {cwd.shape}")
print(f"EF shape:   {ef.shape}")
print(f"SIFY shape: {sify.shape}")
print(f"Time range: {cwd.time.values[0]} -> {cwd.time.values[-1]}")
print(f"Lat range:  {float(cwd.lat.min()):.2f} -> {float(cwd.lat.max()):.2f}")
print(f"Lon range:  {float(cwd.lon.min()):.2f} -> {float(cwd.lon.max()):.2f}")
print("=" * 60)

Data loaded successfully
CWD shape:  (828, 3600, 7200)
EF shape:   (828, 3600, 7200)
SIFY shape: (828, 3600, 7200)
Time range: 2003-01-01T00:00:00.000000000 -> 2020-12-26T00:00:00.000000000
Lat range:  -89.97 -> 89.98
Lon range:  -179.97 -> 179.98


In [5]:
import numba as nb

# ==================== Build Year Slices ====================
years = cwd["time"].dt.year.values.astype(np.int32)
unique_years = np.unique(years)

year_slices = []
for year in unique_years:
    indices = np.where(years == year)[0]
    if len(indices) > 0:
        year_slices.append((year, int(indices[0]), int(indices[-1] + 1)))

year_slices_arr = np.array(year_slices, dtype=np.int32)  # shape: (n_years, 3)

print(f"Found {len(year_slices)} years: {unique_years[0]} - {unique_years[-1]}")
print(year_slices_arr)

# ==================== Event Window Detection (Optimized) ====================
@nb.njit
def find_all_events_one_pixel(cwd_ts, year_slices_arr, zero_eps, trim_frac, min_seg_pts):
    """
    Identify drought event windows for all years in one pixel (optimized).
    
    Parameters:
    -----------
    cwd_ts : array (T,)
        Full CWD time series for one pixel
    year_slices_arr : array (n_years, 3)
        [(year, start_idx, end_idx), ...]
    zero_eps : float
        Threshold to consider CWD as zero
    trim_frac : float
        Fraction of CWDmax to trim event
    min_seg_pts : int
        Minimum points required in event window
    
    Returns:
    --------
    events : array (n_years, 4)
        For each year: [valid, start, end, cwdmax]
    """
    n_years = year_slices_arr.shape[0]
    events = np.zeros((n_years, 4), dtype=np.float32)
    
    for k in range(n_years):
        year_start = year_slices_arr[k, 1]
        year_end = year_slices_arr[k, 2]
        cwd_year = cwd_ts[year_start:year_end]
        T = cwd_year.size
        
        # Step 1: Find peak using argmax (handles NaN automatically in numba)
        # Replace NaN with -inf for argmax
        cwd_clean = np.where(np.isfinite(cwd_year), cwd_year, -np.inf)
        peak = np.argmax(cwd_clean)
        cwdmax = cwd_clean[peak]
        
        # Check if valid peak found
        if cwdmax <= 0.0 or not np.isfinite(cwdmax):
            events[k] = [0, 0, 0, 0.0]
            continue
        
        # Step 2: Find event start (vectorized backward search)
        # Find all points <= zero_eps before peak
        before_peak = cwd_year[:peak + 1]
        is_zero = np.where(np.isfinite(before_peak), before_peak <= zero_eps, False)
        zero_indices = np.where(is_zero)[0]
        
        if zero_indices.size > 0:
            start = zero_indices[-1] + 1  # Last zero before peak
        else:
            start = 0
        
        # Step 3: Find event end (vectorized forward search)
        threshold = trim_frac * cwdmax
        after_peak = cwd_year[peak + 1:]
        is_below_threshold = np.where(
            np.isfinite(after_peak), 
            after_peak < threshold, 
            False
        )
        below_indices = np.where(is_below_threshold)[0]
        
        if below_indices.size > 0:
            end = peak + 1 + below_indices[0]  # First below threshold after peak
        else:
            end = T
        
        # Step 4: Check minimum window length
        if end - start < min_seg_pts:
            events[k] = [0, 0, 0, 0.0]
        else:
            events[k] = [1, start, end, cwdmax]
    
    return events


# ==================== Test on Sample Pixel ====================
def test_event_detection(i_lat=1800, i_lon=3600):
    """Test event detection on a sample pixel."""
    cwd_pixel = cwd[:, i_lat, i_lon].values
    
    # Process all years at once
    events = find_all_events_one_pixel(
        cwd_pixel, year_slices_arr, ZERO_EPS, TRIM_FRAC, MIN_SEG_PTS
    )
    
    print(f"\nTesting pixel (lat={i_lat}, lon={i_lon}):")
    print("=" * 60)
    
    for k in range(len(year_slices_arr)):
        year_id = year_slices_arr[k, 0]
        valid, start, end, cwdmax = events[k]
        
        if valid == 1:
            window_len = int(end - start)
            print(f"Year {year_id}: CWDmax={cwdmax:6.1f} mm, "
                  f"window=[{int(start):3d}, {int(end):3d}), length={window_len:3d} days")
        else:
            print(f"Year {year_id}: No valid event")
    
    print("=" * 60)

# Run test
test_event_detection()

Found 18 years: 2003 - 2020
[[2003    0   46]
 [2004   46   92]
 [2005   92  138]
 [2006  138  184]
 [2007  184  230]
 [2008  230  276]
 [2009  276  322]
 [2010  322  368]
 [2011  368  414]
 [2012  414  460]
 [2013  460  506]
 [2014  506  552]
 [2015  552  598]
 [2016  598  644]
 [2017  644  690]
 [2018  690  736]
 [2019  736  782]
 [2020  782  828]]

Testing pixel (lat=1800, lon=3600):
Year 2003: No valid event
Year 2004: No valid event
Year 2005: No valid event
Year 2006: No valid event
Year 2007: No valid event
Year 2008: No valid event
Year 2009: No valid event
Year 2010: No valid event
Year 2011: No valid event
Year 2012: No valid event
Year 2013: No valid event
Year 2014: No valid event
Year 2015: No valid event
Year 2016: No valid event
Year 2017: No valid event
Year 2018: No valid event
Year 2019: No valid event
Year 2020: No valid event


In [6]:
import numba as nb
import numpy as np
import time  # 1. 导入 time 库

# 普通 Python 函数
def slow_sum(arr):
    total = 0.0
    for x in arr:
        total += x
    return total

# Numba 加速函数 (JIT)
@nb.njit
def fast_sum(arr):
    total = 0.0
    for x in arr:
        total += x
    return total

# 准备数据 (一千万个随机数)
print("正在生成数据...")
data = np.random.random(10000000)
print("数据生成完毕，开始测试。\n")

# ==========================================
# 1. 测试普通函数 (Slow)
# ==========================================
start_time = time.time()  # 【打卡开始】
slow_sum(data)
end_time = time.time()    # 【打卡结束】

print(f"普通 Python 函数耗时: {end_time - start_time:.4f} 秒")


# ==========================================
# 2. 测试 Numba 函数 (第一次运行 - 包含编译时间)
# ==========================================
start_time = time.time()
fast_sum(data)
end_time = time.time()

print(f"Numba (第1次运行, 含编译): {end_time - start_time:.4f} 秒")


# ==========================================
# 3. 测试 Numba 函数 (第二次运行 - 纯运行时间)
# ==========================================
start_time = time.time()
fast_sum(data)
end_time = time.time()

print(f"Numba (第2次运行, 已加速): {end_time - start_time:.4f} 秒")

正在生成数据...
数据生成完毕，开始测试。

普通 Python 函数耗时: 0.7461 秒
Numba (第1次运行, 含编译): 1.1517 秒
Numba (第2次运行, 已加速): 0.0118 秒


In [2]:
import numpy as np
arr = np.array([10, 2, 0, 5, 0, 8])

# 找所有等于0的位置
result = np.where(arr == 0)      # 返回 (array([2, 4]),)
indices = np.where(arr == 0)[0]  # 返回 array([2, 4])

print(result) # 输出 [2 4]
print(indices) # 输出 [2 4]

(array([2, 4]),)
[2 4]


In [5]:
arr = np.zeros((50, 10)) # 50行，10列
print(arr.shape)    # (50, 10)
print(arr.shape[0]) # 50 (行数)
print(np.nan) # 10 (列数)

(50, 10)
50
nan
