## This is a scratch notebook for LLM + RAG

It has a lot of different versions that we tested as well as functions for debugging. If you want a cleaner version use TKGC_RAG_clean.ipynb or RAG_full.py

In [2]:
import socket, os
print("hostname:", socket.gethostname())
print("CUDA_VISIBLE_DEVICES:", os.environ.get("CUDA_VISIBLE_DEVICES"))


hostname: node19.enst.fr
CUDA_VISIBLE_DEVICES: 0


In [3]:
import sys
print(sys.executable)


/home/infres/mporwisz-25/miniconda3/envs/tkgc_rag/bin/python


## 0. Setup

In [4]:
import warnings
warnings.filterwarnings("ignore")

In [5]:
import torch

from transformers import (
    BitsAndBytesConfig,
    AutoTokenizer,
    AutoModelForCausalLM,
    GenerationConfig
)

# LLM: https://huggingface.co/Qwen/Qwen3-4B-Instruct-2507
llm_name = "Qwen/Qwen3-4B-Instruct-2507"

# We want to use 4bit quantization to save memory (in case some of you use their own computer)
quantization_config = BitsAndBytesConfig(
    load_in_8bit=False, load_in_4bit=True
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(llm_name, padding_side="left")

# Prevent some transformers specific issues.
tokenizer.use_default_system_prompt = False
tokenizer.pad_token_id = tokenizer.eos_token_id

# Load LLM.
llm = AutoModelForCausalLM.from_pretrained(
    llm_name,
    quantization_config=quantization_config,
    device_map={"": 0}, # load all the model layers on GPU 0
    torch_dtype=torch.bfloat16, # float precision
)

# Set LLM on eval mode.
llm.eval()

`torch_dtype` is deprecated! Use `dtype` instead!
Loading weights: 100%|███████████████████████| 398/398 [00:10<00:00, 37.82it/s, Materializing param=model.norm.weight]


Qwen3ForCausalLM(
  (model): Qwen3Model(
    (embed_tokens): Embedding(151936, 2560)
    (layers): ModuleList(
      (0-35): 36 x Qwen3DecoderLayer(
        (self_attn): Qwen3Attention(
          (q_proj): Linear4bit(in_features=2560, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=2560, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=2560, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=2560, bias=False)
          (q_norm): Qwen3RMSNorm((128,), eps=1e-06)
          (k_norm): Qwen3RMSNorm((128,), eps=1e-06)
        )
        (mlp): Qwen3MLP(
          (gate_proj): Linear4bit(in_features=2560, out_features=9728, bias=False)
          (up_proj): Linear4bit(in_features=2560, out_features=9728, bias=False)
          (down_proj): Linear4bit(in_features=9728, out_features=2560, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): Qwen3RMSNorm((2560,), eps=1e-06)
 

In [6]:
from transformers import set_seed
import random
import numpy as np
import torch

SEED = 42
set_seed(SEED)

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [7]:
import json

qid_to_label = json.load(open("qid_to_label.json"))
pid_to_label = json.load(open("pid_to_label.json"))
norm_label_to_qids = json.load(open("norm_label_to_qids.json"))


In [8]:
# import re
# from langchain_core.documents import Document

# ENTITY_RE = re.compile(r"^Entity:\s*(.+?)\s*$")
# FACT_RE = re.compile(r"^-+\s*(.+?)\s*$")  # linie starting with "- "

# def load_facts_txt(path: str):
#     docs = []
#     current_entity = None

#     with open(path, "r", encoding="utf-8") as f:
#         for line in f:
#             line = line.strip()
#             if not line:
#                 continue

#             m_ent = ENTITY_RE.match(line)
#             if m_ent:
#                 current_entity = m_ent.group(1).strip()
#                 continue

#             m_fact = FACT_RE.match(line)
#             if m_fact and current_entity:
#                 sentence = m_fact.group(1).strip()
#                 docs.append(Document(
#                     page_content=sentence,
#                     metadata={"head_label": current_entity}
#                 ))
#     return docs


In [9]:
# docs = load_facts_txt("rag_documents.txt")  # Twoja funkcja parsera

## 2. Retrieval-Augmented Generation (RAG)

In [10]:
import os
import re
import torch
from typing import List, Tuple
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import SentenceTransformerEmbeddings

USER_AGENT environment variable not set, consider setting it to identify your requests.


In [11]:
# # Create an embedding model for similarity search/ retrieval.

# import chromadb

# embeddings = SentenceTransformerEmbeddings(model_name="nomic-ai/nomic-embed-text-v1", model_kwargs={"trust_remote_code":True})

# text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
# template = """
# Question: \n{question}
# Context: \n{context}
# """

# def create_retriever(docs, k=8):
#     if not docs:
#         return None
#     #splits = text_splitter.split_documents(docs)
#     vectorstore = Chroma.from_documents(
#         documents=docs,
#         embedding=embeddings,
#         collection_name="smallpedia_facts",
#         persist_directory="./chroma_smallpedia"
#     )
#     vectorstore.persist()

#     return vectorstore.as_retriever(search_kwargs={"k": k})

In [12]:
# retriever = create_retriever(docs, k=24)

In [13]:
import pandas as pd

df = pd.read_csv("tkgl-smallpedia_edgelist.csv")
df["ts"] = df["ts"].astype(int)

train_df = df[df["ts"] < 2008].copy()
test_df  = df[df["ts"] >= 2008].copy()

print(len(train_df), len(test_df))

468790 81586


In [14]:
from collections import defaultdict

def build_hr_index(df):
    idx = defaultdict(list)
    for _, row in df.iterrows():
        h = str(row["head"])
        r = str(row["relation_type"])
        t = str(row["tail"])
        ts = int(row["ts"])
        idx[(h, r)].append((ts, t))
    # sort by ts
    for k in idx:
        idx[k].sort(key=lambda x: x[0])
    return idx

hr_index = build_hr_index(train_df)


In [15]:
def build_head_index(df):
    idx = defaultdict(list)
    for _, row in df.iterrows():
        h = str(row["head"])
        r = str(row["relation_type"])
        t = str(row["tail"])
        ts = int(row["ts"])
        idx[h].append((ts, r, t))
    # sort by ts
    for k in idx:
        idx[k].sort(key=lambda x: x[0])
    return idx

head_index = build_head_index(train_df)

In [16]:
def build_entity_index(df):
    idx = defaultdict(list)
    for _, row in df.iterrows():
        h = str(row["head"])
        r = str(row["relation_type"])
        t = str(row["tail"])
        ts = int(row["ts"])

        idx[h].append((ts, h, r, t))
        idx[t].append((ts, h, r, t))

    for e in idx:
        idx[e].sort(key=lambda x: x[0])
    return idx

entity_index = build_entity_index(train_df)


In [17]:
def build_rel_index(df):
    idx = defaultdict(list)
    for _, row in df.iterrows():
        ts = int(row["ts"])
        h = str(row["head"])
        r = str(row["relation_type"])
        t = str(row["tail"])
        idx[r].append((ts, h, t))
    for r in idx:
        idx[r].sort(key=lambda x: x[0])
    return idx

rel_index = build_rel_index(train_df)

In [18]:
def retrieve_relation_examples(rel_id, ts, k=12):
    facts_r = rel_index.get(rel_id, [])
    if not facts_r:
        return []

    ranked = sorted(facts_r, key=lambda x: abs(x[0] - ts))[:k]

    txt = []
    for y, h, t in ranked:
        txt.append(
            f"In {y}, {qid_to_label.get(h,h)} {pid_to_label.get(rel_id,rel_id)} {qid_to_label.get(t,t)}."
        )
    return txt


In [19]:
def retrieve_facts_new(head_id, rel_id, ts, k=12):
    out = []

    # 1) exact match: (head, rel) -> (ts, h, r, t)
    facts_hr = hr_index.get((head_id, rel_id), [])  # list[(ts, tail_id)]
    if facts_hr:
        ranked = sorted(facts_hr, key=lambda x: abs(x[0] - ts))
        out.extend([(y, head_id, rel_id, tail_id) for (y, tail_id) in ranked])

    # 2) fallback: (head, *) -> already (ts, r, t) OR (ts, r, tail_id)
    if len(out) < k:
        facts_h = head_index.get(head_id, [])  # list[(ts, r, tail_id)]
        if facts_h:
            ranked_h = sorted(facts_h, key=lambda x: abs(x[0] - ts))
            out.extend([(y, head_id, r, tail_id) for (y, r, tail_id) in ranked_h])

    # 3) fallback: facts about entity (as head OR tail)
    if len(out) < k:
        facts_e = entity_index.get(head_id, [])  # list[(ts, h, r, t)]
        if facts_e:
            ranked_e = sorted(facts_e, key=lambda x: abs(x[0] - ts))
            # add only the number missing
            need = k - len(out)
            out.extend(ranked_e[:need])

    # 4) relation-only examples (global)
    if len(out) < k:
        facts_r = rel_index.get(rel_id, [])  # [(ts, h, t)]
        ranked_r = sorted(facts_r, key=lambda x: abs(x[0] - ts))
        need = k - len(out)
        out.extend([(y, h, rel_id, t) for (y, h, t) in ranked_r[:need]])

    out = out[:k]

    # change to text
    facts_txt = []
    for y, h, r, t in out:
        h_label = qid_to_label.get(h, h)
        r_label = pid_to_label.get(r, r)
        t_label = qid_to_label.get(t, t)
        facts_txt.append(f"In {y}, {h_label} {r_label} {t_label}.")

    return facts_txt


In [20]:
# def retrieve_facts_hr(head_id, rel_id, ts, k=12):
#     facts = hr_index.get((head_id, rel_id), [])
#     if len(facts)<k:
#         facts.append(hr_index.get(head_id, []))
#     if not facts:
#         return []

#     # choose k closest |year - ts|
#     ranked = sorted(facts, key=lambda x: abs(x[0] - ts))
#     ranked = ranked[:k]

#     # change to text for prompt
#     head_label = qid_to_label.get(head_id, head_id)
#     rel_label  = pid_to_label.get(rel_id, rel_id)
#     out = []
#     for y, tail_id in ranked:
#         tail_label = qid_to_label.get(tail_id, tail_id)
#         out.append(f"In {y}, {head_label} {rel_label} {tail_label}.")
#     return out


In [21]:
# def build_query_text(head_id, rel_id, ts):
#     head = qid_to_label.get(head_id, head_id)
#     rel  = pid_to_label.get(rel_id, rel_id)
#     return f"Query: In {ts}, what is the tail for ({head}, {rel}, ?) ?"

def normalize_label(s: str) -> str:
    s = (s or "").strip().lower()
    s = re.sub(r"\s+", " ", s)
    s = re.sub(r"\.$", "", s)
    return s

def tail_label_to_qids(tail_label: str):
    n = normalize_label(tail_label)
    return norm_label_to_qids.get(n, [])


In [22]:
# # Define the function to call the qwen model
# def qwen_llm(prompt):
#     # 1. Instruction prompt
#     system_prompt = (
#       "You are solving Temporal Knowledge Graph Completion.\n"
#       "Given a query (head, relation, timestamp) and the retrieved facts, "
#       "select the most likely tail entity.\n"
#         "Rules:\n"
#         "1) Choose the tail entity ONLY from the retrieved facts.\n"
#         "2) Copy the tail label EXACTLY as it appears in the facts.\n"
#         '3) Return ONLY JSON: {"tail_label": "..."}.\n'
#         '4) If no fact matches, return {"tail_label": null}.'
#       # "Choose the tail entity ONLY from the retrieved facts (copy the exact tail string). Do not invent new entities.\n"
#       # "Return ONLY JSON in this format:\n"
#       # '{"tail_label": "..."}\n'
#       # "Do not add any extra text."
#     )


#     messages = [
#         {"role": "system", "content": system_prompt},
#         {'role': 'user', 'content': prompt}
#     ]

#     inputs = tokenizer.apply_chat_template(
#         messages,
#         add_generation_prompt=True,
#         return_tensors="pt",
#         return_dict=True
#         ).to(llm.device)

#     outputs = llm.generate(**inputs, max_new_tokens=64, do_sample=False, temperature=0.0)
#     raw = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True).strip()
#     return raw

In [23]:
# def extract_tail_label(raw: str):
#     """
#     Wyciąga tail_label z odpowiedzi. Działa nawet jeśli LLM doda śmieci dookoła.
#     """
#     # spróbuj znaleźć pierwszego JSON-a w tekście
#     m = re.search(r"\{.*\}", raw, flags=re.DOTALL)
#     if not m:
#         return None
#     try:
#         obj = json.loads(m.group(0))
#         return obj.get("tail_label")
#     except json.JSONDecodeError:
#         return None

In [24]:
# YEAR_RE = re.compile(r"\b(1[0-9]{3}|20[0-9]{2})\b")

# def extract_year(text: str):
#     m = YEAR_RE.search(text)
#     return int(m.group(0)) if m else None

In [25]:
# def build_user_prompt(head_label: str, rel_label: str, ts: int, retrieved_facts: list[str]):
#     facts_block = "\n".join([f"- {f}" for f in retrieved_facts])
#     return (
#         f"Query:\n"
#         f"Time: {ts}\n"
#         f"Head: {head_label}\n"
#         f"Relation: {rel_label}\n\n"
#         f"Retrieved facts:\n{facts_block}\n\n"
#         f"Task: Return the most likely tail entity label as JSON."
#     )


# # def rag_answer(head_id: str, rel_id: str, ts: int):
# #     head_label = qid_to_label.get(head_id, head_id)
# #     rel_label  = pid_to_label.get(rel_id, rel_id)

# #     query_text = f"In {ts}, {head_label} {rel_label}"
# #     retrieved_docs = retriever.invoke(query_text)
# #     retrieved_facts = [d.page_content for d in retrieved_docs]

# #     user_prompt = build_user_prompt(head_label, rel_label, ts, retrieved_facts)
# #     raw = qwen_llm(user_prompt)
# #     tail_label = extract_tail_label(raw)
# #     return tail_label, retrieved_facts, raw


In [26]:
# def rag_answer(head_id: str, rel_id: str, ts: int, k_llm=12, k_retrieve=80):
#     head_label = qid_to_label.get(head_id, head_id)
#     rel_label  = pid_to_label.get(rel_id, rel_id)

#     # query pod podobieństwo tekstowe
#     query_text = f"In {ts}, {head_label} {rel_label}"

#     # 1) Load many candidates
#     retrieved_docs = retriever.invoke(query_text)
#     #retrieved_facts = [d.page_content for d in retrieved_docs]
#     retrieved_facts = [f for f in retrieved_facts if rel_label in f]


#     # 2) reranking by timestamp
#     scored = []
#     for fact in retrieved_facts:
#         y = extract_year(fact)
#         # no year -> give huge penalty
#         time_dist = abs(y - ts) if y is not None else 10**9
#         scored.append((time_dist, fact))

#     scored.sort(key=lambda x: x[0])
#     reranked_facts = [f for _, f in scored][:k_llm]

#     user_prompt = build_user_prompt(head_label, rel_label, ts, reranked_facts)
#     raw = qwen_llm(user_prompt)
#     tail_label = extract_tail_label(raw)
#     return tail_label, reranked_facts, raw


In [27]:
# def rag_answer_heuristic(head_id: str, rel_id: str, ts: int, k=12):
#     head_label = qid_to_label.get(head_id, head_id)
#     rel_label  = pid_to_label.get(rel_id, rel_id)

#     retrieved_facts = retrieve_facts_new(head_id, rel_id, ts, k=k)

#     user_prompt = build_user_prompt(head_label, rel_label, ts, retrieved_facts)
#     raw = qwen_llm(user_prompt)
#     tail_label = extract_tail_label(raw)
#     return tail_label, retrieved_facts, raw


## 3. Experiments

In [28]:
# import re

# def normalize_label(s: str) -> str:
#     s = (s or "").strip().lower()
#     s = re.sub(r"\s+", " ", s)
#     s = re.sub(r"\.$", "", s)
#     return s

# def tail_label_to_qids(tail_label: str):
#     if not tail_label:
#         return []
#     return norm_label_to_qids.get(normalize_label(tail_label), [])

# def gold_in_retrieved(gold_tail_id: str, retrieved_facts: list[str]) -> bool:
#     gold_label = qid_to_label.get(gold_tail_id, gold_tail_id)
#     gl = gold_label.lower()
#     return any(gl in f.lower() for f in retrieved_facts)

# train_tails_set = set(train_df["tail"].astype(str))

# def gold_seen_in_train(gold_tail_id: str) -> bool:
#     return gold_tail_id in train_tails_set


# def eval_hits_at_1(df, max_examples=None, log_every=50):
#     correct = 0
#     total = 0
#     n_no_json = 0
#     n_no_map = 0
#     recall_hits = 0
#     gold_tail_train = 0

#     N = len(df) if max_examples is None else min(len(df), max_examples)

#     for i in range(N):
#         row = df.iloc[i]
#         ts = int(row["ts"])
#         head_id = str(row["head"])
#         rel_id  = str(row["relation_type"])
#         gold_tail = str(row["tail"])

#         pred_label, retrieved_facts, raw = rag_answer_heuristic(head_id, rel_id, ts)

#         if gold_in_retrieved(gold_tail, retrieved_facts):
#             recall_hits += 1

#         if pred_label is None:
#             n_no_json += 1
#             total += 1
#             continue

#         pred_qids = tail_label_to_qids(pred_label)
#         if not pred_qids:
#             n_no_map += 1

#         if gold_tail in pred_qids:
#             correct += 1

#         if gold_seen_in_train(gold_tail):
#             gold_tail_train += 1

#         total += 1

#         if log_every and (i+1) % log_every == 0:
#             print(f"{i+1}/{N}  Hits@1={correct/total:.3f}  no_json={n_no_json}  no_map={n_no_map}")

#     return {
#         "hits@1": correct / max(total, 1),
#         "recall@k": recall_hits / max(total, 1),
#         "total": total,
#         "no_json": n_no_json,
#         "no_map": n_no_map,
#         "gold_seen_in_train": gold_tail_train
#     }


In [29]:
# print("retriever:", retriever is not None)
# print("qid_to_label size:", len(qid_to_label))
# print("pid_to_label size:", len(pid_to_label))
# print("norm_label_to_qids size:", len(norm_label_to_qids))


In [30]:
# row = test_df.iloc[0]

# ts = int(row["ts"])
# head_id = str(row["head"])
# rel_id  = str(row["relation_type"])
# gold_tail = str(row["tail"])

# print("GOLD:")
# print("ts:", ts)
# print("head:", head_id, "->", qid_to_label.get(head_id, head_id))
# print("rel :", rel_id, "->", pid_to_label.get(rel_id, rel_id))
# print("tail:", gold_tail, "->", qid_to_label.get(gold_tail, gold_tail))


In [31]:
# pred_label, retrieved_facts, raw = rag_answer_heuristic(head_id, rel_id, ts)

# print("\nLLM raw output:")
# print(raw)

# print("\nPredicted tail_label:", pred_label)

# print("\nTop retrieved facts:")
# for i, f in enumerate(retrieved_facts[:10]):
#     print(f"{i+1}. {f}")


In [32]:
# def eval_with_nomap_examples(df, max_examples=100, keep=20, show_k=10, log_every=20):
#     correct = 0
#     total = 0
#     n_no_json = 0
#     n_no_map = 0
#     recall_hits = 0
#     seen_in_train = 0
#     no_retrieval = 0

#     nomap_examples = []   # <-- tu zbieramy przypadki

#     N = min(len(df), max_examples)

#     for i in range(N):
#         row = df.iloc[i]
#         ts = int(row["ts"])
#         head_id = str(row["head"])
#         rel_id  = str(row["relation_type"])
#         gold_tail = str(row["tail"])

#         if gold_tail in train_tails_set:
#             seen_in_train += 1

#         pred_label, retrieved_facts, raw = rag_answer_candidates(head_id, rel_id, ts, k=12)

#         if len(retrieved_facts) == 0:
#             no_retrieval += 1

#         # recall@k
#         if gold_in_retrieved(gold_tail, retrieved_facts):
#             recall_hits += 1

#         if pred_label is None:
#             n_no_json += 1
#             # też warto logować brak JSON
#             if len(nomap_examples) < keep:
#                 nomap_examples.append({
#                     "reason": "no_json_or_null",
#                     "ts": ts,
#                     "head_id": head_id,
#                     "head_label": qid_to_label.get(head_id, head_id),
#                     "rel_id": rel_id,
#                     "rel_label": pid_to_label.get(rel_id, rel_id),
#                     "gold_tail_id": gold_tail,
#                     "gold_tail_label": qid_to_label.get(gold_tail, gold_tail),
#                     "pred_label": pred_label,
#                     "raw": raw,
#                     "retrieved_facts": retrieved_facts[:show_k],
#                 })
#             total += 1
#             continue

#         pred_qids = tail_label_to_qids(pred_label)

#         if not pred_qids:
#             n_no_map += 1
#             if len(nomap_examples) < keep:
#                 nomap_examples.append({
#                     "reason": "no_map",
#                     "ts": ts,
#                     "head_id": head_id,
#                     "head_label": qid_to_label.get(head_id, head_id),
#                     "rel_id": rel_id,
#                     "rel_label": pid_to_label.get(rel_id, rel_id),
#                     "gold_tail_id": gold_tail,
#                     "gold_tail_label": qid_to_label.get(gold_tail, gold_tail),
#                     "pred_label": pred_label,
#                     "raw": raw,
#                     "retrieved_facts": retrieved_facts[:show_k],
#                 })

#         if gold_tail in pred_qids:
#             correct += 1

#         total += 1

#         if log_every and (i+1) % log_every == 0:
#             print(
#                 f"{i+1}/{N}  Hits@1={correct/total:.3f}  "
#                 f"Recall@k={recall_hits/total:.3f}  no_map={n_no_map}  no_json={n_no_json}"
#             )

#     metrics = {
#         "hits@1": correct / max(total, 1),
#         "recall@k": recall_hits / max(total, 1),
#         "total": total,
#         "no_json": n_no_json,
#         "no_map": n_no_map,
#         "gold_seen_in_train": seen_in_train,
#         "No retrieved facts": no_retrieval
#     }
#     return metrics, nomap_examples


In [33]:
# metrics, examples = eval_with_nomap_examples(test_df, max_examples=100)
# metrics

In [34]:
# for j, ex in enumerate(examples[:5], 1):
#     print("="*80)
#     print(f"Example {j} | reason={ex['reason']}")
#     print(f"ts={ex['ts']}")
#     print(f"head: {ex['head_id']} -> {ex['head_label']}")
#     print(f"rel : {ex['rel_id']} -> {ex['rel_label']}")
#     print(f"gold tail: {ex['gold_tail_id']} -> {ex['gold_tail_label']}")
#     print(f"pred tail_label: {ex['pred_label']}")
#     print("\nTop retrieved facts:")
#     for i, f in enumerate(ex["retrieved_facts"], 1):
#         print(f"{i}. {f}")
#     print("\nRaw LLM:")
#     print(ex["raw"])


In [35]:
# res = eval_hits_at_1(test_df, max_examples=100, log_every=20)
# res

In [36]:
# hits_when_recalled = res['hits@1'] / res['recall@k']
# hits_when_recalled


In [37]:
def extract_tail_candidates(retrieved_facts: list[str], head_label: str, rel_label: str):
    """
    Return unique tail labels (strings) in format:
      In YEAR, HEAD REL TAIL.
    """
    candidates = []
    key = f"{head_label} {rel_label} "

    for f in retrieved_facts:
        if key not in f:
            continue

        tail = f.split(key, 1)[1].strip()

        tail = re.sub(r"\.\s*$", "", tail)

        tail = tail.strip()

        if tail:
            candidates.append(tail)

    seen = set()
    uniq = []
    for c in candidates:
        if c not in seen:
            seen.add(c)
            uniq.append(c)
    return uniq


In [38]:
def build_user_prompt_candidates(head_label: str, rel_label: str, ts: int, retrieved_facts: list[str], candidates: list[str]):
    facts_block = "\n".join([f"- {f}" for f in retrieved_facts])

    cand_block = "\n".join([f"{i+1}. {c}" for i, c in enumerate(candidates)])

    return (
        f"Query:\n"
        f"Time: {ts}\n"
        f"Head: {head_label}\n"
        f"Relation: {rel_label}\n\n"
        f"Retrieved facts:\n{facts_block}\n\n"
        f"Candidates (choose exactly one from this list):\n{cand_block}\n\n"
        "Rules:\n"
        "1) You MUST pick exactly one candidate from the list above, or null if none match.\n"
        "2) Return ONLY valid JSON in one line.\n"
        'Output format: {"tail_label": "<EXACT candidate string>"} or {"tail_label": null}\n'
        "3) Do not add any other text."
    )


In [39]:
# Define the function to call the qwen model
def qwen_llm_candidates(prompt):
    # 1. Instruction prompt
    system_prompt = (
        "You solve Temporal Knowledge Graph Completion (TKGC).\n"
        "You will be given a query (head, relation, timestamp), retrieved facts, and a candidate list.\n"
        "Your task is to select the most likely tail entity from the candidate list.\n\n"
        "Rules:\n"
        "1) Choose up to 3 candidates ONLY from the provided candidate list.\n"
        "2) Order them from best to worst.\n"
        "3) Return ONLY valid JSON in ONE line, and nothing else.\n"
        '4) Output format: {"tail_labels": ["...", "...", "..."]}\n'
        '   If no candidate matches, return {"tail_labels": []}.\n'
    )


    messages = [
        {"role": "system", "content": system_prompt},
        {'role': 'user', 'content': prompt}
    ]

    inputs = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt",
        return_dict=True
        ).to(llm.device)

    outputs = llm.generate(**inputs, max_new_tokens=64, do_sample=False, temperature=0.0)
    raw = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True).strip()
    return raw

In [40]:
# def rag_answer_candidates(head_id: str, rel_id: str, ts: int, k=12):
#     head_label = qid_to_label.get(head_id, head_id)
#     rel_label  = pid_to_label.get(rel_id, rel_id)
#
#     retrieved_facts = retrieve_facts_new(head_id, rel_id, ts, k=k)
#
#     if not retrieved_facts:
#         return None, [], '{"tail_label": null}'
#
#     candidates = extract_tail_candidates(retrieved_facts, head_label, rel_label)
#
#     # if there's no tail, don't ask LLM (or fallback to a different prompt)
#     if not candidates:
#         return None, retrieved_facts, '{"tail_label": null}'
#
#     user_prompt = build_user_prompt_candidates(head_label, rel_label, ts, retrieved_facts, candidates)
#
#     raw = qwen_llm_candidates(user_prompt)
#     tail_label = extract_tail_label(raw)
#     return tail_label, retrieved_facts, raw


In [41]:
def build_user_prompt_candidates_top3(head_label, rel_label, ts, retrieved_facts, candidates):
    facts_block = "\n".join([f"- {f}" for f in retrieved_facts])
    cand_block = "\n".join([f"{i+1}. {c}" for i, c in enumerate(candidates)])

    return (
        f"Query:\nTime: {ts}\nHead: {head_label}\nRelation: {rel_label}\n\n"
        f"Retrieved facts:\n{facts_block}\n\n"
        f"Candidates (pick up to 3):\n{cand_block}\n\n"
        "Rules:\n"
        "1) Choose up to 3 candidates ONLY from the candidate list.\n"
        "2) Order them best→worst.\n"
        '3) Return ONLY JSON: {"tail_labels": ["...","...","..."]}.\n'
        '4) If none match, return {"tail_labels": []}.'
    )

def rag_answer_candidates_top3(head_id: str, rel_id: str, ts: int, k=12):
    head_label = qid_to_label.get(head_id, head_id)
    rel_label  = pid_to_label.get(rel_id, rel_id)

    retrieved_facts = retrieve_facts_new(head_id, rel_id, ts, k=k)
    if not retrieved_facts:
        return [], [], '{"tail_labels": []}'

    candidates = extract_tail_candidates(retrieved_facts, head_label, rel_label)
    if not candidates:
        return [], retrieved_facts, '{"tail_labels": []}'

    user_prompt = build_user_prompt_candidates_top3(head_label, rel_label, ts, retrieved_facts, candidates)
    raw = qwen_llm_candidates(user_prompt)
    pred_labels = extract_tail_labels_topk(raw, k=3) or []
    return pred_labels, retrieved_facts, raw


In [42]:
import json, re
JSON_OBJ_RE = re.compile(r"\{.*?\}", re.DOTALL)

def extract_tail_labels_topk(raw: str, k=3):
    if not raw:
        return None
    m = JSON_OBJ_RE.search(raw)
    if not m:
        return None
    try:
        obj = json.loads(m.group(0))
    except json.JSONDecodeError:
        return None

    labels = obj.get("tail_labels", None)
    if labels is None:
        return None

    labels = [x for x in labels if isinstance(x, str) and x.strip()]
    return labels[:k]


In [49]:
def gold_in_retrieved(gold_tail_id: str, retrieved_facts: list[str]) -> bool:
    gold_label = qid_to_label.get(gold_tail_id, gold_tail_id)
    gl = gold_label.lower()
    return any(gl in f.lower() for f in retrieved_facts)

In [43]:
def eval_hits_1_3(df, max_examples=100, log_every=20):
    h1 = 0
    h3 = 0
    total = 0
    parse_fail = 0
    null_count = 0
    debug_printed = 0

    recall_hits = 0

    for i in range(min(len(df), max_examples)):
        row = df.iloc[i]
        ts = int(row["ts"])
        head_id = str(row["head"])
        rel_id  = str(row["relation_type"])
        gold_tail = str(row["tail"])

        pred_label, retrieved_facts, raw = rag_answer_candidates_top3(head_id, rel_id, ts, k=24)
        pred_labels = extract_tail_labels_topk(raw, k=3)

        # recall@k
        if gold_in_retrieved(gold_tail, retrieved_facts):
            recall_hits += 1

        if pred_labels is None and debug_printed < 3:
            print("RAW:\n", raw)
            debug_printed += 1
            
        if pred_labels is None:
            parse_fail += 1
            total += 1
            continue

        if len(pred_labels) == 0:
            null_count += 1
            total += 1
            continue

        # map label -> qids i check gold
        pred_qids_1 = tail_label_to_qids(pred_labels[0])
        pred_qids_3 = set()
        for lab in pred_labels[:3]:
            pred_qids_3.update(tail_label_to_qids(lab))

        if gold_tail in pred_qids_1:
            h1 += 1
        if gold_tail in pred_qids_3:
            h3 += 1

        total += 1
        if log_every and (i+1) % log_every == 0:
            print(f"{i+1}/{max_examples}  H@1={h1/total:.3f}  H@3={h3/total:.3f}  parse_fail={parse_fail}  null={null_count}")

    return {"hits@1": h1/total, "hits@3": h3/total, "total": total, "parse_fail": parse_fail, "null": null_count,
           "recall@k": recall_hits / total}


In [44]:
train_df_filter = df[df["ts"] < 1998].copy()

train_entities = set(pd.concat([train_df_filter["head"], train_df_filter["tail"]]).astype(str).unique())

test_df_filtered = test_df[
    (test_df["head"].astype(str).isin(train_entities)) &
    (test_df["tail"].astype(str).isin(train_entities))
].copy()


In [45]:
test_df_filtered.drop_duplicates()
len(test_df_filtered)

18564

In [47]:
sample_df = test_df_filtered.sample(n=500, random_state=42).reset_index(drop=True)

In [50]:
results = eval_hits_1_3(sample_df, max_examples=100, log_every=20)
results

20/100  H@1=0.600  H@3=0.800  parse_fail=0  null=1
40/100  H@1=0.575  H@3=0.675  parse_fail=0  null=2
60/100  H@1=0.617  H@3=0.683  parse_fail=0  null=5
80/100  H@1=0.625  H@3=0.713  parse_fail=0  null=6
100/100  H@1=0.590  H@3=0.690  parse_fail=0  null=8


{'hits@1': 0.59,
 'hits@3': 0.69,
 'total': 100,
 'parse_fail': 0,
 'null': 8,
 'recall@k': 0.73}

In [69]:
res = eval_hits_1_3(test_df_filtered, max_examples=100)

20/100  H@1=0.200  H@3=0.200  parse_fail=0  null=3
40/100  H@1=0.200  H@3=0.225  parse_fail=0  null=7
100/100  H@1=0.170  H@3=0.240  parse_fail=0  null=19


In [44]:
# metrics, examples = eval_with_nomap_examples(test_df_filtered, max_examples=100)
# metrics

20/100  Hits@1=0.200  Recall@k=0.250  no_map=0  no_json=3
40/100  Hits@1=0.200  Recall@k=0.275  no_map=0  no_json=7
100/100  Hits@1=0.170  Recall@k=0.290  no_map=0  no_json=19


{'hits@1': 0.17,
 'recall@k': 0.29,
 'total': 100,
 'no_json': 19,
 'no_map': 0,
 'gold_seen_in_train': 98,
 'No retrieved facts': 0}

In [45]:
# for j, ex in enumerate(examples[:5], 1):
#     print("="*80)
#     print(f"Example {j} | reason={ex['reason']}")
#     print(f"ts={ex['ts']}")
#     print(f"head: {ex['head_id']} -> {ex['head_label']}")
#     print(f"rel : {ex['rel_id']} -> {ex['rel_label']}")
#     print(f"gold tail: {ex['gold_tail_id']} -> {ex['gold_tail_label']}")
#     print(f"pred tail_label: {ex['pred_label']}")
#     print("\nTop retrieved facts:")
#     for i, f in enumerate(ex["retrieved_facts"], 1):
#         print(f"{i}. {f}")
#     print("\nRaw LLM:")
#     print(ex["raw"])


Example 1 | reason=no_json_or_null
ts=2008
head: Q16019 -> Wolfgang Schäuble
rel : P166 -> award received
gold tail: Q445673 -> Order of Merit of Baden-Württemberg
pred tail_label: None

Top retrieved facts:
1. In 1989, Wolfgang Schäuble position held Federal Minister for Special Affairs of Germany.
2. In 1988, Wolfgang Schäuble position held Federal Minister for Special Affairs of Germany.
3. In 1987, Wolfgang Schäuble position held Federal Minister for Special Affairs of Germany.
4. In 1986, Wolfgang Schäuble position held Federal Minister for Special Affairs of Germany.
5. In 1985, Wolfgang Schäuble position held Federal Minister for Special Affairs of Germany.
6. In 1984, Wolfgang Schäuble position held Federal Minister for Special Affairs of Germany.
7. In 2007, Federal Ministry of the Interior chairperson Wolfgang Schäuble.
8. In 2006, Federal Ministry of the Interior chairperson Wolfgang Schäuble.
9. In 2005, Federal Ministry of the Interior chairperson Wolfgang Schäuble.
10. In

## Cached evaluation on full test split

In [44]:
import os, json, time
from pathlib import Path

def count_lines(path):
    if not os.path.exists(path):
        return 0
    with open(path, "r", encoding="utf-8") as f:
        return sum(1 for _ in f)

def append_jsonl(path, obj):
    with open(path, "a", encoding="utf-8") as f:
        f.write(json.dumps(obj, ensure_ascii=False) + "\n")

def eval_hits_1_3_checkpoint(
    df,
    out_jsonl="eval_results.jsonl",
    k_retrieval=24,
    max_examples=None,
    start_from="auto",   # "auto" or int
    log_every=100,
    debug_first_n=3,
):
    Path(os.path.dirname(out_jsonl) or ".").mkdir(parents=True, exist_ok=True)

    N = len(df) if max_examples is None else min(len(df), max_examples)

    # Resume
    if start_from == "auto":
        start_i = count_lines(out_jsonl)
    else:
        start_i = int(start_from)

    print(f"Will process [{start_i} .. {N-1}] (N={N}). Output -> {out_jsonl}")

    debug_printed = 0
    t0 = time.time()

    for i in range(start_i, N):
        row = df.iloc[i]
        ts = int(row["ts"])
        head_id = str(row["head"])
        rel_id  = str(row["relation_type"])
        gold_tail = str(row["tail"])

        # --- model call ---
        pred_label, retrieved_facts, raw = rag_answer_candidates_top3(head_id, rel_id, ts, k=k_retrieval)
        pred_labels = extract_tail_labels_topk(raw, k=3)

        if pred_labels is None and debug_printed < debug_first_n:
            print("RAW (parse_fail example):\n", raw)
            debug_printed += 1

        # save the results (one line = one example)
        rec = {
            "i": i,
            "ts": ts,
            "head": head_id,
            "rel": rel_id,
            "gold_tail": gold_tail,
            "pred_labels": pred_labels,   # None / [] / ["a","b","c"]
            "raw": raw,                   # optional
            "k": k_retrieval,
        }
        append_jsonl(out_jsonl, rec)

        if log_every and (i + 1) % log_every == 0:
            dt = time.time() - t0
            speed = (i + 1 - start_i) / max(dt, 1e-9)
            eta = (N - (i + 1)) / max(speed, 1e-9)
            print(f"{i+1}/{N} saved | speed={speed:.3f} ex/s | ETA~{eta/60:.1f} min")

    print("Done.")


In [None]:
eval_hits_1_3_checkpoint(
    test_df_filtered,
    out_jsonl="runs/eval_k24.jsonl",
    k_retrieval=24,
    max_examples=len(test_df_filtered),
    start_from="auto",
    log_every=100,
)


The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Will process [14606 .. 18563] (N=18564). Output -> runs/eval_k24.jsonl
14700/18564 saved | speed=0.512 ex/s | ETA~125.7 min
14800/18564 saved | speed=0.550 ex/s | ETA~114.0 min
14900/18564 saved | speed=0.553 ex/s | ETA~110.4 min
15000/18564 saved | speed=0.546 ex/s | ETA~108.7 min
15100/18564 saved | speed=0.536 ex/s | ETA~107.7 min
15200/18564 saved | speed=0.530 ex/s | ETA~105.9 min
15300/18564 saved | speed=0.558 ex/s | ETA~97.4 min
15400/18564 saved | speed=0.562 ex/s | ETA~93.9 min
15500/18564 saved | speed=0.563 ex/s | ETA~90.7 min
15600/18564 saved | speed=0.562 ex/s | ETA~87.9 min
15700/18564 saved | speed=0.559 ex/s | ETA~85.3 min
15800/18564 saved | speed=0.554 ex/s | ETA~83.2 min
15900/18564 saved | speed=0.551 ex/s | ETA~80.6 min
16000/18564 saved | speed=0.567 ex/s | ETA~75.4 min
16100/18564 saved | speed=0.581 ex/s | ETA~70.7 min
16200/18564 saved | speed=0.579 ex/s | ETA~68.0 min
16300/18564 saved | speed=0.577 ex/s | ETA~65.4 min
16500/18564 saved | speed=0.568 ex/s | 

In [None]:
def summarize_jsonl(out_jsonl):
    h1 = 0
    h3 = 0
    total = 0
    parse_fail = 0
    null_count = 0

    with open(out_jsonl, "r", encoding="utf-8") as f:
        for line in f:
            rec = json.loads(line)
            gold_tail = str(rec["gold_tail"])
            pred_labels = rec.get("pred_labels", None)

            total += 1

            if pred_labels is None:
                parse_fail += 1
                continue

            if len(pred_labels) == 0:
                null_count += 1
                continue

            pred_qids_1 = tail_label_to_qids(pred_labels[0])
            pred_qids_3 = set()
            for lab in pred_labels[:3]:
                pred_qids_3.update(tail_label_to_qids(lab))

            if gold_tail in pred_qids_1:
                h1 += 1
            if gold_tail in pred_qids_3:
                h3 += 1

    return {
        "hits@1": h1 / max(total, 1),
        "hits@3": h3 / max(total, 1),
        "total": total,
        "parse_fail": parse_fail,
        "null": null_count,
    }

summary = summarize_jsonl("runs/eval_k24.jsonl")
summary
