In [1]:
import xarray as xr
import cf_xarray as cfxr
import pandas as pd

def load_from_nc(load_path: str, var_name: str = 'data') -> xr.DataArray:
    """
    从 NetCDF 文件恢复 DataArray,自动处理 multi-index
    """
    ds = xr.open_dataset(load_path)

    # 尝试解码 multi-index
    try:
        has_compress = any(
            'compress' in ds[var].attrs
            for var in ds.coords
            if var in ds.variables
        )
        if has_compress:
            ds = cfxr.decode_compress_to_multi_index(ds, 'layer')
    except (ValueError, KeyError):
        pass

    # 提取 DataArray
    if var_name in ds.data_vars:
        da = ds[var_name]
    else:
        var_name = list(ds.data_vars)[0]
        da = ds[var_name]

    return da

In [2]:
load_path = r"F:\Users\s222552331\Work\LUTO2_XH\luto-2.0\output\20260223_Paper2_Results_test\carbon_price\0_base_data\Run_25_GHG_off_BIO_off_CUT_10\2050\xr_biodiversity_GBF2_priority_ag_2050.nc"
xr_arr = xr.open_dataset(load_path)
print(xr_arr)

ValueError: did not find a match in any of xarray's currently installed IO backends ['netcdf4', 'h5netcdf', 'scipy', 'rasterio']. Consider explicitly selecting one of the installed engines via the ``engine`` parameter, or installing additional IO dependencies, see:
https://docs.xarray.dev/en/stable/getting-started-guide/installing.html
https://docs.xarray.dev/en/stable/user-guide/io.html

In [None]:
xr_arr.sum()

In [None]:
xr_arr.coords['From-land-use'].values

In [None]:
import numpy as np
import pandas as pd
import xarray as xr
from typing import Union, Iterable, Optional


def _ensure_layer_multiindex_strict(
    da: xr.DataArray,
    layer_dim: str = "layer",
    level_order: Optional[Iterable[str]] = None,
) -> xr.DataArray:
    """
    严格确保 da 的 layer 是 MultiIndex。
    规则：
    - 如果 layer 已经是 MultiIndex：直接返回
    - 否则要求存在 coords(dims=(layer,)) 的 level 变量（如 am/lm/lu/source...）
      用 set_index(layer=[...]) 构造 MultiIndex
    - 如果没有 level 变量：直接报错（因为无法还原）
    """
    if layer_dim not in da.dims:
        return da

    idx = da.indexes.get(layer_dim, None)
    if isinstance(idx, pd.MultiIndex):
        return da

    layer_level_vars = [
        c for c in da.coords
        if c != layer_dim and da.coords[c].dims == (layer_dim,)
    ]
    if not layer_level_vars:
        raise ValueError(
            f"'{layer_dim}' is not a MultiIndex and no layer-level coords exist to rebuild it. "
            f"Available coords with dims=('layer',): {layer_level_vars}"
        )

    if level_order is not None:
        level_order = list(level_order)
        layer_level_vars = [c for c in level_order if c in layer_level_vars] + \
                           [c for c in layer_level_vars if c not in level_order]

    return da.set_index({layer_dim: layer_level_vars})


def filter_all_from_dims(
    obj: Union[xr.Dataset, xr.DataArray],
    layer_dim: str = "layer",
    strict_layer_multiindex: bool = True,
    layer_level_order: Optional[Iterable[str]] = None,
) -> Union[xr.Dataset, xr.DataArray]:
    """
    同时过滤：
    1) 正常 dims 中坐标值为 'ALL' 的部分
    2) 如果存在 (cell, layer) 且 'ALL' 出现在 layer-level coords（dims=('layer',)），也过滤掉这些 layer

    不使用 try；如果 strict 且 layer 无法成为 MultiIndex，则直接报错。
    """

    def _filter_da(da: xr.DataArray) -> xr.DataArray:
        out = da

        # ---------- A) 过滤真正 dims 中的 ALL ----------
        for dim in list(out.dims):
            if dim in out.coords and out[dim].dtype.kind in ("U", "S", "O"):
                vals = out[dim].values
                if np.isin(vals, ["ALL"]).any():
                    out = out.sel({dim: out[dim] != "ALL"})

        # ---------- B) 过滤 layer-level coords 中的 ALL ----------
        if layer_dim in out.dims:
            # 仅当存在 layer-level coords 才做
            layer_level_vars = [
                c for c in out.coords
                if c != layer_dim and out.coords[c].dims == (layer_dim,)
            ]
            if layer_level_vars:
                if strict_layer_multiindex:
                    out = _ensure_layer_multiindex_strict(
                        out, layer_dim=layer_dim, level_order=layer_level_order
                    )

                # 现在 layer 必须是 MultiIndex（strict 情况下）
                if strict_layer_multiindex:
                    idx = out.indexes[layer_dim]
                    if not isinstance(idx, pd.MultiIndex):
                        raise ValueError(f"Expected '{layer_dim}' to be MultiIndex after ensure, got {type(idx)}")

                    # 过滤：任意 level == 'ALL' 的 layer 都删掉
                    keep = np.ones(len(idx), dtype=bool)
                    for lvl in idx.names:
                        # idx.get_level_values 返回 Index，可直接比较
                        keep &= (idx.get_level_values(lvl) != "ALL")

                    out = out.isel({layer_dim: keep})

        return out

    # Dataset：对每个变量应用（coords 保持）
    if isinstance(obj, xr.Dataset):
        new_vars = {}
        for v in obj.data_vars:
            new_vars[v] = _filter_da(obj[v])
        # 注意：不同变量过滤后 layer 可能不同长度；通常你是逐文件逐变量处理，这样是 OK 的
        return xr.Dataset(new_vars, attrs=obj.attrs)

    # DataArray
    return _filter_da(obj)

from typing import Sequence


def reduce_layered_da(
    da: xr.DataArray,
    dims_to_sum: Sequence[str],
    layer_dim: str = "layer",
    keep_attrs: bool = True,
    skipna: bool = True,
    layer_level_order: Optional[Iterable[str]] = None,
) -> xr.DataArray:
    """
    输入：da dims 通常为 ('cell','layer')，且 layer 是 MultiIndex（或可由 layer-level coords 严格重建）
    输出：仍为 ('cell','layer')，并保持 layer 为 MultiIndex（多级索引坐标可恢复、可 sel）

    不使用 try；无法 unstack 就直接报错。
    """
    if not np.issubdtype(da.dtype, np.number):
        return da

    if layer_dim not in da.dims:
        # 没有 layer，就按普通 dims_to_sum 求和
        present = [d for d in dims_to_sum if d in da.dims]
        return da.sum(dim=present, keep_attrs=keep_attrs, skipna=skipna) if present else da

    # 1) 严格确保 layer 是 MultiIndex
    da = _ensure_layer_multiindex_strict(da, layer_dim=layer_dim, level_order=layer_level_order)

    # 2) unstack：layer -> 多维 dims (am/lm/lu/source/...)
    da_u = da.unstack(layer_dim)

    # 3) 在真正 dims 上求和
    present = [d for d in dims_to_sum if d in da_u.dims]
    if present:
        da_u = da_u.sum(dim=present, keep_attrs=keep_attrs, skipna=skipna)

    # 4) stack 回 layer：把除 cell 以外的 dims 全部 stack
    stack_levels = [d for d in da_u.dims if d != "cell"]
    if layer_level_order is not None:
        layer_level_order = list(layer_level_order)
        stack_levels = [d for d in layer_level_order if d in stack_levels] + \
                       [d for d in stack_levels if d not in layer_level_order]

    if stack_levels:
        da_out = da_u.stack({layer_dim: stack_levels})
        # 让 MultiIndex 的 level coords 作为 (layer,) coords 存在（xarray 会自动创建）
        return da_out
    else:
        # 全 sum 掉了，只剩 cell
        return da_u



In [None]:
import xarray as xr
import numpy as np

# ====== 1. 配置你要看的文件 ======
file_path = load_path
engine = "h5netcdf"
chunks = "auto"

dims_to_sum = (
    'lm', 'source', 'Type', 'GHG_source',
    'Cost type', 'From water-supply', 'To water-supply'
)

layer_level_order = None
# 例如：
# layer_level_order = ['am','lm','lu','source','Type']

# ====== 2. 打开文件 ======
ds = cfxr.decode_compress_to_multi_index(xr.open_dataset(load_path), 'layer')
ds = ds.fillna(0)

print("\n=== ORIGINAL DATASET ===")
print(ds)

# ====== 3. 对每个变量执行你的新逻辑 ======
out_vars = {}

for v in ds.data_vars:
    print(f"\n--- Processing variable: {v} ---")
    da = ds[v]

    print("Before processing:")
    print("  dims:", da.dims)
    print("  coords:", list(da.coords))
    if "layer" in da.dims:
        print("  layer index type:", type(da.indexes.get("layer")))

    # (1) 过滤 ALL（包括 layer-level coords）
    da = filter_all_from_dims(
        da,
        layer_dim="layer",
        strict_layer_multiindex=True,
        layer_level_order=layer_level_order,
    )

    print("\nAfter filter_all_from_dims:")
    print("  dims:", da.dims)
    if "layer" in da.dims:
        print("  layer index type:", type(da.indexes["layer"]))
        if hasattr(da.indexes["layer"], "names"):
            print("  layer levels:", da.indexes["layer"].names)

    # (2) 可逆求和
    da = reduce_layered_da(
        da,
        dims_to_sum=dims_to_sum,
        layer_dim="layer",
        keep_attrs=True,
        skipna=True,
        layer_level_order=layer_level_order,
    )

    print("\nAfter reduce_layered_da:")
    print("  dims:", da.dims)
    if "layer" in da.dims:
        print("  layer index type:", type(da.indexes["layer"]))
        print("  layer levels:", da.indexes["layer"].names)
        print("  number of layers:", da.sizes["layer"])

        # 展示前几行 MultiIndex
        print("\n  layer MultiIndex preview:")
        print(da.indexes["layer"].to_frame().head())

    out_vars[v] = da

# ====== 4. 汇总为 Dataset（不保存） ======
out_ds = xr.Dataset(out_vars, attrs=ds.attrs)

print("\n================ FINAL RESULT ================")
print(out_ds)

# 如果你只关心 data 这个变量
if "data" in out_ds:
    da_final = out_ds["data"]
    print("\n=== FINAL DataArray SUMMARY ===")
    print("dims:", da_final.dims)
    print("sizes:", da_final.sizes)
    print("coords:")
    for c in da_final.coords:
        print(f"  {c}: dims={da_final.coords[c].dims}, dtype={da_final.coords[c].dtype}")

    if "layer" in da_final.dims:
        print("\nLayer MultiIndex detail:")
        print("  type:", type(da_final.indexes["layer"]))
        print("  level names:", da_final.indexes["layer"].names)
        print(da_final.indexes["layer"].to_frame().head())