In [2]:
import os
import sys
sys.path.insert(0, '..')


In [3]:
import re
from typing import Dict, Any, List, Tuple

remove_underscore = True

def _norm(s):
    if s is None:
        return s
    return s.replace("_", " ") if remove_underscore else s

def _new_unknown(unk_list):
    u = f"unknown_{len(unk_list)}"
    unk_list.append(u)
    return u

def generate(claim: str, sample: Dict[str, Any]):
    # Normalize entities by adding the underscore
    entity_set = [_norm(e) for e in sample.get("Entity_set", [])]

    # Get evidences
    evidence = sample.get("Evidence", {})

    # Normalize keys
    evidence_map = { _norm(k): v for k, v in evidence.items() }

    # --------------------------------------------------------------
    # STEP 1 — Build direct matches (A --rel--> B)
    # --------------------------------------------------------------
    direct_edges = []

    for e1, groups1 in evidence_map.items():
        for rels1 in groups1:
            for r in rels1:
                inverse = r.startswith("~")
                rel = r[1:] if inverse else r

                # CASE 1: forward r → look for entity with ~r
                if not inverse:
                    for e2, groups2 in evidence_map.items():
                        if e1 == e2: continue
                        for rels2 in groups2:
                            if f"~{rel}" in rels2:
                                direct_edges.append((e1, rel, e2))

                # CASE 2: inverse ~r → look for entity with forward r
                else:
                    for e2, groups2 in evidence_map.items():
                        if e1 == e2: continue
                        for rels2 in groups2:
                            if rel in rels2:
                                direct_edges.append((e2, rel, e1))

    # --------------------------------------------------------------
    # STEP 2 — Now build multi-hop paths (using unknown_i)
    # --------------------------------------------------------------
    unknown_list = []
    triplets = []

    # dictionary to quickly check for direct pairs
    direct_set = set((h, r, t) for h, r, t in direct_edges)

    # Helper to detect if last hop of unit should link directly
    def find_direct_target(rel, exclude):
        for h, r, t in direct_edges:
            if r == rel:
                if h != exclude:
                    return t
        return None

    for ent, rel_groups in evidence_map.items():
        for rel_path in rel_groups:

            # If path has only 1 rel, and direct match exists → already captured
            if len(rel_path) == 1:
                r = rel_path[0]
                inv = r.startswith("~")
                rel = r[1:] if inv else r

                # direct already handled → skip
                continue

            # Multi-hop path → must use unknowns
            curr = ent
            for i, r in enumerate(rel_path):
                inv = r.startswith("~")
                rel = r[1:] if inv else r

                # last hop? try direct match
                if i == len(rel_path) - 1:
                    target = find_direct_target(rel, curr)
                    if target is None:
                        target = _new_unknown(unknown_list)
                else:
                    # internal hop → always unknown
                    target = _new_unknown(unknown_list)

                # build triplet
                if inv:
                    triplets.append((target, rel, curr))
                else:
                    triplets.append((curr, rel, target))

                curr = target

    # --------------------------------------------------------------
    # STEP 3 — Combine direct + multi-hop, deduplicate
    # --------------------------------------------------------------
    triplets.extend(direct_edges)
    final = list(dict.fromkeys(triplets))

    sample["triplet"] = final
    return sample


# --------------------------------------------------------------
def linearize(triplets: List[Tuple[str, str, str]]) -> str:
    return "\n".join(f"<e>{h}</e> || {r} || <e>{t}</e>" for h, r, t in triplets)



def process_data(data: dict, remove_underscore: bool = True) -> Tuple[Dict, List]:
    from tqdm import tqdm

    updated_data = {}
    distinct_entities = set()
    keys = list(data.keys())

    for key in tqdm(keys, desc="Processing data"):
        updated = generate(key, data[key])
        updated_data[key] = updated

        for h, r, t in updated["triplet"]:
            distinct_entities.add(h)
            distinct_entities.add(t)

    return updated_data, list(distinct_entities)


In [4]:
DATA_DIR = 'resources'
# Data dir = (1) working directory, (2) move out of test, (3) move out of src, and append to resources
DATA_DIR = os.path.join(os.getcwd(), '..', 'resources')
print("Data Directory:", DATA_DIR)

TRAIN_FILE = 'factkg_train_5k.pickle'
TEST_FILE = 'factkg_test_1k.pickle'
VALID_FILE = 'factkg_val_300.pickle'

TRAIN_FILE_PATH = os.path.join(DATA_DIR, TRAIN_FILE)
TEST_FILE_PATH = os.path.join(DATA_DIR, TEST_FILE)
VALID_FILE_PATH = os.path.join(DATA_DIR, VALID_FILE)

import pickle

train_data = None
test_data = None
valid_data = None
with open(TRAIN_FILE_PATH, 'rb') as f:
    train_data = pickle.load(f)
with open(TEST_FILE_PATH, 'rb') as f:
    test_data = pickle.load(f)
with open(VALID_FILE_PATH, 'rb') as f:
    valid_data = pickle.load(f)

train_updated_data, train_distinct_entities = process_data(train_data, remove_underscore=True)

test_updated_data, test_distinct_entities = process_data(test_data, remove_underscore=True)

valid_updated_data, valid_distinct_entities = process_data(valid_data, remove_underscore=True)

Data Directory: d:\claimpkg\claimpkg-clone\src\notebooks\..\resources


Processing data: 100%|██████████| 5000/5000 [00:00<00:00, 95791.42it/s]
Processing data: 100%|██████████| 1000/1000 [00:00<00:00, 93111.57it/s]
Processing data: 100%|██████████| 300/300 [00:00<00:00, 54337.40it/s]


In [5]:
# Concat the 3 set train, test, valid
concat_data = {}
concat_data.update(train_updated_data)
concat_data.update(test_updated_data)
concat_data.update(valid_updated_data)

print("Total number of items in concatenated data:", len(concat_data))

key_list = list(concat_data.keys())

Total number of items in concatenated data: 6300


In [6]:
# Get all types items from train, test, and valid datasets
types = set()
for data in [train_updated_data, test_updated_data, valid_updated_data]:
    for item in data:
        for element in data[item]['types']:
            types.add(element)
types = list(types)
print("Distinct types found:", len(types))
types

Distinct types found: 13


['coll:presup',
 'coll:model',
 'question',
 'existence',
 'num4',
 'written',
 'num3',
 'num2',
 'multi hop',
 'negation',
 'substitution',
 'num1',
 'multi claim']

# Explain

| Type           | Category            | Ý nghĩa                                |
| -------------- | ------------------- | -------------------------------------- |
| `written`      | claim style         | Văn phong tự nhiên                     |
| `coll:model`   | claim style         | Văn nói do model sinh                  |
| `coll:presup`  | claim style         | Dạng câu hỏi giả định (presupposition) |
| `num1`         | reasoning           | One-hop                                |
| `multi claim`  | reasoning           | Contains multiple facts                |
| `existence`    | reasoning           | Hỏi về sự tồn tại                      |
| `multi hop`    | reasoning           | Multi-hop reasoning                    |
| `negation`     | reasoning           | Phủ định                               |
| `num2`         | reasoning (complex) | Multi relation 2 chiều                 |
| `num3`         | reasoning (complex) | Multi relation 3 chiều                 |
| `substitution` | generation          | Claim tạo bằng thay thế thông tin      |


So, these won't work or need to concern.
(1) existence
(2) num[```i```]
(3) multi-hop

The next step is to find number of rows to concern

In [7]:
is_concern = []
# Get all types items from train, test, and valid datasets
for data in [train_updated_data, test_updated_data, valid_updated_data]:
    for item in data:
        if 'multi-hop' in data[item]['types']:
            is_concern.append(True)
            continue
        if 'num2' in data[item]['types']:
            is_concern.append(True)
            continue
        if 'num3' in data[item]['types']:
            is_concern.append(True)
            continue
        if 'num4' in data[item]['types']:
            is_concern.append(True)
            continue
        is_concern.append(False)

print("Number of concerned types:", sum(is_concern))
print("Total types checked:", len(is_concern))

Number of concerned types: 4066
Total types checked: 6300


In [8]:
INDEX = 20
print(f"Concern state: {is_concern[INDEX]}")
print(f"Key: {key_list[INDEX]}")
print(f"DATA: {concat_data[key_list[INDEX]]}")

Concern state: False
Key: It was Micol Fontana who did not have an award.
DATA: {'Label': [True], 'Entity_set': ['Micol_Fontana'], 'Evidence': {'Micol_Fontana': [['award']]}, 'types': ['coll:model', 'negation', 'existence'], 'triplet': []}


# Analyzing the difficult level of multi-hop or num-i

In [9]:
import re

PRONOUNS = {"he", "she", "they", "them", "his", "her", "their",
            "its", "it", "this artist", "the artist", "the city",
            "the governor", "the musician"}

def contains_pronoun(claim: str):
    text = claim.lower()
    return any(p in text for p in PRONOUNS)

def classify_multihop_complexity(claim: str, sample: dict):
    """
    Classify multi-hop difficulty into: easy, medium, hard
    """
    types = sample["types"]
    evidence = sample.get("Evidence")

    # Not multi-hop => always easy
    if "multi hop" not in types:
        return "easy"

    # Count relations per entity
    rel_counts = {ent: len(paths) for ent, paths in evidence.items()}

    # Count relation-types inside each hop
    hop_complexity = []
    for ent, paths in evidence.items():
        for hop in paths:
            hop_complexity.append(len(hop))

    max_rels_per_entity = max(rel_counts.values()) if rel_counts else 0
    max_hop_width = max(hop_complexity) if hop_complexity else 0

    # Number of entities
    num_entities = len(evidence)

    # Detect cycles -- if entity A has ~r followed by B has r again
    def has_inverse_cycles(evidence):
        inverse_pairs = {}
        for ent, paths in evidence.items():
            for hop in paths:
                for rel in hop:
                    if rel.startswith("~"):
                        inverse_pairs.setdefault(ent, []).append(rel[1:])
        # If multiple entities share inverse forms => likely cycle
        inverse_map = {}
        for ent, relations in inverse_pairs.items():
            for r in relations:
                inverse_map.setdefault(r, []).append(ent)
        return any(len(v) > 1 for v in inverse_map.values())

    # Rule: Easy cases
    if max_rels_per_entity == 1 and max_hop_width == 1:
        # no ambiguity, no complex inverse chains
        return "easy"

    # Rule: Hard cases (requiring GPT)
    # 1. Implicit subject/object → pronoun
    if contains_pronoun(claim):
        return "hard"

    # 2. Entity has > 2 relation paths → ambiguous multi-hop (num2, num3)
    if max_rels_per_entity >= 3:
        return "hard"

    # 3. Any hop uses >= 3 relations (multi-path)
    if max_hop_width >= 3:
        return "hard"

    # 4. Inverse cycles or loops
    if has_inverse_cycles(evidence):
        return "hard"

    # 5. Too many entities in multi-hop (structure ambiguous)
    if num_entities >= 4:
        return "hard"

    # If not easy, not hard → medium
    return "medium"

complexity_counts = {"easy": [], "medium": [], "hard": []}

for i, key in enumerate(concat_data):
    if is_concern[i] == False:
        continue

    item = concat_data[key]
    difficulty = classify_multihop_complexity(key, item)
    complexity_counts[difficulty].append(key)

In [10]:
# Count and show percentages
total_counts = sum(len(v) for v in complexity_counts.values())
for level, items in complexity_counts.items():
    count = len(items)
    percentage = (count / total_counts) * 100 if total_counts > 0 else 0
    print(f"{level.capitalize()}: {count} items ({percentage:.2f}%)")

Easy: 2315 items (56.94%)
Medium: 113 items (2.78%)
Hard: 1638 items (40.29%)


# Inspecting each classification results

In [11]:
INDEX = 3
key = complexity_counts['medium'][INDEX]
print(f"Key: {key}")
print(f"DATA: {concat_data[key]}")

Key: Born on April 27, 1937 and died on December 9th,1991.
DATA: {'Label': [True], 'Entity_set': ['"1991-12-09"', '"1937-04-27"'], 'Evidence': {'"1937-04-27"': [['~birthDate', 'deathDate']], '"1991-12-09"': [['~deathDate', 'birthDate']]}, 'types': ['coll:model', 'num2', 'multi hop'], 'triplet': [('unknown_0', 'birthDate', '"1937-04-27"'), ('unknown_0', 'deathDate', '"1991-12-09"'), ('unknown_1', 'deathDate', '"1991-12-09"'), ('unknown_1', 'birthDate', '"1937-04-27"'), ('"1991-12-09"', 'birthDate', '"1937-04-27"'), ('"1937-04-27"', 'deathDate', '"1991-12-09"')]}


In [None]:
import re
from typing import Dict, Any, List, Tuple

_UNKNOWN_RE = re.compile(r"^unknown[\s_\-]?(\d+)$", flags=re.I)

def _is_unknown_token(s: str):
    m = _UNKNOWN_RE.match(str(s))
    return bool(m)




def _unknown_index(s: str):
    m = _UNKNOWN_RE.match(str(s))
    return int(m.group(1)) if m else None

def _find_representative_unknown(triplets: List[Tuple[str,str,str]], keep_idx: int):
    """
    Find an existing string in triplets that corresponds to unknown with index keep_idx,
    preserving original formatting (underscore/space/hyphen).
    If not found, return canonical 'unknown_0' style.
    """
    for x,y,z in triplets:
        for token in (x,y,z):
            if _is_unknown_token(token) and _unknown_index(token) == keep_idx:
                return token
    return f"unknown_{keep_idx}"

def apply_fix_medium(sample: Dict[str, Any],
                     is_eliminate_reverse: bool = True,
                     remove_underscore: bool = True) -> Dict[str, Any]:
    """
    Fix medium-case triplets in-place (updates sample["triplet"]).
    - Keeps only unknown_0 related triplets (and canonicalizes unknown token)
    - Removes unknown_i (i>0)
    - Removes direct entity->entity edges when unknown_0 exists
    - If is_eliminate_reverse=True, removes entity->rel->unknown_0 triples (reverse)
    """

    triplets = list(sample.get("triplet", []))
    evidence = sample.get("Evidence", {})

    # Normalize function for comparison if needed
    def norm_name(s: str) -> str:
        if not isinstance(s, str):
            return s
        if remove_underscore:
            return s.replace("_", " ")
        return s

    # 1) detect any unknown indices present
    used_idxs = set()
    for h,r,t in triplets:
        if _is_unknown_token(h):
            idx = _unknown_index(h)
            if idx is not None:
                used_idxs.add(idx)
        if _is_unknown_token(t):
            idx = _unknown_index(t)
            if idx is not None:
                used_idxs.add(idx)

    if not used_idxs:
        # Nothing to do: no unknown tokens present. Just return sample unchanged.
        return sample

    # choose keep index = smallest used index (prefer unknown_0 if present)
    keep_idx = min(used_idxs)

    # find representative string for keep unknown (to preserve formatting)
    rep_unknown = _find_representative_unknown(triplets, keep_idx)

    # Build set of entities that appear ONLY as ~relations in evidence (tail-only)
    ev_map = { norm_name(k): v for k,v in evidence.items() }
    tail_only_entities = set()
    for ent, groups in ev_map.items():
        has_fwd = False
        has_inv = False
        for g in groups:
            for rel in g:
                if isinstance(rel, str) and rel.startswith("~"):
                    has_inv = True
                else:
                    has_fwd = True
        if has_inv and not has_fwd:
            tail_only_entities.add(ent)

    # 2) Filter triplets:
    cleaned: List[Tuple[str,str,str]] = []
    for h, r, t in triplets:
        h_s = str(h)
        t_s = str(t)

        # A) Remove any unknown token whose index != keep_idx
        if _is_unknown_token(h_s):
            if _unknown_index(h_s) != keep_idx:
                continue
        if _is_unknown_token(t_s):
            if _unknown_index(t_s) != keep_idx:
                continue

        # B) If both head and tail are non-unknown (direct entity->entity), drop it
        if (not _is_unknown_token(h_s)) and (not _is_unknown_token(t_s)):
            # If unknown_0 exists anywhere, we drop direct edges to avoid bypass.
            # We enforce dropping direct edges here (Option A).
            continue

        # C) If eliminate_reverse and triplet is entity -> rel -> unknown_0, drop it
        #     (i.e., head is non-unknown, tail is the keep unknown rep)
        if is_eliminate_reverse:
            if (not _is_unknown_token(h_s)) and (_is_unknown_token(t_s) and _unknown_index(t_s)==keep_idx):
                continue

        # D) Normalize the keep unknown token to the representative string
        #    so final output uses consistent unknown token (rep_unknown).
        h_out = h_s
        t_out = t_s
        if _is_unknown_token(h_out) and _unknown_index(h_out) == keep_idx:
            h_out = rep_unknown
        if _is_unknown_token(t_out) and _unknown_index(t_out) == keep_idx:
            t_out = rep_unknown

        cleaned.append((h_out, r, t_out))

    # 3) Deduplicate preserving order
    final = []
    seen = set()
    for tpl in cleaned:
        if tpl not in seen:
            seen.add(tpl)
            final.append(tpl)

    # 4) If final is empty but evidence suggests we should keep a chain,
    #    attempt to construct minimal unknown_0 chain from evidence:
    #    (This is a safe fallback, but typically not needed.)
    if not final:
        # Build simple chain: find an entity that has forward rels and match partner with ~rel
        # We'll create rep_unknown -> rel -> partner for such cases
        # This fallback is conservative.
        for ent, groups in ev_map.items():
            for rel_group in groups:
                # consider only multi-hop/one-hop forward relations
                for rel in rel_group:
                    inv = isinstance(rel,str) and rel.startswith("~")
                    rel_clean = rel[1:] if inv else rel
                    if not inv:
                        # find partner that has ~rel
                        partner = None
                        for e2, g2 in ev_map.items():
                            if e2 == ent: continue
                            for g2part in g2:
                                if f"~{rel_clean}" in g2part:
                                    partner = e2
                                    break
                            if partner: break
                        if partner:
                            final.append((ent, rel_clean, rep_unknown))
                            final.append((rep_unknown, "ethnicGroup" if "ethnic" in rel_clean else rel_clean, partner))
                            break
            if final:
                break

    sample["triplet"] = final
    return sample



# def apply_fix_medium(sample, is_eliminate_reverse=True, remove_underscore=True):
#     """
#     Medium-level fixing of triplets generated by generate_basic.
#     - Supports 'unknown_0' AND 'unknown 0'
#     - Removes reverse-direction triples
#     - Removes spurious unknown_1+
#     - Keeps only the correct multi-hop structure
#     """

#     triplets = sample["triplet"]
#     evidence = sample["Evidence"]

#     # Normalize entities
#     def norm(x):
#         if isinstance(x, str) and remove_underscore:
#             return x.replace("_", " ")
#         return x

#     # Identify all unknowns: matches "unknown_0", "unknown 0", "unknown-0"
#     UNKNOWN_RE = re.compile(r"^unknown[\s_]?(\d+)$")

#     def is_unknown(x):
#         return isinstance(x, str) and UNKNOWN_RE.match(x) is not None

#     def unknown_index(x):
#         m = UNKNOWN_RE.match(x)
#         return int(m.group(1)) if m else None

#     # Normalize evidence keys
#     ev = {norm(k): v for k, v in evidence.items()}

#     # ---------------------------------------------------
#     # STEP 1 — Find tail-only entities (only have ~rel)
#     # ---------------------------------------------------
#     tail_only_entities = set()
#     for ent, groups in ev.items():
#         has_forward = False
#         has_inverse = False
#         for g in groups:
#             for r in g:
#                 if r.startswith("~"):
#                     has_inverse = True
#                 else:
#                     has_forward = True
#         if is_eliminate_reverse and has_inverse and not has_forward:
#             tail_only_entities.add(norm(ent))

#     # ---------------------------------------------------
#     # STEP 2 — Identify all unknown indexes used
#     # ---------------------------------------------------
#     used_unknowns = set()
#     for h, r, t in triplets:
#         if is_unknown(h):
#             used_unknowns.add(unknown_index(h))
#         if is_unknown(t):
#             used_unknowns.add(unknown_index(t))

#     # Keep only the *lowest* unknown index (usually unknown_0)
#     keep_idx = min(used_unknowns) if used_unknowns else None

#     # ---------------------------------------------------
#     # STEP 3 — Filter triplets
#     # ---------------------------------------------------
#     cleaned = []

#     for h, r, t in triplets:

#         h_n = norm(h) if isinstance(h, str) else h
#         t_n = norm(t) if isinstance(t, str) else t

#         # A) Remove unknown_i where i > keep_idx
#         if is_unknown(h_n):
#             if unknown_index(h_n) != keep_idx:
#                 continue
#         if is_unknown(t_n):
#             if unknown_index(t_n) != keep_idx:
#                 continue

#         # B) Remove reverse-direction triples
#         if is_eliminate_reverse and h_n in tail_only_entities:
#             continue

#         # C) Remove direct Ahmet → African Americans that bypass unknown
#         if keep_idx is not None and t_n in tail_only_entities and not is_unknown(t_n):
#             # remove h → tail-only entity direct if unknown exists
#             if not is_unknown(h_n):
#                 continue

#         cleaned.append((h_n, r, t_n))

#     # ---------------------------------------------------
#     # STEP 4 — Deduplicate preserving order
#     # ---------------------------------------------------
#     final = []
#     seen = set()
#     for tpl in cleaned:
#         if tpl not in seen:
#             seen.add(tpl)
#             final.append(tpl)

#     sample["triplet"] = final
#     return sample


# Test the fix on a medium complexity example
INDEX = 2
key = complexity_counts['medium'][INDEX]
print(f"Before Fix - Key: {key}")
print(f"DATA: {concat_data[key]}")

after_fix_sample = apply_fix_medium(concat_data[key], is_eliminate_reverse=False)
after_fix = after_fix_sample['triplet']
print(f"After Fix - Triplets: {after_fix}")

Before Fix - Key: Ahmet Ertegun is from a country that has an ethnic group called African Americans.
DATA: {'Label': [True], 'Entity_set': ['Ahmet_Ertegun', 'African_Americans'], 'Evidence': {'Ahmet_Ertegun': [['hometown', 'ethnicGroup']], 'African_Americans': [['~ethnicGroup', '~hometown']]}, 'types': ['coll:model', 'num2', 'multi hop'], 'triplet': [('Ahmet Ertegun', 'hometown', 'unknown_0'), ('unknown_0', 'ethnicGroup', 'African Americans'), ('unknown_1', 'ethnicGroup', 'African Americans'), ('African Americans', 'hometown', 'unknown_1'), ('Ahmet Ertegun', 'hometown', 'African Americans'), ('Ahmet Ertegun', 'ethnicGroup', 'African Americans')]}
After Fix - Triplets: [('Ahmet Ertegun', 'hometown', 'unknown_0'), ('unknown_0', 'ethnicGroup', 'African Americans')]
