## This is the main notebook of the LLM + RAG part of the project

It should be possible to use it for both the filtered and full dataset as well for all major RAG versions with only small adjustments in the code (e.g. setting the arguments of eval_open_world_hits_1_3 function)

## 0. Setup

In [1]:
import socket, os, sys, warnings
print("hostname:", socket.gethostname())
print("CUDA_VISIBLE_DEVICES:", os.environ.get("CUDA_VISIBLE_DEVICES"))
print(sys.executable)
warnings.filterwarnings("ignore")

hostname: node18.enst.fr
CUDA_VISIBLE_DEVICES: 0
/home/infres/mporwisz-25/miniconda3/envs/tkgc_rag/bin/python


In [2]:
import torch

from transformers import (
    BitsAndBytesConfig,
    AutoTokenizer,
    AutoModelForCausalLM,
    GenerationConfig
)

# LLM: https://huggingface.co/Qwen/Qwen3-4B-Instruct-2507
llm_name = "Qwen/Qwen3-4B-Instruct-2507"

# We want to use 4bit quantization to save memory (in case some of you use their own computer)
quantization_config = BitsAndBytesConfig(
    load_in_8bit=False, load_in_4bit=True
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(llm_name, padding_side="left")

# Prevent some transformers specific issues.
tokenizer.use_default_system_prompt = False
tokenizer.pad_token_id = tokenizer.eos_token_id

# Load LLM.
llm = AutoModelForCausalLM.from_pretrained(
    llm_name,
    quantization_config=quantization_config,
    device_map={"": 0}, # load all the model layers on GPU 0
    torch_dtype=torch.bfloat16, # float precision
)

# Set LLM on eval mode.
llm.eval()

`torch_dtype` is deprecated! Use `dtype` instead!
Loading weights: 100%|███████████████████████| 398/398 [00:46<00:00,  8.49it/s, Materializing param=model.norm.weight]


Qwen3ForCausalLM(
  (model): Qwen3Model(
    (embed_tokens): Embedding(151936, 2560)
    (layers): ModuleList(
      (0-35): 36 x Qwen3DecoderLayer(
        (self_attn): Qwen3Attention(
          (q_proj): Linear4bit(in_features=2560, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=2560, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=2560, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=2560, bias=False)
          (q_norm): Qwen3RMSNorm((128,), eps=1e-06)
          (k_norm): Qwen3RMSNorm((128,), eps=1e-06)
        )
        (mlp): Qwen3MLP(
          (gate_proj): Linear4bit(in_features=2560, out_features=9728, bias=False)
          (up_proj): Linear4bit(in_features=2560, out_features=9728, bias=False)
          (down_proj): Linear4bit(in_features=9728, out_features=2560, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): Qwen3RMSNorm((2560,), eps=1e-06)
 

In [3]:
from transformers import set_seed
import random
import numpy as np
import re

SEED = 42
set_seed(SEED)

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [4]:
import json

qid_to_label = json.load(open("qid_to_label.json"))
pid_to_label = json.load(open("pid_to_label.json"))
norm_label_to_qids = json.load(open("norm_label_to_qids.json"))

In [5]:
# load and split dataset
import pandas as pd

df = pd.read_csv("tkgl-smallpedia_edgelist.csv")
df["ts"] = df["ts"].astype(int)

train_df = df[df["ts"] < 2008].copy()
test_df  = df[df["ts"] >= 2008].copy()

print(len(train_df), len(test_df))

468790 81586


In [None]:
# # for filtering data
# train_df_filter = df[df["ts"] < 1998].copy()
#
# train_entities = set(pd.concat([train_df_filter["head"], train_df_filter["tail"]]).astype(str).unique())
#
# test_df_filtered = test_df[
#     (test_df["head"].astype(str).isin(train_entities)) &
#     (test_df["tail"].astype(str).isin(train_entities))
# ].copy()
#
# test_df_filtered.drop_duplicates()
# len(test_df_filtered)

### 1. RAG

In [7]:
from collections import defaultdict

def build_hr_index(df):
    idx = defaultdict(list)
    for _, row in df.iterrows():
        h = str(row["head"])
        r = str(row["relation_type"])
        t = str(row["tail"])
        ts = int(row["ts"])
        idx[(h, r)].append((ts, t))
    # sort by ts so it's easy to retreive close entries
    for k in idx:
        idx[k].sort(key=lambda x: x[0])
    return idx

hr_index = build_hr_index(train_df)

In [8]:
def build_head_index(df):
    idx = defaultdict(list)
    for _, row in df.iterrows():
        h = str(row["head"])
        r = str(row["relation_type"])
        t = str(row["tail"])
        ts = int(row["ts"])
        idx[h].append((ts, r, t))
    # sort by ts so it's easy to retreive close entries
    for k in idx:
        idx[k].sort(key=lambda x: x[0])
    return idx

head_index = build_head_index(train_df)

In [9]:
def build_entity_index(df):
    idx = defaultdict(list)
    for _, row in df.iterrows():
        h = str(row["head"])
        r = str(row["relation_type"])
        t = str(row["tail"])
        ts = int(row["ts"])

        idx[h].append((ts, h, r, t))
        idx[t].append((ts, h, r, t))

    for e in idx:
        idx[e].sort(key=lambda x: x[0])
    return idx

entity_index = build_entity_index(train_df)


In [10]:
def build_rel_index(df):
    idx = defaultdict(list)
    for _, row in df.iterrows():
        ts = int(row["ts"])
        h = str(row["head"])
        r = str(row["relation_type"])
        t = str(row["tail"])
        idx[r].append((ts, h, t))
    for r in idx:
        idx[r].sort(key=lambda x: x[0])
    return idx

rel_index = build_rel_index(train_df)

In [12]:
def retrieve_facts_new(head_id, rel_id, ts, k=12):
    out = []

    # 1) exact match: (head, rel) -> (ts, h, r, t)
    facts_hr = hr_index.get((head_id, rel_id), [])  # list[(ts, tail_id)]
    if facts_hr:
        ranked = sorted(facts_hr, key=lambda x: abs(x[0] - ts))
        out.extend([(y, head_id, rel_id, tail_id) for (y, tail_id) in ranked])

    # 2) fallback: (head, *) -> already (ts, r, t) OR (ts, r, tail_id)
    if len(out) < k:
        facts_h = head_index.get(head_id, [])  # list[(ts, r, tail_id)]
        if facts_h:
            ranked_h = sorted(facts_h, key=lambda x: abs(x[0] - ts))
            out.extend([(y, head_id, r, tail_id) for (y, r, tail_id) in ranked_h])

    # 3) fallback: facts about entity (as head OR tail)
    if len(out) < k:
        facts_e = entity_index.get(head_id, [])  # list[(ts, h, r, t)]
        if facts_e:
            ranked_e = sorted(facts_e, key=lambda x: abs(x[0] - ts))
            # dodaj tylko tyle ile brakuje
            need = k - len(out)
            out.extend(ranked_e[:need])

    # 4) relation-only examples (global)
    if len(out) < k:
        facts_r = rel_index.get(rel_id, [])  # [(ts, h, t)]
        ranked_r = sorted(facts_r, key=lambda x: abs(x[0] - ts))
        need = k - len(out)
        out.extend([(y, h, rel_id, t) for (y, h, t) in ranked_r[:need]])

    # filter to k
    out = out[:k]

    # change to text
    facts_txt = []
    for y, h, r, t in out:
        h_label = qid_to_label.get(h, h)
        r_label = pid_to_label.get(r, r)
        t_label = qid_to_label.get(t, t)
        facts_txt.append(f"In {y}, {h_label} {r_label} {t_label}.")

    return facts_txt

In [13]:
def normalize_label(s: str) -> str:
    s = (s or "").strip().lower()
    s = re.sub(r"\s+", " ", s)
    s = re.sub(r"\.$", "", s)
    return s

def tail_label_to_qids(tail_label: str):
    n = normalize_label(tail_label)
    return norm_label_to_qids.get(n, [])


In [16]:
# Define the function to call the qwen model
def qwen_llm_candidates(prompt):
    # 1. Instruction prompt
    system_prompt = (
        "You solve Temporal Knowledge Graph Completion (TKGC).\n"
        "You will be given a query (head, relation, timestamp), retrieved facts, and a candidate list.\n"
        "Your task is to select the most likely tail entity from the candidate list.\n\n"
        "Rules:\n"
        "1) Choose up to 3 candidates ONLY from the provided candidate list.\n"
        "2) Order them from best to worst.\n"
        "3) Return ONLY valid JSON in ONE line, and nothing else.\n"
        '4) Output format: {"tail_labels": ["...", "...", "..."]}\n'
        '   If no candidate matches, return {"tail_labels": []}.\n'
    )

    messages = [
        {"role": "system", "content": system_prompt},
        {'role': 'user', 'content': prompt}
    ]

    inputs = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt",
        return_dict=True
        ).to(llm.device)

    with torch.inference_mode():
        outputs = llm.generate(**inputs, max_new_tokens=64, do_sample=False)

    raw = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True).strip()
    return raw

In [17]:
SYSTEM_PROMPT = (
    "You solve Temporal Knowledge Graph Completion (open-world).\n"
    "Given a query (head, relation, timestamp) and retrieved facts, predict the most likely tail entity.\n"
    "The correct tail may NOT appear in the retrieved facts.\n"
    "Return up to 3 tail entity labels ordered best→worst.\n"
    "Return ONLY one-line valid JSON and nothing else.\n"
    'Format: {"tail_labels": ["...","...","..."]}\n'
    'If you cannot propose any, return {"tail_labels": []}.'
)


In [18]:
# Define the function to call the qwen model
def qwen_llm_open_top3(prompt):

    system_prompt = (
        "You solve Temporal Knowledge Graph Completion (open-world).\n"
        "Given a query (head, relation, timestamp) and retrieved facts, predict the most likely tail entity.\n"
        "The correct tail may NOT appear in the retrieved facts.\n"
        "Return up to 3 tail entity labels ordered best→worst.\n"
        "Return ONLY one-line valid JSON and nothing else.\n"
        'Format: {"tail_labels": ["...","...","..."]}\n'
        'If you cannot propose any, return {"tail_labels": []}.'
    )

    messages = [
        {"role": "system", "content": system_prompt},
        {'role': 'user', 'content': prompt}
    ]

    inputs = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt",
        return_dict=True
        ).to(llm.device)

    with torch.inference_mode():
        outputs = llm.generate(**inputs, max_new_tokens=64, do_sample=False)

    raw = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True).strip()
    return raw

In [19]:
def build_user_prompt_candidates_top3(head_label, rel_label, ts, retrieved_facts, candidates):
    facts_block = "\n".join([f"- {f}" for f in retrieved_facts])
    cand_block = "\n".join([f"{i+1}. {c}" for i, c in enumerate(candidates)])

    return (
        f"Query:\nTime: {ts}\nHead: {head_label}\nRelation: {rel_label}\n\n"
        f"Retrieved facts:\n{facts_block}\n\n"
        f"Candidates (pick up to 3):\n{cand_block}\n\n"
        "Rules:\n"
        "1) Choose up to 3 candidates ONLY from the candidate list.\n"
        "2) Order them best→worst.\n"
        '3) Return ONLY JSON: {"tail_labels": ["...","...","..."]}.\n'
        '4) If none match, return {"tail_labels": []}.'
    )

In [20]:
def rag_answer_candidates_top3(head_id: str, rel_id: str, ts: int, k=12):
    head_label = qid_to_label.get(head_id, head_id)
    rel_label  = pid_to_label.get(rel_id, rel_id)

    retrieved_facts = retrieve_facts_new(head_id, rel_id, ts, k=k)
    if not retrieved_facts:
        return [], [], '{"tail_labels": []}'

    candidates = extract_tail_candidates(retrieved_facts, head_label, rel_label)
    if not candidates:
        return [], retrieved_facts, '{"tail_labels": []}'

    user_prompt = build_user_prompt_candidates_top3(head_label, rel_label, ts, retrieved_facts, candidates)
    raw = qwen_llm_candidates(user_prompt)
    pred_labels = extract_tail_labels_topk(raw, k=3) or []
    return pred_labels, retrieved_facts, raw

In [21]:
def extract_tail_candidates(retrieved_facts: list[str], head_label: str, rel_label: str):
    """
    Returns unique tail labels (strings) extracted from facts in format:
      In YEAR, HEAD REL TAIL.
    """
    candidates = []
    key = f"{head_label} {rel_label} "

    for f in retrieved_facts:
        if key not in f:
            continue

        tail = f.split(key, 1)[1].strip()

        # delete dot at the end
        tail = re.sub(r"\.\s*$", "", tail)

        tail = tail.strip()

        if tail:
            candidates.append(tail)

    # unique candidates with retained order
    seen = set()
    uniq = []
    for c in candidates:
        if c not in seen:
            seen.add(c)
            uniq.append(c)
    return uniq


In [54]:
def rag_answer_open_world_top3(head_id: str, rel_id: str, ts: int, k=24):
    head_label = qid_to_label.get(head_id, head_id)
    rel_label  = pid_to_label.get(rel_id, rel_id)

    retrieved_facts = retrieve_facts_new(head_id, rel_id, ts, k=k)

    user_prompt = (
        f"Query:\nTime: {ts}\nHead: {head_label}\nRelation: {rel_label}\n\n"
        "Retrieved facts:\n" + "\n".join([f"- {f}" for f in retrieved_facts]) + "\n\n"
        "Task:\n"
        "Predict up to 3 most likely tail entity labels (best→worst).\n"
        "The correct tail may NOT be in the retrieved facts.\n"
        'Return ONLY JSON: {"tail_labels": ["...","...","..."]}\n'
        'If none, return {"tail_labels": []}.'
    )

    raw = qwen_llm_open_top3(user_prompt)
    pred_labels = extract_tail_labels_topk(raw, k=3)  # None / [] / ["a","b","c"]
    return pred_labels, retrieved_facts, raw


In [46]:
def rag_answer_hybrid_top3(head_id: str, rel_id: str, ts: int, k=24, rag_threshold=1):
    # try candidate-only
    head_label = qid_to_label.get(head_id, head_id)
    rel_label  = pid_to_label.get(rel_id, rel_id)

    retrieved_facts = retrieve_facts_new(head_id, rel_id, ts, k=k)

    candidates = extract_tail_candidates(retrieved_facts, head_label, rel_label)

    if len(candidates) >= rag_threshold:
        # candidate-only top3
        pred_labels, retrieved_facts2, raw = rag_answer_candidates_top3(head_id, rel_id, ts, k=k)
        return pred_labels, retrieved_facts2, raw, "candidates"

    # fallback: open-world top3
    pred_labels, retrieved_facts2, raw = rag_answer_open_world_top3(head_id, rel_id, ts, k=k)
    return pred_labels, retrieved_facts2, raw, "open_world"


## 2. Evaluation

In [48]:
_JSON_RE = re.compile(r"\{.*\}", re.DOTALL)

def extract_tail_labels_topk(raw: str, k=3):
    if not raw:
        return None

    cleaned = raw.strip()
    cleaned = re.sub(r"^```(?:json)?", "", cleaned, flags=re.IGNORECASE).strip()
    cleaned = re.sub(r"```$", "", cleaned).strip()

    m = _JSON_RE.search(cleaned)
    if not m:
        return None

    block = m.group(0)
    try:
        obj = json.loads(block)
    except json.JSONDecodeError: 
        # just in case for apostrophes
        try:
            obj = json.loads(block.replace("'", '"'))
        except json.JSONDecodeError:
            return None

    labels = obj.get("tail_labels", None)
    if labels is None:
        return None
    if not isinstance(labels, list):
        return None

    labels = [x.strip() for x in labels if isinstance(x, str) and x.strip()]
    return labels[:k]


In [49]:
def gold_in_retrieved(gold_tail_id: str, retrieved_facts: list[str]) -> bool:
    gold_label = qid_to_label.get(gold_tail_id, gold_tail_id)
    gl = gold_label.lower()
    return any(gl in f.lower() for f in retrieved_facts)

In [60]:
def eval_open_world_hits_1_3(df, rag_answer, max_examples=None, log_every=100, rag_threshold=1):
    h1 = 0
    h3 = 0
    total = 0
    parse_fail = 0
    empty = 0
    no_map = 0

    rag_cand = 0
    rag_open = 0
    recall_hits = 0
    cand_h1 = 0
    cand_h3 = 0
    open_h1 = 0
    open_h3 = 0

    N = len(df) if max_examples is None else min(len(df), max_examples)

    for i in range(N):
        row = df.iloc[i]
        ts = int(row["ts"])
        head_id = str(row["head"])
        rel_id  = str(row["relation_type"])
        gold_tail = str(row["tail"])

        answer_ver = None

        if rag_answer == rag_answer_hybrid_top3:
            pred_labels, retrieved_facts, raw, answer_ver = rag_answer(head_id, rel_id, ts, k=24, rag_threshold=rag_threshold)
    
            if answer_ver == "candidates":
                rag_cand += 1
                # recall@k
                if gold_in_retrieved(gold_tail, retrieved_facts):
                    recall_hits += 1
    
            if answer_ver == "open_world":
                rag_open += 1
        else:
            pred_labels, retrieved_facts, raw = rag_answer(head_id, rel_id, ts, k=24)
            
        if pred_labels is None:
            parse_fail += 1
            total += 1
            continue
        if len(pred_labels) == 0:
            empty += 1
            total += 1
            continue

        # mapping label -> qids
        qids_1 = tail_label_to_qids(pred_labels[0])
        qids_3 = set()
        for lab in pred_labels[:3]:
            qids_3.update(tail_label_to_qids(lab))

        if (not qids_1) and (not qids_3):
            no_map += 1

        if gold_tail in qids_1:
            h1 += 1
            if answer_ver == "candidates":
                cand_h1 += 1
            if answer_ver == "open_world":
                open_h1 += 1

            
        if gold_tail in qids_3:
            h3 += 1
            if answer_ver == "candidates":
                cand_h3 += 1
            if answer_ver == "open_world":
                open_h3 += 1

        total += 1
        if log_every and (i+1) % log_every == 0:
            print(f"{i+1}/{N} H@1={h1/total:.3f} H@3={h3/total:.3f} parse_fail={parse_fail} empty={empty} no_map={no_map}")

    return {"hits@1": h1/total, "hits@3": h3/total, "total": total, "parse_fail": parse_fail, 
            "empty": empty, "no_map": no_map, 
            "rag_candidates": rag_cand, "rag_open_world": rag_open,
            "hits@1_candidates": cand_h1 / max(rag_cand, 1), "hits@1_open_world": open_h1 / max(rag_open, 1),
            "hits@3_candidates": cand_h3 / max(rag_cand, 1), "hits@3_open_world": open_h3 / max(rag_open, 1),
            "recall@k": recall_hits / max(rag_cand, 1)}

### Comparison tests on 500 sample - hybrid with different rag switch thresholds (1, 3, 5)

In [51]:
sample_df = test_df.sample(n=500, random_state=42).reset_index(drop=True)

In [56]:
results = eval_open_world_hits_1_3(df=sample_df, rag_answer=rag_answer_hybrid_top3, 
                                   max_examples=len(sample_df), log_every=20, rag_threshold=1)
results

20/500 H@1=0.450 H@3=0.500 parse_fail=0 empty=0 no_map=0
40/500 H@1=0.375 H@3=0.425 parse_fail=0 empty=0 no_map=0
60/500 H@1=0.367 H@3=0.433 parse_fail=0 empty=0 no_map=0
80/500 H@1=0.312 H@3=0.400 parse_fail=0 empty=0 no_map=1
100/500 H@1=0.330 H@3=0.400 parse_fail=0 empty=0 no_map=2
120/500 H@1=0.342 H@3=0.417 parse_fail=0 empty=0 no_map=2
140/500 H@1=0.329 H@3=0.407 parse_fail=0 empty=0 no_map=2
160/500 H@1=0.312 H@3=0.388 parse_fail=0 empty=0 no_map=5
180/500 H@1=0.328 H@3=0.400 parse_fail=0 empty=0 no_map=5
200/500 H@1=0.330 H@3=0.405 parse_fail=0 empty=0 no_map=5
220/500 H@1=0.345 H@3=0.414 parse_fail=0 empty=0 no_map=8
240/500 H@1=0.325 H@3=0.392 parse_fail=0 empty=0 no_map=9
260/500 H@1=0.319 H@3=0.381 parse_fail=0 empty=0 no_map=9
280/500 H@1=0.325 H@3=0.386 parse_fail=0 empty=0 no_map=10
300/500 H@1=0.330 H@3=0.390 parse_fail=0 empty=0 no_map=10
320/500 H@1=0.334 H@3=0.391 parse_fail=0 empty=0 no_map=10
340/500 H@1=0.326 H@3=0.385 parse_fail=0 empty=0 no_map=10
360/500 H@1=0.

{'hits@1': 0.342,
 'hits@3': 0.39,
 'total': 500,
 'parse_fail': 0,
 'empty': 0,
 'no_map': 15,
 'rag_candidates': 376,
 'rag_open_world': 124,
 'hits@1_candidates': 0.43882978723404253,
 'hits@1_open_world': 0.04838709677419355,
 'hits@3_candidates': 0.4946808510638298,
 'hits@3_open_world': 0.07258064516129033,
 'recall@k': 0.5212765957446809}

In [57]:
results = eval_open_world_hits_1_3(df=sample_df, rag_answer=rag_answer_hybrid_top3, 
                                   max_examples=len(sample_df), log_every=20, rag_threshold=3)
results

20/500 H@1=0.400 H@3=0.500 parse_fail=0 empty=0 no_map=0
40/500 H@1=0.350 H@3=0.425 parse_fail=0 empty=0 no_map=0
60/500 H@1=0.350 H@3=0.433 parse_fail=0 empty=0 no_map=0
80/500 H@1=0.300 H@3=0.400 parse_fail=0 empty=0 no_map=1
100/500 H@1=0.320 H@3=0.400 parse_fail=0 empty=0 no_map=2
120/500 H@1=0.333 H@3=0.417 parse_fail=0 empty=0 no_map=2
140/500 H@1=0.314 H@3=0.407 parse_fail=0 empty=0 no_map=2
160/500 H@1=0.287 H@3=0.381 parse_fail=1 empty=0 no_map=5
180/500 H@1=0.306 H@3=0.394 parse_fail=1 empty=0 no_map=5
200/500 H@1=0.305 H@3=0.400 parse_fail=1 empty=0 no_map=5
220/500 H@1=0.318 H@3=0.409 parse_fail=1 empty=0 no_map=8
240/500 H@1=0.296 H@3=0.383 parse_fail=1 empty=0 no_map=9
260/500 H@1=0.292 H@3=0.373 parse_fail=1 empty=0 no_map=9
280/500 H@1=0.296 H@3=0.375 parse_fail=1 empty=0 no_map=11
300/500 H@1=0.300 H@3=0.380 parse_fail=1 empty=0 no_map=11
320/500 H@1=0.306 H@3=0.381 parse_fail=1 empty=0 no_map=11
340/500 H@1=0.300 H@3=0.379 parse_fail=1 empty=0 no_map=11
360/500 H@1=0.

{'hits@1': 0.324,
 'hits@3': 0.388,
 'total': 500,
 'parse_fail': 3,
 'empty': 0,
 'no_map': 18,
 'rag_candidates': 131,
 'rag_open_world': 369,
 'hits@1_candidates': 0.25190839694656486,
 'hits@1_open_world': 0.34959349593495936,
 'hits@3_candidates': 0.3816793893129771,
 'hits@3_open_world': 0.3902439024390244,
 'recall@k': 0.42748091603053434}

In [58]:
results = eval_open_world_hits_1_3(df=sample_df, rag_answer=rag_answer_hybrid_top3, 
                                   max_examples=len(sample_df), log_every=20, rag_threshold=5)
results

20/500 H@1=0.400 H@3=0.500 parse_fail=0 empty=0 no_map=0
40/500 H@1=0.350 H@3=0.425 parse_fail=0 empty=0 no_map=0
60/500 H@1=0.350 H@3=0.433 parse_fail=0 empty=0 no_map=0
80/500 H@1=0.287 H@3=0.388 parse_fail=0 empty=0 no_map=1
100/500 H@1=0.310 H@3=0.390 parse_fail=0 empty=0 no_map=2
120/500 H@1=0.325 H@3=0.408 parse_fail=0 empty=0 no_map=2
140/500 H@1=0.307 H@3=0.400 parse_fail=0 empty=0 no_map=2
160/500 H@1=0.281 H@3=0.375 parse_fail=1 empty=0 no_map=5
180/500 H@1=0.300 H@3=0.389 parse_fail=1 empty=0 no_map=5
200/500 H@1=0.305 H@3=0.395 parse_fail=1 empty=0 no_map=5
220/500 H@1=0.318 H@3=0.405 parse_fail=1 empty=0 no_map=8
240/500 H@1=0.296 H@3=0.379 parse_fail=1 empty=0 no_map=9
260/500 H@1=0.292 H@3=0.369 parse_fail=1 empty=0 no_map=9
280/500 H@1=0.296 H@3=0.371 parse_fail=1 empty=0 no_map=11
300/500 H@1=0.300 H@3=0.377 parse_fail=1 empty=0 no_map=11
320/500 H@1=0.306 H@3=0.378 parse_fail=1 empty=0 no_map=11
340/500 H@1=0.300 H@3=0.376 parse_fail=1 empty=0 no_map=11
360/500 H@1=0.

{'hits@1': 0.32,
 'hits@3': 0.386,
 'total': 500,
 'parse_fail': 3,
 'empty': 0,
 'no_map': 18,
 'rag_candidates': 58,
 'rag_open_world': 442,
 'hits@1_candidates': 0.1896551724137931,
 'hits@1_open_world': 0.33710407239819007,
 'hits@3_candidates': 0.3103448275862069,
 'hits@3_open_world': 0.39592760180995473,
 'recall@k': 0.41379310344827586}

### open-world

In [35]:
results = eval_open_world_hits_1_3(df=sample_df, rag_answer=rag_answer_open_world_top3, max_examples=len(sample_df), log_every=50)
results

50/500 H@1=0.320 H@3=0.420 parse_fail=0 empty=0 no_map=0
100/500 H@1=0.300 H@3=0.390 parse_fail=0 empty=0 no_map=2
150/500 H@1=0.287 H@3=0.387 parse_fail=1 empty=0 no_map=3
200/500 H@1=0.300 H@3=0.395 parse_fail=1 empty=0 no_map=5
250/500 H@1=0.288 H@3=0.380 parse_fail=1 empty=0 no_map=9
300/500 H@1=0.297 H@3=0.380 parse_fail=1 empty=0 no_map=11
350/500 H@1=0.300 H@3=0.380 parse_fail=2 empty=0 no_map=11
400/500 H@1=0.295 H@3=0.372 parse_fail=2 empty=0 no_map=15
450/500 H@1=0.302 H@3=0.378 parse_fail=3 empty=0 no_map=16
500/500 H@1=0.316 H@3=0.388 parse_fail=3 empty=0 no_map=18


{'hits@1': 0.316,
 'hits@3': 0.388,
 'total': 500,
 'parse_fail': 3,
 'empty': 0,
 'no_map': 18}

## 3. Cached evaluation on full test split

In [None]:
import os, json, time
from pathlib import Path

def count_lines(path):
    if not os.path.exists(path):
        return 0
    with open(path, "r", encoding="utf-8") as f:
        return sum(1 for _ in f)

def append_jsonl(path, obj):
    with open(path, "a", encoding="utf-8") as f:
        f.write(json.dumps(obj, ensure_ascii=False) + "\n")

def eval_checkpoint_generic(
    df,
    rag_answer,
    out_jsonl="runs/eval_results.jsonl",
    k_retrieval=24,
    max_examples=None,
    start_from="auto",
    log_every=100,
    debug_first_n=3,
    store_raw=False,  
):
    Path(os.path.dirname(out_jsonl) or ".").mkdir(parents=True, exist_ok=True)
    N = len(df) if max_examples is None else min(len(df), max_examples)

    start_i = count_lines(out_jsonl) if start_from == "auto" else int(start_from)
    print(f"Will process [{start_i} .. {N-1}] (N={N}). Output -> {out_jsonl}")

    debug_printed = 0
    t0 = time.time()

    for i in range(start_i, N):
        row = df.iloc[i]
        ts = int(row["ts"])
        head_id = str(row["head"])
        rel_id  = str(row["relation_type"])
        gold_tail = str(row["tail"])

        orig_idx = int(row["orig_idx"])

        # --- model call ---
        pred_labels, retrieved_facts, raw = rag_answer(head_id, rel_id, ts, k=k_retrieval)
        # If rag_answer returns raw in format tail_label (single), not tail_labels,
        # we parse it here:
        if pred_labels is None:
            pred_labels = extract_tail_labels_topk(raw, k=3)

        status = "ok"
        if pred_labels is None:
            status = "parse_fail"
        elif len(pred_labels) == 0:
            status = "empty"

        if status == "parse_fail" and debug_printed < debug_first_n:
            print("RAW (parse_fail example):\n", raw)
            debug_printed += 1

        rec = {
            "i": i,
            "orig_idx": orig_idx,
            "ts": ts,
            "head": head_id,
            "rel": rel_id,
            "gold_tail": gold_tail,
            "pred_labels": pred_labels,   # None / [] / ["a","b","c"]
            "status": status,
            "k": k_retrieval,
        }
        if store_raw:
            rec["raw"] = raw

        append_jsonl(out_jsonl, rec)

        if log_every and (i + 1) % log_every == 0:
            dt = time.time() - t0
            speed = (i + 1 - start_i) / max(dt, 1e-9)
            eta = (N - (i + 1)) / max(speed, 1e-9)
            print(f"{i+1}/{N} saved | speed={speed:.3f} ex/s | ETA~{eta/60:.1f} min")

    print("Done.")


In [None]:
eval_checkpoint_generic(
    test_df,                        # full test
    rag_answer=rag_answer_hybrid_top3,
    out_jsonl="runs/hybrid_k24.jsonl",
    k_retrieval=24,
    start_from="auto",
    log_every=100,
    store_raw=False
)


In [None]:
def summarize_jsonl(path):
    h1 = 0
    h3 = 0
    total = 0
    parse_fail = 0
    empty = 0
    no_map = 0

    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            rec = json.loads(line)
            gold_tail = str(rec["gold_tail"])
            pred_labels = rec.get("pred_labels", None)
            status = rec.get("status", None)

            total += 1

            if status == "parse_fail" or pred_labels is None:
                parse_fail += 1
                continue
            if status == "empty" or len(pred_labels) == 0:
                empty += 1
                continue

            # mapping label -> QIDs
            qids_1 = tail_label_to_qids(pred_labels[0])
            qids_3 = set()
            for lab in pred_labels[:3]:
                qids_3.update(tail_label_to_qids(lab))

            if (not qids_1) and (not qids_3):
                no_map += 1

            if gold_tail in qids_1:
                h1 += 1
            if gold_tail in qids_3:
                h3 += 1

    return {
        "hits@1": h1 / max(total, 1),
        "hits@3": h3 / max(total, 1),
        "total": total,
        "parse_fail": parse_fail,
        "empty": empty,
        "no_map": no_map
    }