In [None]:
import numpy as np
import pandas as pd
from pathlib import Path
import re
import sep
from scipy.ndimage import gaussian_filter
import matplotlib.pyplot as plt

# ================== 1. 滤波器定义 ==================

def normalize_filter(weight_array, epsilon=1e-10):
    total = np.sum(weight_array)
    if total < epsilon:
        raise ValueError(
            f"Cannot normalize: filter weights sum to {total}, below threshold {epsilon}."
        )
    return weight_array / total

def smooth_center_taper(x, epsilon=1e-5):
    return x**2 / (x**2 + epsilon)

def schirmer_weight(x):
    a, b, c, d, xc = 6., 150., 47., 50., 0.15
    Q = 1. / (1. + np.exp(a - b * x) + np.exp(d * x - c))
    with np.errstate(divide='ignore', invalid='ignore'):
        fx = np.tanh(x / xc) / (x / xc)
        fx[np.isnan(fx)] = 1.0
    return Q * fx

def gaussian_weight(x):
    return np.exp(-x**2 / (2 * 0.3 ** 2))  # std = 0.3

def shifted_gaussian_weight(x, mu=0.5, sigma=0.3):
    return np.exp(-((x - mu)**2) / (2 * sigma**2))

def top_hat_weight(x, smoothing_width=0.05):
    Q = np.ones_like(x)
    mask = (x > 1.0) & (x <= 1.0 + smoothing_width)
    Q[mask] = 1.0 - (x[mask] - 1.0) / smoothing_width
    Q[x > 1.0 + smoothing_width] = 0.0
    return Q

def nfw_matched_weight(x):
    Q = np.zeros_like(x)
    eps = 1e-6
    x = np.clip(x, eps, None)

    mask1 = (x < 1)
    mask2 = (x > 1)
    mask3 = (x == 1)

    Q[mask1] = (2 / (x[mask1]**2 - 1)) * (
        1 - (2 / np.sqrt(1 - x[mask1]**2)) *
        np.arctanh(np.sqrt((1 - x[mask1])/(1 + x[mask1])))
    )
    Q[mask2] = (2 / (x[mask2]**2 - 1)) * (
        1 - (2 / np.sqrt(x[mask2]**2 - 1)) *
        np.arctan(np.sqrt((x[mask2] - 1)/(x[mask2] + 1)))
    )
    Q[mask3] = 10/3 - 4*np.log(2)
    return Q

def no_filter_weight(x):
    return np.ones_like(x)

filter_bank = {
    "schirmer": schirmer_weight,
    "gaussian": gaussian_weight,
    "shifted_gaussian": shifted_gaussian_weight,
    "tophat": top_hat_weight,
    "nfw": nfw_matched_weight,
    "none": no_filter_weight,
}


# ================== 2. 配置区 ==================

CATALOG_DIR = Path("C:/Users/skyma/Desktop/t1/csv")   # 你的 catalog *.csv 所在目录
PEAKS_DIR   = Path("C:/Users/skyma/Desktop/t1/batch_sep")      # 对应 *_peaks.csv 所在目录
OUT_DIR     = Path("C:/Users/skyma/Desktop/t1/blend_out")

OUT_DIR.mkdir(parents=True, exist_ok=True)
(OUT_DIR / "maps").mkdir(exist_ok=True)
(OUT_DIR / "png").mkdir(exist_ok=True)

# bin_size 列表（单位：原始像素），同时决定：
# 1）mass map 的像素大小
# 2）滤波半径 Rs_input_pix = Rs_rel * bin_size_pix
BIN_SIZE_LIST = [500.0, 1000.0, 2000.0]

# 相对 Rs（无量纲）：Rs_rel = Rs_input / bin_size
REL_RS_LIST   = [5.0, 10.0, 20.0, 40.0]

# 滤波器列表
FILTER_LIST   = ["schirmer", "gaussian", "tophat", "nfw", "none"]

# aperture 积分最大半径
RMAX_FACTOR = 3.0  # 积分到 Rmax = RMAX_FACTOR * Rs_input_pix

# SEP 动态阈值
SNR_THRESH_MAIN = 3.0
SNR_THRESH_MIN  = 1.0
SNR_N_STEPS     = 4
MINAREA         = 1

# 自动扫描的 S/N 阈值列表（从高到低，越前面优先级越高）
SNR_SCAN_LIST = [4.0, 3.5, 3.0, 2.5, 2.0, 1.5]

# 真值匹配容差（小图“像素索引”单位）
MATCH_TOL = 3.0
MID_FRAC  = 0.3

# fallback S/N 判定参数
FALLBACK_SN_MIN    = 1.0
FALLBACK_DELTA_MID = 0.6

# mass map 至少的像素数（避免 bin_size 过大时变成 1x1 图）
MIN_NPIX_X = 16
MIN_NPIX_Y = 16


# ===== Checkpoint 设置 =====
CKPT_PATH = OUT_DIR / "checkpoint_done.csv"

# 加载或初始化 checkpoint 表
if CKPT_PATH.exists():
    checkpoint_df = pd.read_csv(CKPT_PATH)
else:
    checkpoint_df = pd.DataFrame(
        columns=["catalog", "bin_size", "Rs_rel", "filter"]
    )

def is_done(cat_name, bin_size, Rs_rel, filter_name):
    """判断某个组合是否已经跑过了。"""
    if checkpoint_df.empty:
        return False
    mask = (
        (checkpoint_df["catalog"] == cat_name) &
        (checkpoint_df["bin_size"] == bin_size) &
        (checkpoint_df["Rs_rel"] == Rs_rel) &
        (checkpoint_df["filter"] == filter_name)
    )
    return bool(mask.any())

def save_checkpoint(cat_name, bin_size, Rs_rel, filter_name):
    """将已完成的组合写入 checkpoint 文件。"""
    global checkpoint_df
    new_row = pd.DataFrame([{
        "catalog": cat_name,
        "bin_size": bin_size,
        "Rs_rel": Rs_rel,
        "filter": filter_name
    }])
    checkpoint_df = pd.concat([checkpoint_df, new_row], ignore_index=True)
    checkpoint_df.to_csv(CKPT_PATH, index=False)


# ================== 3. 文件 & 真值工具 ==================

def base_name_from_catalog(path: Path) -> str:
    stem = path.stem
    stem = re.sub(r"\.r\d+$", "", stem)
    stem = re.sub(r"_r\d+$", "", stem)
    return stem

def load_truth_peaks(peaks_dir: Path, cat_path: Path) -> pd.DataFrame:
    base = base_name_from_catalog(cat_path)
    peaks_path = peaks_dir / f"{base}_peaks.csv"
    if not peaks_path.exists():
        raise FileNotFoundError(f"missing peaks file: {peaks_path}")
    return pd.read_csv(peaks_path)

def true_sep_from_peaks(peaks_df: pd.DataFrame) -> float:
    if len(peaks_df) < 2:
        return np.nan
    if "cluster_id" in peaks_df.columns:
        sub = peaks_df[peaks_df["cluster_id"].isin([1, 2])]
        if len(sub) < 2:
            sub = peaks_df.iloc[:2]
    else:
        sub = peaks_df.iloc[:2]
    a, b = sub.iloc[0], sub.iloc[1]
    dx = b["x_peak"] - a["x_peak"]
    dy = b["y_peak"] - a["y_peak"]
    return float(np.hypot(dx, dy))


# ================== 4. mass map：由 bin_size 决定网格 ==================

def build_aperture_mass_map_from_catalog(
    df: pd.DataFrame,
    Rs_pix: float,
    filter_name: str,
    bin_size_pix: float,
    x_col: str = "x",
    y_col: str = "y",
    e1_col: str = "e1",
    e2_col: str = "e2",
    use_center_taper: bool = False,
):
    """
    使用 bin_size_pix 决定 mass map 的分辨率：
      - 像素宽度 = bin_size_pix（原始像素）
      - 像素中心：x_min + (k+0.5)*bin_size_pix
    滤波器半径 Rs_pix = Rs_rel * bin_size_pix.
    """
    x = df[x_col].to_numpy()
    y = df[y_col].to_numpy()
    e1 = df[e1_col].to_numpy()
    e2 = df[e2_col].to_numpy()
    mask = np.isfinite(x) & np.isfinite(y) & np.isfinite(e1) & np.isfinite(e2)
    x, y, e1, e2 = x[mask], y[mask], e1[mask], e2[mask]

    x_min, x_max = x.min(), x.max()
    y_min, y_max = y.min(), y.max()

    # 计算 mass map 的像素数量（至少 MIN_NPIX）
    Lx = x_max - x_min
    Ly = y_max - y_min
    nx = max(MIN_NPIX_X, int(np.ceil(Lx / bin_size_pix)))
    ny = max(MIN_NPIX_Y, int(np.ceil(Ly / bin_size_pix)))

    xs_grid = x_min + (np.arange(nx) + 0.5) * bin_size_pix
    ys_grid = y_min + (np.arange(ny) + 0.5) * bin_size_pix

    map_ap = np.zeros((ny, nx), dtype=np.float32)
    Rmax = RMAX_FACTOR * Rs_pix

    filt_func = filter_bank[filter_name]

    for iy, y0 in enumerate(ys_grid):
        dy = y - y0
        for ix, x0 in enumerate(xs_grid):
            dx = x - x0
            r  = np.hypot(dx, dy)
            m  = r <= Rmax
            if not np.any(m):
                continue

            dxm = dx[m]
            dym = dy[m]
            em1 = e1[m]
            em2 = e2[m]
            rm  = r[m]

            phi = np.arctan2(dym, dxm)
            cos2phi = np.cos(2.0 * phi)
            sin2phi = np.sin(2.0 * phi)
            e_t = -(em1 * cos2phi + em2 * sin2phi)  # tangential ellipticity

            x_rel = rm / Rs_pix                     # r/Rs
            Q_raw = filt_func(x_rel)
            if use_center_taper:
                Q_raw = Q_raw * smooth_center_taper(x_rel)

            try:
                Q = normalize_filter(Q_raw)
            except ValueError:
                continue

            map_ap[iy, ix] = np.sum(e_t * Q)

    return map_ap, xs_grid, ys_grid


def make_sn_map(map_ap: np.ndarray, smooth_sigma_pix: float = 1.0) -> np.ndarray:
    img = map_ap.copy()
    if smooth_sigma_pix > 0:
        img = gaussian_filter(img, sigma=smooth_sigma_pix)
    mu = np.mean(img)
    sigma = np.std(img)
    if sigma <= 0:
        sigma = 1.0
    return (img - mu) / sigma


# ================== 5. SEP + 真值映射 ==================

def run_sep_on_map_adaptive(
    sn_map: np.ndarray,
    snr_main: float = SNR_THRESH_MAIN,
    snr_min: float = SNR_THRESH_MIN,
    n_steps: int = SNR_N_STEPS,
    minarea: int = MINAREA,
):
    data = np.ascontiguousarray(sn_map.astype(np.float32))
    bkg_val = np.median(data)
    rms = np.std(data)
    if rms <= 0:
        rms = 1.0
    data_sub = data - bkg_val

    sep.set_extract_pixstack(5_000_000)

    snr_list = [snr_main] if n_steps <= 1 else np.linspace(snr_main, snr_min, n_steps)

    last_objs = None
    last_snr  = snr_list[-1]

    for snr in snr_list:
        thresh = snr * rms
        objs = sep.extract(data_sub, thresh=thresh, err=rms, minarea=minarea)
        if len(objs) > 0:
            return objs, snr
        last_objs = objs
        last_snr  = snr

    return last_objs, last_snr


def run_sep_auto_threshold(
    sn_map: np.ndarray,
    truth_xy_small,
    true_sep_small: float,
    snr_list=None,
    minarea: int = MINAREA,
):
    """
    对同一张 sn_map 自动扫描多个 S/N 阈值。
    优先选择“能给出明确 merge/non-merge（detection_error=False）”的最高阈值。
    若所有阈值都 detection_error，则返回最后一次的结果，交给 fallback。
    """
    if snr_list is None:
        snr_list = SNR_SCAN_LIST

    # 准备数据和噪声估计
    data = np.ascontiguousarray(sn_map.astype(np.float32))
    bkg_val = np.median(data)
    rms = np.std(data)
    if rms <= 0:
        rms = 1.0
    data_sub = data - bkg_val

    sep.set_extract_pixstack(5_000_000)

    best_objects = None
    best_snr = None
    best_cls = None

    last_objects = None
    last_snr = None
    last_cls = None

    for snr in snr_list:
        thresh = snr * rms
        objs = sep.extract(data_sub, thresh=thresh, err=rms, minarea=minarea)

        cls = classify_merge_custom(
            truth_xy_small=truth_xy_small,
            objects=objs,
            true_sep_small=true_sep_small,
            match_tol_pix=MATCH_TOL,
            mid_frac=MID_FRAC,
        )

        # 记录最后一次结果
        last_objects = objs
        last_snr = snr
        last_cls = cls

        # 一旦有一个阈值能给出明确的 merge/non-merge，就用它
        if not cls["detection_error"]:
            best_objects = objs
            best_snr = snr
            best_cls = cls
            break

    # 如果在所有阈值上都 detection_error，就返回最后一次结果
    if best_cls is None:
        best_objects = last_objects
        best_snr = last_snr
        if last_cls is not None:
            best_cls = last_cls
        else:
            best_cls = {
                "n_true": len(truth_xy_small),
                "n_detected_peaks": 0 if best_objects is None else len(best_objects),
                "merged": None,
                "detection_error": True,
                "pattern": "auto_snr_all_error",
            }

    return best_objects, best_snr, best_cls


def compute_truth_small_coords(
    peaks_df: pd.DataFrame,
    xs_grid: np.ndarray,
    ys_grid: np.ndarray,
):
    """
    把真值 x_peak,y_peak 映射到 mass map 像素索引 (ix,iy)。
    通过“找最近的像素中心”来实现。
    """
    if len(peaks_df) == 0:
        return [], np.nan

    if "cluster_id" in peaks_df.columns:
        sub = peaks_df[peaks_df["cluster_id"].isin([1, 2])]
        if len(sub) < 2:
            sub = peaks_df.iloc[:2]
    else:
        sub = peaks_df.iloc[:2]

    x_true = sub["x_peak"].to_numpy()
    y_true = sub["y_peak"].to_numpy()

    ix_list = []
    iy_list = []
    for xt, yt in zip(x_true, y_true):
        ix = int(np.argmin(np.abs(xs_grid - xt)))
        iy = int(np.argmin(np.abs(ys_grid - yt)))
        ix_list.append(ix)
        iy_list.append(iy)

    truth_xy_small = list(zip(ix_list, iy_list))

    if len(truth_xy_small) >= 2:
        dx = ix_list[1] - ix_list[0]
        dy = iy_list[1] - iy_list[0]
        true_sep_small = float(np.hypot(dx, dy))   # 单位：mass map 像素
    else:
        true_sep_small = np.nan

    return truth_xy_small, true_sep_small


# ================== 6. merge 判定（你的规则） ==================

def classify_merge_custom(
    truth_xy_small,
    objects,
    true_sep_small: float,
    match_tol_pix: float = MATCH_TOL,
    mid_frac: float = MID_FRAC,
):
    """
    根据真值位置 (truth_xy_small) 和 SEP objects 来判断：
      - merged: True / False / None
      - detection_error: 是否需要 fallback
      - pattern: SEP 模式标签（便于 debug / 画图）

    所有坐标和距离单位都是 mass-map 图像的像素索引。
    """
    result = {
        "n_true": len(truth_xy_small),
        "n_detected_peaks": 0 if objects is None else len(objects),
        "merged": None,
        "detection_error": False,
        "pattern": "none",
    }

    n_true = len(truth_xy_small)
    n_det  = result["n_detected_peaks"]

    # --------- 0. 没有真值（理论上不会发生）---------
    if n_true == 0:
        result["detection_error"] = True
        result["pattern"] = "no_truth"
        return result

    # --------- 1. 只有 1 个真实 cluster 的情况 ---------
    if n_true == 1:
        if n_det == 0:
            # 什么都没探测到 → 明显错误
            result["detection_error"] = True
            result["pattern"] = "1truth_0det"
            return result

        # 探测到了至少 1 个峰 → 只是单 cluster，谈不上 merge，
        # 这里直接认定为 non-merge
        result["merged"] = False
        result["pattern"] = "1truth_ge1det_nonmerge"
        return result

    # 真值 >= 2，只取前两个做 merge / non-merge 判定
    truths = np.array(truth_xy_small[:2])
    t1, t2 = truths[0], truths[1]
    mid = 0.5 * (t1 + t2)

    # 如果没有 SEP peak
    if n_det == 0:
        result["detection_error"] = True
        result["pattern"] = "0det"
        return result

    # SEP peak 坐标
    det_xy = np.vstack((objects["x"], objects["y"])).T

    # 距离：到 t1 / t2 / mid
    d1   = np.hypot(det_xy[:, 0] - t1[0], det_xy[:, 1] - t1[1])
    d2   = np.hypot(det_xy[:, 0] - t2[0], det_xy[:, 1] - t2[1])
    dmid = np.hypot(det_xy[:, 0] - mid[0], det_xy[:, 1] - mid[1])

    # 容差：真值点用 match_tol_pix
    tol_true = match_tol_pix
    # 中点容差：max(match_tol_pix, mid_frac * true_sep_small)
    if np.isfinite(true_sep_small):
        tol_mid = max(match_tol_pix, mid_frac * true_sep_small)
    else:
        tol_mid = match_tol_pix * 1.5

    near_t1_idx  = np.where(d1   <= tol_true)[0]
    near_t2_idx  = np.where(d2   <= tol_true)[0]
    near_mid_idx = np.where(dmid <= tol_mid )[0]

    # ======================================================================
    # 情况 A：只探测到 1 个 peak
    # ======================================================================
    if n_det == 1:
        i0 = 0
        cond_mid = (i0 in near_mid_idx)
        cond_t1  = (i0 in near_t1_idx)
        cond_t2  = (i0 in near_t2_idx)

        # ⭐ 优先 merge：同时满足 mid & (t1 or t2)
        if cond_mid and (cond_t1 or cond_t2):
            result["merged"] = True
            result["pattern"] = "1det_mid_and_truth_merge"
            return result

        # 只满足 mid 区域 → merge
        if cond_mid:
            result["merged"] = True
            result["pattern"] = "1det_mid_merge"
            return result

        # 只贴着某个 truth：
        # 物理上可能是偏心单峰 → 交给 fallback 再决定
        if cond_t1 or cond_t2:
            result["detection_error"] = True
            result["pattern"] = "1det_truth_only_ambiguous"
            return result

        # 完全乱飞 → 错误
        result["detection_error"] = True
        result["pattern"] = "1det_far_ambiguous"
        return result

    # ======================================================================
    # 情况 B：探测到 ≥ 2 个 peaks
    # ======================================================================
    has_mid        = len(near_mid_idx) > 0
    has_t1         = len(near_t1_idx) > 0
    has_t2         = len(near_t2_idx) > 0
    has_pair_truth = has_t1 and has_t2

    has_merge_pattern    = has_mid              # 至少一个 peak 在中点区域
    has_nonmerge_pattern = has_pair_truth      # 至少有一对在 t1/t2 附近

    # ---- B1: 只有 merge 模式 ----
    if has_merge_pattern and not has_nonmerge_pattern:
        # 即使有其他乱飞的 peaks，只要有一个位于中点区域，就直接视为 merge
        result["merged"] = True
        result["pattern"] = "multi_merge_only"
        return result

    # ---- B2: 只有 non-merge 模式 ----
    if has_nonmerge_pattern and not has_merge_pattern:
        # 存在 peak 分别贴近 t1,t2 → 非 merge
        result["merged"] = False
        result["pattern"] = "multi_nonmerge_only"
        return result

    # ---- B3: merge & non-merge 同时存在（冲突）----
    if has_merge_pattern and has_nonmerge_pattern:
        # 交给 fallback：用 S/N 决定
        result["detection_error"] = True
        result["pattern"] = "multi_merge_and_nonmerge_conflict"
        return result

    # ---- B4: 什么模式都没有 ----
    # 比如只在 t1 附近有 peaks, t2 附近没有；或者全部都离得远
    result["detection_error"] = True
    result["pattern"] = "multi_no_pattern"
    return result


# ================== 7. fallback：从 S/N map 直接看 ==================

def fallback_classify_from_snmap(
    sn_map: np.ndarray,
    truth_xy_small,
    true_sep_small: float,
    sn_min: float = FALLBACK_SN_MIN,
    delta_mid: float = FALLBACK_DELTA_MID,
):
    res = {
        "merged_fb": None,
        "sn1": np.nan,
        "sn2": np.nan,
        "sn_mid": np.nan,
    }

    if len(truth_xy_small) < 2:
        return res

    (x1, y1), (x2, y2) = truth_xy_small[0], truth_xy_small[1]
    xm = 0.5 * (x1 + x2)
    ym = 0.5 * (y1 + y2)

    def _sample_sn(x, y):
        xi = int(np.clip(round(x), 0, sn_map.shape[1]-1))
        yi = int(np.clip(round(y), 0, sn_map.shape[0]-1))
        return float(sn_map[yi, xi])

    sn1 = _sample_sn(x1, y1)
    sn2 = _sample_sn(x2, y2)
    sn_mid = _sample_sn(xm, ym)

    res["sn1"], res["sn2"], res["sn_mid"] = sn1, sn2, sn_mid

    sn_max = max(sn1, sn2)

    if (sn1 > sn_min) and (sn2 > sn_min) and (sn_mid < sn_max - delta_mid):
        res["merged_fb"] = False
        return res

    if (sn_mid > sn_min) and (sn_mid >= sn_max - delta_mid) and (sn_mid > sn_max):
        res["merged_fb"] = True
        return res

    return res


# ================== 8. PNG 输出（标注最终结果） ==================

def plot_diagnostic_map(
    sn_map: np.ndarray,
    objects,
    truth_xy_small,
    out_path: Path,
    title_main: str = "",
    merged_final=None,
    merged_sep_rule=None,
    detection_error_sep=False,
    used_fallback=False,
    snr_used=None,
    n_sep_peaks=None,
    Rs_rel=None,
    Rs_pix=None,
    bin_size_pix=None,
    filter_name=None,
    pattern_sep_rule: str = None,
):
    fig, ax = plt.subplots(figsize=(6, 5))
    im = ax.imshow(sn_map, origin="lower", cmap="inferno")
    fig.colorbar(im, ax=ax, label="M_ap / sigma")

    if objects is not None and len(objects) > 0:
        ax.scatter(
            objects["x"], objects["y"],
            s=60, facecolors="none", edgecolors="cyan",
            linewidths=1.5, label="SEP peaks"
        )

    if truth_xy_small:
        xt = [p[0] for p in truth_xy_small]
        yt = [p[1] for p in truth_xy_small]
        ax.scatter(
            xt, yt,
            s=80, marker="x", color="red",
            linewidths=2.0, label="truth"
        )

    def _fmt(val):
        if val is True:
            return "MERGE"
        if val is False:
            return "NON-MERGE"
        return "None"

    text_lines = [
        f"merged_final = {_fmt(merged_final)}",
        f"merged_sep_rule = {_fmt(merged_sep_rule)}",
        f"pattern_sep_rule = {pattern_sep_rule}" if pattern_sep_rule is not None else "",
        f"SEP peaks = {n_sep_peaks}",
        f"detection_error_sep = {detection_error_sep}",
        f"used_fallback = {used_fallback}",
        f"snr_used = {snr_used:.2f}" if snr_used is not None else "",
        f"bin_size = {bin_size_pix} pix" if bin_size_pix is not None else "",
        f"Rs_rel = {Rs_rel}, Rs_pix = {Rs_pix:.0f}",
        f"filter = {filter_name}" if filter_name is not None else "",
    ]
    text_lines = [t for t in text_lines if t]

    ax.text(
        0.02, 0.98,
        "\n".join(text_lines),
        transform=ax.transAxes,
        fontsize=8, color="white",
        ha="left", va="top",
        bbox=dict(facecolor="black", alpha=0.35, edgecolor="none", pad=6)
    )

    ax.set_xlabel("mass-map pixel x")
    ax.set_ylabel("mass-map pixel y")
    ax.set_title(title_main, fontsize=10)

    if (objects is not None and len(objects) > 0) or truth_xy_small:
        ax.legend(loc="upper right", fontsize=7)

    fig.tight_layout()
    fig.savefig(out_path, dpi=150)
    plt.close(fig)

def plot_merge_probability_heatmaps(summary_df: pd.DataFrame, out_dir: Path):
    """
    对每个 filter 画一张 2D heatmap:
      x 轴 = Rs_rel
      y 轴 = bin_size_pix
      color = P(merged_final=True)
    """
    # 只保留有明确 merge/non-merge 的样本
    df = summary_df.copy()
    df = df[df["merged_final"].isin([True, False])]

    if df.empty:
        print("No valid merged_final data to make heatmaps.")
        return

    filters = sorted(df["filter"].unique())

    for f in filters:
        sub = df[df["filter"] == f].copy()
        if sub.empty:
            continue

        # 转成数值 1/0
        sub["is_merged"] = sub["merged_final"].astype(int)

        # 分组：每个 (bin_size, Rs_rel) 的 merge 概率
        grouped = sub.groupby(["bin_size_pix", "Rs_rel"])["is_merged"].mean().reset_index()

        bin_sizes = sorted(grouped["bin_size_pix"].unique())
        Rs_vals   = sorted(grouped["Rs_rel"].unique())

        # 构造 2D 网格
        mat = np.full((len(bin_sizes), len(Rs_vals)), np.nan, dtype=float)
        for i, b in enumerate(bin_sizes):
            for j, r in enumerate(Rs_vals):
                tmp = grouped[(grouped["bin_size_pix"] == b) & (grouped["Rs_rel"] == r)]
                if not tmp.empty:
                    mat[i, j] = tmp["is_merged"].iloc[0]

        fig, ax = plt.subplots(figsize=(6, 5))
        im = ax.imshow(
            mat,
            origin="lower",
            aspect="auto",
            extent=[
                min(Rs_vals)-0.5, max(Rs_vals)+0.5,
                min(bin_sizes)-0.5, max(bin_sizes)+0.5
            ],
            vmin=0.0, vmax=1.0,
            cmap="viridis",
        )
        cbar = fig.colorbar(im, ax=ax)
        cbar.set_label("P(merged_final = True)")

        ax.set_xlabel("Rs_rel")
        ax.set_ylabel("bin_size_pix")
        ax.set_title(f"Merge probability heatmap (filter={f})")

        # 刻度设成整数 Rs_rel / bin_size
        ax.set_xticks(Rs_vals)
        ax.set_yticks(bin_sizes)

        fig.tight_layout()
        out_png = out_dir / f"merge_prob_heatmap_filter_{f}.png"
        fig.savefig(out_png, dpi=150)
        plt.close(fig)

        print(f"Saved heatmap for filter={f} to {out_png}")


# ================== 9. 主循环：catalog × bin_size × Rs_rel × filter ==================

def process_all_catalogs():
    rows = []

    cat_files = sorted(CATALOG_DIR.glob("*.csv"))
    cat_files = [f for f in cat_files if not f.name.endswith("_peaks.csv")]

    print(f"Found {len(cat_files)} catalog files")

    for i, cat_path in enumerate(cat_files, start=1):
        print(f"\n===== [{i}/{len(cat_files)}] {cat_path.name} =====")
        df = pd.read_csv(cat_path)

        try:
            peaks_df = load_truth_peaks(PEAKS_DIR, cat_path)
        except FileNotFoundError as e:
            print("  WARNING:", e)
            peaks_df = pd.DataFrame()

        true_sep_orig = true_sep_from_peaks(peaks_df)
        print(f"  true sep (orig pix) = {true_sep_orig}")

        for bin_size_pix in BIN_SIZE_LIST:
            print(f"  === bin_size = {bin_size_pix} pix ===")

            for rel_rs in REL_RS_LIST:
                Rs_pix = rel_rs * bin_size_pix
                print(f"    -- Rs_rel={rel_rs:.1f}, Rs_pix={Rs_pix:.1f} --")

                for filter_name in FILTER_LIST:

                    # --- checkpoint: 已经跑过就跳过 ---
                    if is_done(cat_path.name, bin_size_pix, rel_rs, filter_name):
                        print(f"       * {filter_name}: SKIP (checkpoint)")
                        continue

                    print(f"       * filter = {filter_name}")

                    # 1) mass map
                    map_ap, xs_grid, ys_grid = build_aperture_mass_map_from_catalog(
                        df,
                        Rs_pix=Rs_pix,
                        filter_name=filter_name,
                        bin_size_pix=bin_size_pix,
                        x_col="x",
                        y_col="y",
                        e1_col="e1",
                        e2_col="e2",
                        use_center_taper=False,
                    )

                    # 2) S/N map
                    sn_map = make_sn_map(map_ap, smooth_sigma_pix=1.0)

                    # 3) 真值映射到 mass-map 像素坐标
                    truth_xy_small, true_sep_small = compute_truth_small_coords(
                        peaks_df,
                        xs_grid,
                        ys_grid,
                    )

                    # 4) 自动阈值扫描 + SEP + 初步分类
                    objects, snr_used, cls = run_sep_auto_threshold(
                        sn_map,
                        truth_xy_small,
                        true_sep_small,
                        snr_list=SNR_SCAN_LIST,
                        minarea=MINAREA,
                    )

                    n_sep_peaks = 0 if objects is None else len(objects)

                    merged_final = cls["merged"]
                    used_fallback = False

                    # 5) 如果 SEP 模式仍然 detection_error，则 fallback
                    if cls["detection_error"]:
                        fb = fallback_classify_from_snmap(
                            sn_map,
                            truth_xy_small,
                            true_sep_small=true_sep_small,
                            sn_min=FALLBACK_SN_MIN,
                            delta_mid=FALLBACK_DELTA_MID,
                        )
                        if fb["merged_fb"] is not None:
                            merged_final = fb["merged_fb"]
                            used_fallback = True
                            print(f"         fallback: sn1={fb['sn1']:.2f}, "
                                  f"sn2={fb['sn2']:.2f}, sn_mid={fb['sn_mid']:.2f} "
                                  f"-> merged={merged_final}")
                        else:
                            print("         fallback inconclusive, merged stays None")

                    print(f"         result: pattern={cls['pattern']}, "
                          f"merged_sep_rule={cls['merged']}, "
                          f"merged_final={merged_final}, "
                          f"n_det={cls['n_detected_peaks']}, "
                          f"snr_used={snr_used:.2f}")

                    # 6) 保存 PNG
                    png_name = (f"{cat_path.stem}_bin{int(bin_size_pix)}"
                                f"_Rsrel{int(rel_rs)}_{filter_name}.png")
                    png_path = OUT_DIR / "png" / png_name
                    title = (f"{cat_path.name}\n"
                             f"bin={bin_size_pix} pix, Rs_rel={rel_rs}, "
                             f"Rs_pix={Rs_pix:.0f}, filter={filter_name}")

                    plot_diagnostic_map(
                        sn_map,
                        objects,
                        truth_xy_small,
                        png_path,
                        title_main=title,
                        merged_final=merged_final,
                        merged_sep_rule=cls["merged"],
                        detection_error_sep=cls["detection_error"],
                        used_fallback=used_fallback,
                        snr_used=snr_used,
                        n_sep_peaks=n_sep_peaks,
                        Rs_rel=rel_rs,
                        Rs_pix=Rs_pix,
                        bin_size_pix=bin_size_pix,
                        filter_name=filter_name,
                        pattern_sep_rule=cls["pattern"],
                    )

                    # 7) 记录结果
                    row = {
                        "file": cat_path.name,
                        "base_name": base_name_from_catalog(cat_path),
                        "bin_size_pix": bin_size_pix,
                        "Rs_rel": rel_rs,
                        "Rs_pix": Rs_pix,
                        "filter": filter_name,
                        "true_sep_pix_orig": true_sep_orig,
                        "true_sep_pix_small": true_sep_small,
                        "n_true": cls["n_true"],
                        "n_detected_peaks": cls["n_detected_peaks"],
                        "n_sep_peaks": n_sep_peaks,
                        "snr_used": snr_used,
                        "merged_sep_rule": cls["merged"],
                        "merged_final": merged_final,
                        "detection_error_sep": cls["detection_error"],
                        "used_fallback": used_fallback,
                        "pattern_sep_rule": cls["pattern"],
                    }
                    for k, (xt, yt) in enumerate(truth_xy_small[:2], start=1):
                        row[f"c{k}_x_true_small"] = xt
                        row[f"c{k}_y_true_small"] = yt

                    rows.append(row)

                    # 8) 本组合完成 → 写 checkpoint
                    save_checkpoint(cat_path.name, bin_size_pix, rel_rs, filter_name)

    summary = pd.DataFrame(rows)
    out_csv = OUT_DIR / "results_apmap_sep_binRsFilter.csv"
    summary.to_csv(out_csv, index=False)
    print("\nAll done, results saved to:", out_csv)

    # 生成 merge probability heatmap
    plot_merge_probability_heatmaps(summary, OUT_DIR)


if __name__ == "__main__":
    process_all_catalogs()
