# Error-aware Negative-enhanced Ranking (ENR) framework for Biological Process concept in Gene Ontology (GO) and Phenotypic Abnormality concept in Human Phenotype Ontology (HPO) recognition — Offline Inference Notebook

This notebook runs a **bi-encoder retrieval + cross-encoder reranker** pipeline on local text.

You can:
- Tag a **single passage** (a string), or
- Tag **many passages from a file** (txt/json/jsonl)

It outputs predicted **GO-Biological Process** and/or **HPO-Phenotypic Abnormality** concepts per passage.

---

## What you need

1. **Hugging Face model repos**
   - Bi-encoder checkpoint (HF `AutoModel`) used for retrieval
   - Cross-encoder checkpoint (HF `AutoModelForSequenceClassification`, `num_labels=1`) used for rerank
2. **Concept catalogs** (GO and/or HPO) with `id` + `name`
   - JSON dict: `{ "GO_...": {"name":"..."}, ... }`
   - JSON list: `[{"id":"...", "name":"..."}, ...]`
   - JSONL: each line `{"id":"...", "name":"..."}`

> Tip:  You can get catalog files refer to `docs/data.md`. The `data/mm-go/meta/biological_process_concept.json` and `data/mm-hpo/meta/phenotypic_abnormality_concept.json` are valid catalog files.

---

## Notes on speed

- GO/HPO catalogs can be large. We compute and cache concept embeddings once per session.
- If `faiss` is available, retrieval is faster. Notebook will fall back to a pure PyTorch implementation otherwise. We do not use `faiss` in this notebook to avoid FAISS segfault.


In [None]:
# If running in a fresh environment, uncomment:
# !pip -q install transformers accelerate sentencepiece huggingface_hub tqdm

# Optional (recommended for large catalogs):
# !pip -q install faiss-cpu


In [None]:
import os
import json
import unicodedata
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification


## 1) Configure models, catalogs, and inference parameters

Fill in your own Hugging Face repo IDs (or local paths).  
You may set either GO, HPO, or both.


In [None]:
# -----------------------------
# User configuration
# -----------------------------

# Bi-encoder (retrieval) checkpoints
GO_BIENCODER = "Samantha633/enr-recognizer-biological-process-retriever"     
HPO_BIENCODER = "Samantha633/enr-recognizer-phenotypic-abnormality-retriever"   

# Cross-encoder (rerank) checkpoints
GO_CROSSENCODER = "Samantha633/enr-recognizer-biological-process-reranker"     
HPO_CROSSENCODER = "Samantha633/enr-recognizer-phenotypic-abnormality-reranker"   

# Concept catalogs (local paths OR files you download beforehand)
GO_CATALOG_PATH = "data/mm-go/meta/biological_process_concept.json"       # json / jsonl
HPO_CATALOG_PATH = "data/mm-hpo/meta/phenotypic_abnormality_concept.json"     # json / jsonl

# Retrieval / rerank params
RETRIEVE_TOPK = 100            # candidates from bi-encoder
RERANK_TOPK = 100              # rerank up to this many
FINAL_TOPN = 20                # return top-N predictions per passage

# Optional: apply a fixed threshold after reranking (if None, use FINAL_TOPN only)
GO_SCORE_THRESHOLD = 7.3111958329057485       # e.g. 7.3111958329057485
HPO_SCORE_THRESHOLD = 5.729438811082113      # e.g. 5.729438811082113

# Tokenization / batching
Q_MAX_LEN = 448
T_MAX_LEN = 64
ENCODE_BS = 64                 # embedding batches
RERANK_BS = 64                 # rerank batches

# Compute device
GPU_ID = 0
DEVICE = torch.device(f"cuda:{GPU_ID}" if torch.cuda.is_available() else "cpu")
DEVICE


## 2) Utilities: IO, normalization, catalog loading

In [None]:
def load_json(path: Union[str, Path]) -> Any:
    path = Path(path)
    with path.open("r", encoding="utf-8") as f:
        return json.load(f)


def load_jsonl(path: Union[str, Path]) -> List[dict]:
    path = Path(path)
    rows: List[dict] = []
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            rows.append(json.loads(line))
    return rows


@dataclass(frozen=True)
class Concept:
    cid: str
    name: str


def load_concept_catalog(path: Union[str, Path]) -> List[Concept]:
    """Load concept catalog with concept_id + name.

    Supports:
      - JSONL: each line {"id":..., "name":...}
      - JSON dict: {id: {"name":...}, ...}
      - JSON list: [{"id":..., "name":...}, ...]
    """
    path = Path(path)
    concepts: List[Concept] = []
    seen: set[str] = set()

    if path.suffix.lower() == ".jsonl":
        for r in load_jsonl(path):
            if not isinstance(r, dict):
                continue
            cid = r.get("id") or r.get("concept_id") or r.get("term_id")
            name = r.get("name") or r.get("term_text") or r.get("label")
            if cid is None or name is None:
                continue
            cid = str(cid).strip()
            name = str(name).strip()
            if cid and name and cid not in seen:
                seen.add(cid)
                concepts.append(Concept(cid=cid, name=name))
        return concepts

    obj = load_json(path)
    if isinstance(obj, dict):
        for k, v in obj.items():
            cid = str(k).strip()
            if not cid or cid in seen:
                continue
            name = None
            if isinstance(v, dict):
                name = v.get("name") or v.get("term_text") or v.get("label")
            if name is None:
                continue
            name = str(name).strip()
            if name:
                seen.add(cid)
                concepts.append(Concept(cid=cid, name=name))
        return concepts

    if isinstance(obj, list):
        for r in obj:
            if not isinstance(r, dict):
                continue
            cid = r.get("id") or r.get("concept_id") or r.get("term_id")
            name = r.get("name") or r.get("term_text") or r.get("label")
            if cid is None or name is None:
                continue
            cid = str(cid).strip()
            name = str(name).strip()
            if cid and name and cid not in seen:
                seen.add(cid)
                concepts.append(Concept(cid=cid, name=name))
        return concepts

    raise ValueError(f"Unsupported catalog format: {path}")


## 3) Utilities: bi-encoder embeddings + retrieval (FAISS optional)

In [None]:
def mean_pooling(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
    mask = attention_mask.unsqueeze(-1)  # [B, L, 1]
    summed = (last_hidden_state * mask).sum(dim=1)  # [B, H]
    denom = mask.sum(dim=1).clamp(min=1e-9)  # [B, 1]
    return summed / denom


@torch.no_grad()
def encode_texts(
    model,
    tokenizer,
    texts: Sequence[str],
    max_len: int,
    device: torch.device,
    batch_size: int,
) -> torch.Tensor:
    """Return L2-normalized embeddings on CPU float32."""
    vecs: List[torch.Tensor] = []
    for i in tqdm(range(0, len(texts), batch_size), desc="Encoding", leave=False):
        chunk = list(texts[i : i + batch_size])
        enc = tokenizer(chunk, padding=True, truncation=True, max_length=max_len, return_tensors="pt")
        enc = {k: v.to(device, non_blocking=True) for k, v in enc.items()}
        out = model(**enc).last_hidden_state
        pooled = mean_pooling(out, enc["attention_mask"])
        pooled = F.normalize(pooled, dim=-1)
        vecs.append(pooled.detach().cpu())
    return torch.cat(vecs, dim=0).float() if vecs else torch.empty(0)


def _try_build_faiss_index(vectors: np.ndarray, use_gpu: bool, gpu_id: int):
    try:
        import faiss  # type: ignore
    except Exception:
        return None, None

    dim = vectors.shape[1]
    cpu_index = faiss.IndexFlatIP(dim)
    cpu_index.add(vectors)

    if not use_gpu:
        return faiss, cpu_index

    if not hasattr(faiss, "StandardGpuResources"):
        return faiss, cpu_index

    try:
        res = faiss.StandardGpuResources()
        gpu_index = faiss.index_cpu_to_gpu(res, gpu_id, cpu_index)
        return faiss, gpu_index
    except Exception:
        return faiss, cpu_index


def retrieve_topk(
    term_vecs: torch.Tensor,   # [Nt, D] CPU
    query_vecs: torch.Tensor,  # [Nq, D] CPU
    topk: int,
    use_faiss: bool = True,
    use_faiss_gpu: bool = False,
    faiss_gpu_id: int = 0,
) -> Tuple[np.ndarray, np.ndarray]:
    """Return (scores, indices) with shape [Nq, topk]."""
    term_np = term_vecs.numpy().astype(np.float32, copy=False)
    query_np = query_vecs.numpy().astype(np.float32, copy=False)

    if use_faiss:
        try:
            faiss_mod, index = _try_build_faiss_index(term_np, use_gpu=use_faiss_gpu, gpu_id=faiss_gpu_id)
            if index is not None:
                # 可選：限制 faiss thread，降低 native crash 機率
                if faiss_mod is not None and hasattr(faiss_mod, "omp_set_num_threads"):
                    faiss_mod.omp_set_num_threads(1)

                scores, indices = index.search(query_np, int(topk))
                return scores, indices
        except Exception as e:
            print(f"[FAISS] failed -> fallback to torch. err={repr(e)}")

    sims = torch.from_numpy(query_np) @ torch.from_numpy(term_np).t()
    vals, idxs = torch.topk(sims, k=min(int(topk), sims.size(1)), dim=1, largest=True, sorted=True)
    return vals.numpy(), idxs.numpy()


## 4) Utilities: cross-encoder rerank

We score pairs `(term_text, query_text)` (same direction as your eval scripts).


In [None]:
@torch.no_grad()
def rerank_with_cross_encoder(
    tokenizer,
    model,
    query_texts: List[str],
    candidate_texts: List[List[str]],
    max_len: int,
    batch_size: int,
    device: torch.device,
) -> List[np.ndarray]:
    """Return per-query score arrays aligned with candidate_texts."""
    model.eval()
    out_scores: List[np.ndarray] = []

    for q, cands in tqdm(list(zip(query_texts, candidate_texts)), desc="Reranking"):
        if not cands:
            out_scores.append(np.zeros((0,), dtype=np.float32))
            continue

        scores_chunks: List[np.ndarray] = []
        for i in range(0, len(cands), batch_size):
            chunk = cands[i : i + batch_size]
            enc = tokenizer(
                chunk,
                [q] * len(chunk),
                padding=True,
                truncation=True,
                max_length=max_len,
                return_tensors="pt",
            )
            enc = {k: v.to(device, non_blocking=True) for k, v in enc.items()}
            logits = model(**enc).logits.squeeze(-1)
            scores_chunks.append(logits.detach().cpu().numpy())
        out_scores.append(np.concatenate(scores_chunks, axis=0) if scores_chunks else np.zeros((0,), dtype=np.float32))

    return out_scores


## 5) Build a reusable tagger class (GO or HPO)

In [None]:
@dataclass
class TaggerConfig:
    biencoder_name_or_path: str
    crossencoder_name_or_path: Optional[str]
    catalog_path: str
    retrieve_topk: int = 100
    rerank_topk: int = 100
    final_topn: int = 20
    score_threshold: Optional[float] = None
    q_max_len: int = 448
    t_max_len: int = 64
    encode_bs: int = 64
    rerank_bs: int = 64
    use_faiss: bool = True
    faiss_gpu: bool = False
    gpu_id: int = 0


class ConceptTagger:
    def __init__(self, cfg: TaggerConfig, device: torch.device):
        self.cfg = cfg
        self.device = device

        # Load catalog
        self.concepts: List[Concept] = load_concept_catalog(cfg.catalog_path)
        if not self.concepts:
            raise ValueError(f"Empty catalog: {cfg.catalog_path}")
        self.term_texts = [c.name for c in self.concepts]

        # Load bi-encoder
        self.bi_tok = AutoTokenizer.from_pretrained(cfg.biencoder_name_or_path, use_fast=True)
        self.bi_model = AutoModel.from_pretrained(cfg.biencoder_name_or_path).to(device)
        self.bi_model.eval()

        # Load cross-encoder (optional)
        self.ce_tok = None
        self.ce_model = None
        if cfg.crossencoder_name_or_path:
            self.ce_tok = AutoTokenizer.from_pretrained(cfg.crossencoder_name_or_path, use_fast=True)
            self.ce_model = AutoModelForSequenceClassification.from_pretrained(cfg.crossencoder_name_or_path).to(device)
            self.ce_model.eval()

        # Cache for concept embeddings (CPU tensor)
        self._term_vecs: Optional[torch.Tensor] = None

    def build_index(self) -> None:
        """Compute and cache term embeddings (CPU)."""
        if self._term_vecs is not None:
            return
        print(f"[Index] Encoding {len(self.term_texts):,} concepts...")
        self._term_vecs = encode_texts(
            self.bi_model, self.bi_tok,
            self.term_texts,
            max_len=self.cfg.t_max_len,
            device=self.device,
            batch_size=self.cfg.encode_bs,
        )
        print("[Index] Done.")

    def _retrieve(self, query_texts: List[str]) -> Tuple[np.ndarray, np.ndarray]:
        self.build_index()
        assert self._term_vecs is not None

        print(f"[Retrieve] Encoding {len(query_texts)} queries...")
        q_vecs = encode_texts(
            self.bi_model, self.bi_tok,
            query_texts,
            max_len=self.cfg.q_max_len,
            device=self.device,
            batch_size=self.cfg.encode_bs,
        )
        scores, indices = retrieve_topk(
            term_vecs=self._term_vecs,
            query_vecs=q_vecs,
            topk=self.cfg.retrieve_topk,
            use_faiss=self.cfg.use_faiss,
            use_faiss_gpu=(self.cfg.faiss_gpu and torch.cuda.is_available()),
            faiss_gpu_id=self.cfg.gpu_id,
        )
        return scores, indices

    def predict(self, passages: List[str]) -> List[dict]:
        """Return per-passage predictions with term_id and term_text.

        Output per passage:
          {"query_text": "...",
           "items": [{"rank":1,"score":...,"term_id":"...","term_text":"..."}, ...]}
        Score is rerank score if cross-encoder present else bi-encoder similarity.
        """
        passages = [p for p in passages if isinstance(p, str) and p.strip()]
        if not passages:
            return []

        # 1) bi-encoder retrieval
        ret_scores, ret_indices = self._retrieve(passages)

        # Build candidate lists
        cand_texts: List[List[str]] = []
        cand_ids: List[List[str]] = []
        for idx_row in ret_indices:
            ids = []
            texts = []
            for j in idx_row.tolist():
                j = int(j)
                if j < 0 or j >= len(self.concepts):
                    continue
                c = self.concepts[j]
                ids.append(c.cid)
                texts.append(c.name)
            cand_ids.append(ids[: self.cfg.rerank_topk])
            cand_texts.append(texts[: self.cfg.rerank_topk])

        # 2) rerank (optional)
        if self.ce_model is not None and self.ce_tok is not None:
            ce_scores = rerank_with_cross_encoder(
                tokenizer=self.ce_tok,
                model=self.ce_model,
                query_texts=passages,
                candidate_texts=cand_texts,
                max_len=self.cfg.q_max_len,
                batch_size=self.cfg.rerank_bs,
                device=self.device,
            )
            final_scores = ce_scores
        else:
            final_scores = [np.asarray(s[: self.cfg.rerank_topk], dtype=np.float32) for s in ret_scores]

        # 3) sort + filter
        outputs: List[dict] = []
        for q, ids, texts, scores in zip(passages, cand_ids, cand_texts, final_scores):
            if len(scores) != len(texts):
                m = min(len(scores), len(texts))
                ids, texts, scores = ids[:m], texts[:m], scores[:m]

            order = np.argsort(-scores)
            items = []
            for rank1, ii in enumerate(order.tolist(), start=1):
                items.append({
                    "rank": rank1,
                    "score": float(scores[ii]),
                    "term_id": ids[ii],
                    "term_text": texts[ii],
                })

            if self.cfg.score_threshold is not None:
                items = [it for it in items if float(it["score"]) >= float(self.cfg.score_threshold)]

            if self.cfg.final_topn is not None and self.cfg.final_topn > 0:
                items = items[: int(self.cfg.final_topn)]

            outputs.append({"query_text": q, "items": items})

        return outputs


## 6) Initialize taggers (GO / HPO)

## 7) Input: single passage OR a file with multiple passages

Supported file formats:
- `.txt`: one passage per non-empty line
- `.jsonl`: each line a dict with key `passage` or `query_text` or `text`
- `.json`: list of strings or list of dicts with `passage`/`query_text`/`text`


In [None]:
def load_passages_from_file(path: Union[str, Path], limit: Optional[int] = None) -> List[str]:
    path = Path(path)
    if not path.exists():
        raise FileNotFoundError(path)

    passages: List[str] = []

    if path.suffix.lower() == ".txt":
        with path.open("r", encoding="utf-8") as f:
            for line in f:
                s = line.strip()
                if s:
                    passages.append(s)
                    if limit is not None and len(passages) >= limit:
                        break
        return passages

    if path.suffix.lower() == ".jsonl":
        for r in load_jsonl(path):
            if not isinstance(r, dict):
                continue
            s = r.get("passage") or r.get("query_text") or r.get("text")
            if isinstance(s, str) and s.strip():
                passages.append(s.strip())
                if limit is not None and len(passages) >= limit:
                    break
        return passages

    if path.suffix.lower() == ".json":
        obj = load_json(path)
        if isinstance(obj, list):
            for r in obj:
                if isinstance(r, str) and r.strip():
                    passages.append(r.strip())
                elif isinstance(r, dict):
                    s = r.get("passage") or r.get("query_text") or r.get("text")
                    if isinstance(s, str) and s.strip():
                        passages.append(s.strip())
                if limit is not None and len(passages) >= limit:
                    break
            return passages

    raise ValueError(f"Unsupported file format: {path}")


# Example: single passage
passages = [
    "Nonylphenol and short-chain nonylphenol ethoxylates such as NP2 EO are present in aquatic environment as wastewater contaminants, and their toxic effects on aquatic species have been reported. Apoptosis has been shown to be induced by serum deprivation or copper treatment. To understand the toxicity of nonylphenol diethoxylate, we investigated the effects of NP2 EO on apoptosis induced by serum deprivation and copper by using PC12 cell system. Nonylphenol diethoxylate itself showed no toxicity and recovered cell viability from apoptosis. In addition, nonylphenol diethoxylate decreased DNA fragmentation caused by apoptosis in PC12 cells. This phenomenon was confirmed after treating apoptotic PC12 cells with nonylphenol diethoxylate, whereas the cytochrome c release into the cytosol decreased as compared to that in apoptotic cells not treated with nonylphenol diethoxylate s. Furthermore, Bax contents in apoptotic cells were reduced after exposure to nonylphenol diethoxylate. Thus, nonylphenol diethoxylate has the opposite effect on apoptosis in PC12 cells compared to nonylphenol, which enhances apoptosis induced by serum deprivation. The difference in structure of the two compounds is hypothesized to be responsible for this phenomenon. These results indicated that nonylphenol diethoxylate has capability to affect cell differentiation and development and has potentially harmful effect on organisms because of its unexpected impact on apoptosis. © 2015 Wiley Periodicals, Inc. Environ Toxicol 31: 1389-1398, 2016."
]

# Example: load from file
# passages = load_passages_from_file("/path/to/passages.txt", limit=100)

len(passages), passages[0][:80]


## 8) Run inference (GO / HPO)

In [None]:
def run_inference(passages: List[str], do_go: bool = True, do_hpo: bool = True) -> Dict[str, List[dict]]:
    out: Dict[str, List[dict]] = {}
    if do_go:
        if go_tagger is None:
            raise RuntimeError("GO tagger is not initialized. Check GO_* config and catalog path.")
        out["go"] = go_tagger.predict(passages)

    if do_hpo:
        if hpo_tagger is None:
            raise RuntimeError("HPO tagger is not initialized. Check HPO_* config and catalog path.")
        out["hpo"] = hpo_tagger.predict(passages)

    return out


results = run_inference(passages, do_go=(go_tagger is not None), do_hpo=(hpo_tagger is not None))
results.keys()


## 9) Pretty-print results for the first passage

In [None]:
def print_top(results: Dict[str, List[dict]], idx: int = 0, topn: int = 10) -> None:
    for space in ["go", "hpo"]:
        if space not in results:
            continue
        ex = results[space][idx]
        print(f"\n=== {space.upper()} ===")
        print("Passage:", ex["query_text"][:200] + ("..." if len(ex["query_text"]) > 200 else ""))
        for it in ex["items"][:topn]:
            print(f"  {it['rank']:>2}  {it['score']:.4f}  {it['term_id']}  {it['term_text']}")


print_top(results, idx=0, topn=10)


## 10) Save outputs to JSONL

We save:
- `pred_go.jsonl` and/or `pred_hpo.jsonl`

Each line:
```json
{"query_text": "...", "items": [{"rank":1,"score":...,"term_id":"...","term_text":"..."}, ...]}
```


In [None]:
def write_jsonl(path: Union[str, Path], rows: Iterable[dict]) -> None:
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8") as f:
        for r in rows:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")


out_dir = Path("./inference_outputs")
out_dir.mkdir(parents=True, exist_ok=True)

if "go" in results:
    write_jsonl(out_dir / "pred_go.jsonl", results["go"])
if "hpo" in results:
    write_jsonl(out_dir / "pred_hpo.jsonl", results["hpo"])

list(out_dir.glob("*.jsonl"))


## Optional: combined view (GO + HPO per passage)

This is sometimes nicer for downstream UI/inspection.


In [None]:
def merge_go_hpo(go_rows: Optional[List[dict]], hpo_rows: Optional[List[dict]]) -> List[dict]:
    n = 0
    if go_rows is not None:
        n = max(n, len(go_rows))
    if hpo_rows is not None:
        n = max(n, len(hpo_rows))

    out: List[dict] = []
    for i in range(n):
        rec: Dict[str, Any] = {}
        if go_rows is not None:
            rec["query_text"] = go_rows[i]["query_text"]
            rec["go_items"] = go_rows[i]["items"]
        if hpo_rows is not None:
            rec.setdefault("query_text", hpo_rows[i]["query_text"])
            rec["hpo_items"] = hpo_rows[i]["items"]
        out.append(rec)
    return out


combined = merge_go_hpo(results.get("go"), results.get("hpo"))
combined[0]
