In [1]:
import sys, os
sys.path.insert(0, r"d:\claimpkg\claimpkg-clone")
from src.utils.sim import Similarity

sim = Similarity()

In [None]:
from heapq import nlargest
from typing import Callable, Dict, List, Tuple

def score(
    candidate_entity: str,
    explicit_entities: List[str],
    pseudo_relations: List[str],
    KG: Dict[str, List[Tuple[str, str]]],
    sim_func: Callable[[str, str], float],
    normalize: bool = True
) -> float:
    """
    Compute the semantic matching score of a candidate entity `candidate_entity`
    for resolving an unknown node in a pseudo-subgraph, following Eq. (5) in ClaimPKG.

    The score is computed by summing the similarity between the pseudo-relations
    (relations connected to the unknown entity) and the actual relations in the KG
    that link the candidate to explicit entities.

    Parameters
    ----------
    candidate_entity : str
        The entity in the KG being evaluated as a possible replacement for an unknown node.
    explicit_entities : List[str]
        A list of known (explicit) entities connected to the unknown in the pseudo-subgraph.
    pseudo_relations : List[str]
        The corresponding relations between the unknown and each explicit entity.
    KG : Dict[str, List[Tuple[str, str]]]
        The knowledge graph, represented as a dictionary:
        { head_entity: [(relation, tail_entity), ...], ... }.
    sim_func : Callable[[str, str], float]
        Function computing similarity between two relation strings (e.g., embedding cosine similarity).
    normalize : bool, optional
        Whether to normalize the final score by number of relations, default=True.

    Returns
    -------
    float
        The cumulative similarity score representing how well the candidate matches
        the pseudo-relations and connects to the explicit entities.
    """
    total_score = 0.0
    match_count = 0

    for e_ui, r_ui in zip(explicit_entities, pseudo_relations):
        kg_edges = KG.get(e_ui, [])
        for r, tail in kg_edges:
            if tail == candidate_entity:
                sim_val = sim_func(r_ui, r)
                total_score += sim_val
                match_count += 1

    if normalize and match_count > 0:
        total_score /= match_count

    return total_score



def rank_candidates(
    candidate_sets: List[List[str]],
    explicit_entities: List[str],
    pseudo_relations: List[str],
    KG: Dict[str, List[Tuple[str, str]]],
    sim_func: Callable[[str, str], float],
    k1: int = 3,
    normalize: bool = True,
    aggregate: str = "max"
) -> List[Tuple[str, float]]:
    """
    Rank candidate entities based on their relevance to the unknown entity group
    using the scoring mechanism defined in Eq. (5)-(6) of ClaimPKG.

    Parameters
    ----------
    candidate_sets : List[List[str]]
        A list of candidate lists, each corresponding to one explicit entity e_ui.
        For example: [[cand1, cand2], [cand3, cand4]].
    explicit_entities : List[str]
        Entities directly connected to the unknown entity in the pseudo-subgraph.
    pseudo_relations : List[str]
        Relations corresponding to each explicit entity.
    KG : Dict[str, List[Tuple[str, str]]]
        The knowledge graph data structure.
    sim_func : Callable[[str, str], float]
        Function measuring similarity between two relations.
    k1 : int, optional
        Number of top candidates to select (default=3).
    normalize : bool, optional
        Whether to normalize the score of each candidate (default=True).
    aggregate : str, optional
        Aggregation strategy for merging candidate scores from multiple sets.
        Options:
        - "max"  : keep the maximum score per entity
        - "mean" : average over occurrences
        - "sum"  : sum over occurrences

    Returns
    -------
    List[Tuple[str, float]]
        A list of tuples (candidate_entity, score), sorted descending by score.
    """
    scored = {}

    # Evaluate all candidates from each set
    for candidates in candidate_sets:
        for c in candidates:
            s = score(c, explicit_entities, pseudo_relations, KG, sim_func, normalize)
            if c not in scored:
                scored[c] = [s]
            else:
                scored[c].append(s)

    # Aggregate scores from multiple occurrences
    aggregated_scores = {}
    for c, vals in scored.items():
        if aggregate == "max":
            aggregated_scores[c] = max(vals)
        elif aggregate == "mean":
            aggregated_scores[c] = sum(vals) / len(vals)
        elif aggregate == "sum":
            aggregated_scores[c] = sum(vals)
        else:
            raise ValueError(f"Unknown aggregate mode: {aggregate}")

    # Select top-k1 highest scoring candidates
    topk = nlargest(k1, aggregated_scores.items(), key=lambda x: x[1])

    return topk

# Giả lập KG nhỏ
KG = {
    "Ho Chi Minh": [
        ("birthPlace", "Nghe An"),
        ("leaderTitle", "President"),
        ("leaderCountry", "Vietnam")
    ],
    "Vo Nguyen Giap": [
        ("birthPlace", "Quang Binh"),
        ("leaderTitle", "General"),
        ("leaderCountry", "Vietnam")
    ],
    "Ngo Dinh Diem": [
        ("birthPlace", "Quang Binh"),
        ("leaderTitle", "President"),
        ("leaderCountry", "South Vietnam")
    ]
}

# Group 1: unknown_0  → tìm entity là birthplace
explicit_entities_0 = ["Ho Chi Minh"]
pseudo_relations_0  = ["birth place"]
candidate_sets_0 = [["Nghe An", "Quang Binh", "Hanoi"]]

# Group 2: unknown_1 → tìm entity là title/country của lãnh đạo
explicit_entities_1 = ["Ho Chi Minh", "Vietnam"]
pseudo_relations_1  = ["leader title", "leader country"]
candidate_sets_1 = [["President", "General"], ["Vietnam", "South Vietnam"]]

# Unknown 0
topk_0 = rank_candidates(
    candidate_sets_0, explicit_entities_0, pseudo_relations_0, KG, sim.sim, k1=3
)

# Unknown 1
topk_1 = rank_candidates(
    candidate_sets_1, explicit_entities_1, pseudo_relations_1, KG, sim.sim, k1=3
)

print("Top candidates for unknown_0 (birth place):")
for c, s in topk_0:
    print(f"  {c:<12}  score = {s:.3f}")

print("\nTop candidates for unknown_1 (leader info):")
for c, s in topk_1:
    print(f"  {c:<12}  score = {s:.3f}")

Top candidates for unknown_0 (birth place):
  Nghe An       score = 0.826
  Quang Binh    score = 0.000
  Hanoi         score = 0.000

Top candidates for unknown_1 (leader info):
  President     score = 0.960
  Vietnam       score = 0.813
  General       score = 0.000
