### Summary of obtained results in jsonl format

Note that you can perform this analysis at the end of the notebook as well. However we put it here too, as it's more convenient when you use the python script (RAG_full.py) or had to divide the computation into different session. All you need to run it are the results in jsonl file and the dictionaries for id->label mapping.

At the end we also add some code for reconstruction of recall@k in retrieval if it wasn't saved directly in the results.

In [4]:
import json
import re

In [9]:
norm_label_to_qids = json.load(open("norm_label_to_qids.json"))

def normalize_label(s: str) -> str:
    s = (s or "").strip().lower()
    s = re.sub(r"\s+", " ", s)
    s = re.sub(r"\.$", "", s)
    return s

def tail_label_to_qids(tail_label: str):
    n = normalize_label(tail_label)
    return norm_label_to_qids.get(n, [])

qid_to_label = json.load(open("qid_to_label.json"))
pid_to_label = json.load(open("pid_to_label.json"))

In [5]:
def summarize_jsonl(out_jsonl):
    h1 = 0
    h3 = 0
    total = 0
    parse_fail = 0
    null_count = 0

    with open(out_jsonl, "r", encoding="utf-8") as f:
        for line in f:
            rec = json.loads(line)
            gold_tail = str(rec["gold_tail"])
            pred_labels = rec.get("pred_labels", None)

            total += 1

            if pred_labels is None:
                parse_fail += 1
                continue

            if len(pred_labels) == 0:
                null_count += 1
                continue

            pred_qids_1 = tail_label_to_qids(pred_labels[0])
            pred_qids_3 = set()
            for lab in pred_labels[:3]:
                pred_qids_3.update(tail_label_to_qids(lab))

            if gold_tail in pred_qids_1:
                h1 += 1
            if gold_tail in pred_qids_3:
                h3 += 1

    return {
        "hits@1": h1 / max(total, 1),
        "hits@3": h3 / max(total, 1),
        "total": total,
        "parse_fail": parse_fail,
        "null": null_count,
    }

In [6]:
summary = summarize_jsonl("eval_k24.jsonl")
summary

{'hits@1': 0.6219026071967249,
 'hits@3': 0.6866515837104072,
 'total': 18564,
 'parse_fail': 0,
 'null': 1441}

In [7]:
import pandas as pd

df = pd.read_csv("tkgl-smallpedia_edgelist.csv")
df["ts"] = df["ts"].astype(int)

train_df = df[df["ts"] < 2008].copy()
test_df  = df[df["ts"] >= 2008].copy()

print(len(train_df), len(test_df))

468790 81586


In [8]:
from collections import defaultdict

def build_hr_index(df):
    idx = defaultdict(list)
    for _, row in df.iterrows():
        h = str(row["head"])
        r = str(row["relation_type"])
        t = str(row["tail"])
        ts = int(row["ts"])
        idx[(h, r)].append((ts, t))
    # sort by ts
    for k in idx:
        idx[k].sort(key=lambda x: x[0])
    return idx

hr_index = build_hr_index(train_df)

def build_head_index(df):
    idx = defaultdict(list)
    for _, row in df.iterrows():
        h = str(row["head"])
        r = str(row["relation_type"])
        t = str(row["tail"])
        ts = int(row["ts"])
        idx[h].append((ts, r, t))
    # sort by ts
    for k in idx:
        idx[k].sort(key=lambda x: x[0])
    return idx

head_index = build_head_index(train_df)

def build_entity_index(df):
    idx = defaultdict(list)
    for _, row in df.iterrows():
        h = str(row["head"])
        r = str(row["relation_type"])
        t = str(row["tail"])
        ts = int(row["ts"])

        idx[h].append((ts, h, r, t))
        idx[t].append((ts, h, r, t))

    for e in idx:
        idx[e].sort(key=lambda x: x[0])
    return idx

entity_index = build_entity_index(train_df)

def build_rel_index(df):
    idx = defaultdict(list)
    for _, row in df.iterrows():
        ts = int(row["ts"])
        h = str(row["head"])
        r = str(row["relation_type"])
        t = str(row["tail"])
        idx[r].append((ts, h, t))
    for r in idx:
        idx[r].sort(key=lambda x: x[0])
    return idx

rel_index = build_rel_index(train_df)

In [10]:
def retrieve_facts_new(head_id, rel_id, ts, k=12):
    out = []

    # 1) exact match: (head, rel) -> (ts, h, r, t)
    facts_hr = hr_index.get((head_id, rel_id), [])  # list[(ts, tail_id)]
    if facts_hr:
        ranked = sorted(facts_hr, key=lambda x: abs(x[0] - ts))
        out.extend([(y, head_id, rel_id, tail_id) for (y, tail_id) in ranked])

    # 2) fallback: (head, *) -> already (ts, r, t) OR (ts, r, tail_id)
    if len(out) < k:
        facts_h = head_index.get(head_id, [])  # list[(ts, r, tail_id)]
        if facts_h:
            ranked_h = sorted(facts_h, key=lambda x: abs(x[0] - ts))
            out.extend([(y, head_id, r, tail_id) for (y, r, tail_id) in ranked_h])

    # 3) fallback: facts about entity (as head OR tail)
    if len(out) < k:
        facts_e = entity_index.get(head_id, [])  # list[(ts, h, r, t)]
        if facts_e:
            ranked_e = sorted(facts_e, key=lambda x: abs(x[0] - ts))
            # add only the number that is missing
            need = k - len(out)
            out.extend(ranked_e[:need])

    # 4) relation-only examples (global)
    if len(out) < k:
        facts_r = rel_index.get(rel_id, [])  # [(ts, h, t)]
        ranked_r = sorted(facts_r, key=lambda x: abs(x[0] - ts))
        need = k - len(out)
        out.extend([(y, h, rel_id, t) for (y, h, t) in ranked_r[:need]])

    # filter to k
    out = out[:k]

    # change to text
    facts_txt = []
    for y, h, r, t in out:
        h_label = qid_to_label.get(h, h)
        r_label = pid_to_label.get(r, r)
        t_label = qid_to_label.get(t, t)
        facts_txt.append(f"In {y}, {h_label} {r_label} {t_label}.")

    return facts_txt


In [11]:
def extract_tail_candidates(retrieved_facts: list[str], head_label: str, rel_label: str):
    """
    Returns unique tail labels (strings) in format:
      In YEAR, HEAD REL TAIL.
    """
    candidates = []
    key = f"{head_label} {rel_label} "

    for f in retrieved_facts:
        if key not in f:
            continue

        tail = f.split(key, 1)[1].strip()

        # delete dot at the end of string
        tail = re.sub(r"\.\s*$", "", tail)

        tail = tail.strip()

        if tail:
            candidates.append(tail)

    # unique candidates with order retained
    seen = set()
    uniq = []
    for c in candidates:
        if c not in seen:
            seen.add(c)
            uniq.append(c)
    return uniq


In [12]:
# reconstruction of candidates and recall@k in retrieval

def recall_at_k_from_jsonl(jsonl_path, k=24, max_lines=None):
    total = 0
    hit = 0
    no_candidates = 0

    with open(jsonl_path, "r", encoding="utf-8") as f:
        for j, line in enumerate(f):
            if max_lines and j >= max_lines:
                break
            rec = json.loads(line)

            ts = int(rec["ts"])
            head_id = str(rec["head"])
            rel_id = str(rec["rel"])
            gold_tail = str(rec["gold_tail"])

            # reconstruct candidates
            head_label = qid_to_label.get(head_id, head_id)
            rel_label  = pid_to_label.get(rel_id, rel_id)

            retrieved_facts = retrieve_facts_new(head_id, rel_id, ts, k=k)
            candidates = extract_tail_candidates(retrieved_facts, head_label, rel_label)

            total += 1
            if not candidates:
                no_candidates += 1
                continue

            # map candidates -> qids and check gold
            cand_qids = set()
            for lab in candidates:
                cand_qids.update(tail_label_to_qids(lab))

            if gold_tail in cand_qids:
                hit += 1

    return {
        "recall@k_candidates": hit / max(total, 1),
        "total": total,
        "no_candidates": no_candidates,
    }


In [14]:
recall = recall_at_k_from_jsonl("eval_k24.jsonl")

In [15]:
recall

{'recall@k_candidates': 0.7106765783236372,
 'total': 18564,
 'no_candidates': 1441}

In [16]:
h1 = summary['hits@1']/recall['recall@k_candidates']
h3 = summary['hits@3']/recall['recall@k_candidates']
print(h1, h3)

0.8750852724929887 0.9661941938906996


In [17]:
summary = summarize_jsonl("hybrid_k24.jsonl")
summary

{'hits@1': 0.31071309378128825,
 'hits@3': 0.3702302814551118,
 'total': 8989,
 'parse_fail': 6,
 'null': 0}