In [28]:
%cd /Users/ymxu/Workspace/MuDocU/DTAgent

/Users/ymxu/Workspace/MuDocU/DTAgent


In [37]:
import json
from typing import Any, Dict, List

ALLOWED_STRATEGY = {"direct_label","direct_page","sparse","dense","hybrid","leaf","heading"}
ALLOWED_ROLES = {"section","image","table","paragraph","caption"}
ALLOWED_EXPAND = {"child","same_page","parent","ref"}

POLICY_SYS = (
"你是检索策略规划器。仅返回严格 JSON。"
"根据问题选择一种策略，并尽量填充直达选择器（如有）。"
"JSON 结构："
"{"
"\"strategy\":\"direct_label|direct_page|sparse|dense|hybrid|leaf|heading\","
"\"roles\":[\"section|image|table|paragraph|caption\"],"
"\"selectors\":{\"label\":\"\",\"figure_no\":\"\",\"table_no\":\"\",\"page\":0},"
"\"queries\":{\"dense\":\"\",\"sparse\":\"\"},"
"\"need_leaf\":true|false,"
"\"expand\":{\"types\":[\"child\",\"same_page\",\"parent\",\"ref\"],\"depth\":0},"
"\"topK\":40"
"}"
"准则："
"- 若问题含图/表编号（Figure/Table/图/表 + 序号），用 direct_label 并在 selectors.label/figure_no/table_no 填值。"
"- 若问题含页码（Page/p./第X页），用 direct_page 并填 selectors.page。"
"- 含专有名词/缩写/关键短语明显 → sparse；语义问法不明显 → dense；不确定 → hybrid。"
"- 需要段落/图注级别内容 → need_leaf=true 或 strategy=leaf；只找章节名 → heading。"
"- expand.depth 建议 0~1，types 默认 [\"child\",\"same_page\"]。"
"仅输出 JSON，不要解释。"
)

def validate_policy(obj: Dict[str, Any]) -> Dict[str, Any]:
    if not isinstance(obj, dict):
        raise ValueError("policy must be object")
    strat = obj.get("strategy") or "hybrid"
    if strat not in ALLOWED_STRATEGY:
        strat = "hybrid"
    roles = obj.get("roles") or ["section","image","table"]
    roles = [r for r in roles if r in ALLOWED_ROLES] or ["section","image","table"]
    selectors = obj.get("selectors") or {}
    selectors = {
        "label": str(selectors.get("label") or ""),
        "figure_no": str(selectors.get("figure_no") or ""),
        "table_no": str(selectors.get("table_no") or ""),
        "page": int(selectors.get("page") or 0),
    }
    queries = obj.get("queries") or {}
    queries = {
        "dense": str(queries.get("dense") or ""),
        "sparse": str(queries.get("sparse") or ""),
    }
    need_leaf = bool(obj.get("need_leaf", False))
    expand = obj.get("expand") or {}
    types = expand.get("types") or ["child","same_page"]
    types = [t for t in types if t in ALLOWED_EXPAND] or ["child","same_page"]
    depth = int(expand.get("depth") or 0)
    topK = int(obj.get("topK") or 40)
    return {
        "strategy": strat,
        "roles": roles,
        "selectors": selectors,
        "queries": queries,
        "need_leaf": need_leaf,
        "expand": {"types": types, "depth": max(0, min(2, depth))},
        "topK": max(1, min(200, topK)),
    }

def llm_policy(llm_call, question: str) -> Dict[str, Any]:
    msgs=[{"role":"system","content":POLICY_SYS},{"role":"user","content":question}]
    # 你可替换为自己的 JSON 解析器，这里使用已有工具也可
    from src.agents.planner import default_llm_json
    raw = default_llm_json(msgs, llm_call, max_tokens=240)
    return validate_policy(raw or {})

def pretty_print_policy(p: Dict[str, Any]) -> None:
    print(json.dumps(p, ensure_ascii=False, indent=2))

In [42]:
from src.utils.llm_clients import gpt_llm_call
p = llm_policy(
          lambda messages, json_mode=True, max_tokens=240, **kw: gpt_llm_call(messages, model="gpt-4o-mini", json_mode=True),
          "what is EBITDA  for costco in FY2021?",
        )
pretty_print_policy(p)

{
  "strategy": "sparse",
  "roles": [
    "section",
    "image",
    "table"
  ],
  "selectors": {
    "label": "",
    "figure_no": "",
    "table_no": "",
    "page": 0
  },
  "queries": {
    "dense": "",
    "sparse": "EBITDA Costco FY2021"
  },
  "need_leaf": false,
  "expand": {
    "types": [
      "child",
      "same_page"
    ],
    "depth": 1
  },
  "topK": 40
}


In [None]:
from typing import Any, Dict, List, Set



In [44]:
# 简单 RRF 融合（后续可替换为你们 planner.py 的 fuse）
def _rrf_fuse(dense_cands, sparse_cands, k=60):
    rd = {c.node_id:i+1 for i,c in enumerate(dense_cands)}
    rs = {c.node_id:i+1 for i,c in enumerate(sparse_cands)}
    pool = {}
    for c in dense_cands + sparse_cands:
        pool.setdefault(c.node_id, c)
    fused = []
    for nid, c in pool.items():
        r1, r2 = rd.get(nid, 10**6), rs.get(nid, 10**6)
        score = 1.0/(k+r1) + 1.0/(k+r2)
        fused.append((score, c))
    fused.sort(key=lambda x: x[0], reverse=True)
    return [t[1] for t in fused]

def _roles_or_default(roles: List[str] | Set[str] | None) -> Set[str]:
    base = {"section","image","table"}
    if not roles:
        return base
    r = set(str(x) for x in roles)
    return r if r else base

def execute_policy(policy: Dict[str, Any], R) -> Dict[str, Any]:
    strat = policy.get("strategy") or "hybrid"
    roles = _roles_or_default(policy.get("roles"))
    sel   = policy.get("selectors") or {}
    qd    = (policy.get("queries") or {}).get("dense") or ""
    qs    = (policy.get("queries") or {}).get("sparse") or ""
    topK  = int(policy.get("topK") or 40)

    # 1) 直达类
    if strat == "direct_label":
        label = sel.get("label") or ""
        # 若没有 label 但给了 figure_no/table_no，组装一下
        if not label:
            if sel.get("figure_no"): label = f"Figure {sel['figure_no']}"
            if sel.get("table_no"):  label = f"Table {sel['table_no']}"
        nid = R.idmap_lookup(label) if label else None
        cands = [nid] if nid else []
        return {"mode": strat, "candidates": cands, "policy": policy}

    if strat == "direct_page":
        page = int(sel.get("page") or 0)
        if page <= 0:
            return {"mode": strat, "candidates": [], "policy": policy}
        # 直接用 filter_nodes 按页筛选（跨角色）
        flt = [{"field":"page_idx","op":"=","value":page}]
        out: List[str] = []
        if roles:
            for r in roles:
                out.extend(R.filter_nodes(r, flt))
        else:
            out = R.filter_nodes("", flt)
        # 去重并截断
        uniq = list(dict.fromkeys(out))[:topK]
        return {"mode": strat, "candidates": uniq, "policy": policy}

    # 2) 稀疏/稠密/混合
    if strat == "sparse":
        sp = R.sparse(qs or qd, roles, topK)
        return {"mode": strat, "candidates": [c.node_id for c in sp], "policy": policy}

    if strat == "dense":
        de = R.dense(qd or qs, roles, topK)
        return {"mode": strat, "candidates": [c.node_id for c in de], "policy": policy}

    if strat == "heading":
        sp = R.sparse((qs or qd) + " heading title section", {"section"}, topK)
        return {"mode": strat, "candidates": [c.node_id for c in sp], "policy": policy}

    # 叶子策略：如果索引未包含叶子视图，先退化为混合（roles里包含 paragraph/caption 也不影响）
    if strat == "leaf":
        roles |= {"paragraph","caption"}
        sp = R.sparse(qs or qd, roles, min(200, topK*5))
        de = R.dense(qd or qs, roles, min(200, topK*5))
        fu = _rrf_fuse(de, sp)[:topK]
        return {"mode": strat, "candidates": [c.node_id for c in fu], "policy": policy}

    # 默认 hybrid
    sp = R.sparse(qs or qd, roles, min(200, topK*5))
    de = R.dense(qd or qs, roles, min(200, topK*5))
    fu = _rrf_fuse(de, sp)[:topK]
    return {"mode": "hybrid", "candidates": [c.node_id for c in fu], "policy": policy}


In [48]:
from src.agents.retriever_impl import JsonlRetriever
import numpy as np
import re
def build_hash_encoder(dim: int = 384):
    tok = re.compile(r"[A-Za-z0-9%._-]+")
    def encode(texts):
        X = np.zeros((len(texts), dim), dtype=np.float32)
        for i, s in enumerate(texts):
            for w in tok.findall(s or ""):
                h = hash(w) % dim
                X[i, h] += 1.0
        # L2 归一化
        n = np.linalg.norm(X, axis=1, keepdims=True) + 1e-8
        X = X / n
        return X.astype(np.float32)
    return encode

INDEX_DIR = "./../../../data/users/yiming/dtagent/MinerU_25_MMLB/8e7c4cb542ad160f80fb3d795ada35d8/indexes"   # 或者 indexes
encode_fn = build_hash_encoder(dim=384)
R = JsonlRetriever(INDEX_DIR, encode_fn)

In [49]:
def inspect_candidates(R, node_ids: List[str], k: int = 10,
                        fields: Tuple[str,...] = ("label","parent_title","chart_type","figure_no","table_no","level")):
    for nid in node_ids[:k]:
        pg = R.page_of(nid)
        # 先用公开接口取常见 filters
        info = {"nid": nid, "p": pg}
        for key in fields:
            v = R.get_attr(nid, key)
            if v is not None:
                info[key] = v
        # 角色/标题做个软回退（不影响失败）
        role = None; title = None
        try:
            d = getattr(R, "_dense_by_id", {}).get(nid)
        except Exception:
            d = None
        if not d:
            for sd in getattr(R, "sparse_docs", []):
                if sd.get("id") == nid:
                    d = sd; break
        if d:
            role = d.get("role")
            title = d.get("title")
        # 打印
        parts = [info["nid"], f"p={info['p']}"]
        if role: parts.append(f"role={role}")
        if "label" in info: parts.append(f"label={info['label']}")
        if "parent_title" in info: parts.append(f"parent={info['parent_title']}")
        if title and role == "section": parts.append(f"title={title}")
        print("  ".join(parts))

In [53]:
p = llm_policy(
          lambda messages, json_mode=True, max_tokens=240, **kw: gpt_llm_call(messages, model="gpt-4o-mini", json_mode=True),
          "What is the residential capacity of Staten Island from 2003 to 2007? Give me an integer.",
        )
pretty_print_policy(p)

res = execute_policy(p, R)
print(res["mode"], len(res["candidates"]))
inspect_candidates(R, res["candidates"], 5)

{
  "strategy": "sparse",
  "roles": [
    "paragraph"
  ],
  "selectors": {
    "label": "",
    "figure_no": "",
    "table_no": "",
    "page": 0
  },
  "queries": {
    "dense": "",
    "sparse": "residential capacity Staten Island 2003 to 2007"
  },
  "need_leaf": true,
  "expand": {
    "types": [
      "child",
      "same_page"
    ],
    "depth": 1
  },
  "topK": 40
}
sparse 0
