In [13]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [35]:
import gc
import ast
import torch
import pandas as pd
from tqdm import tqdm
import networkx as nx
from sentence_transformers import SentenceTransformer
from utils.preprocessing import preprocess_text
from utils.graph import KGraphPreproc
from utils.graph.tog_lp import ToGLPChain
from utils.llm.mistral import MistralLLM
from utils.llm.qwen import Qwen2_5
from utils.file import export_results_to_file
from utils.link_prediction import extract_predicted_edges
from utils.prompt import GRAPH_QA_PROMPT, ENTITY_PROMPT, \
    NO_CONTEXT_PROMPT, EVALUATE_CONTEXT_PROMPT, \
    RERANK_TRIPLETS_PROMPT, PREDICT_EDGE_PROMPT


device = torch.device("cuda:0")
torch.cuda.set_device(device)
torch.set_default_device(device)

In [29]:
torch.cuda.empty_cache()
gc.collect()

0

In [30]:
qwen = Qwen2_5()

In [31]:
mistral = MistralLLM()

In [41]:
sbert = SentenceTransformer("all-mpnet-base-v2")
sbert_cache_path = "/datasets/FB15k-237/cache/sbert.csv"

In [33]:
fbkb_graph = KGraphPreproc.get_fbkb_graph()

In [42]:
fbkb_graph.embed_triplets(
    embedding_function=lambda q: sbert.encode(q),
    cache_path=sbert_cache_path
)

Checking embedding cache
Loading embedding cache


329136it [11:46, 465.60it/s]


In [43]:
chain = ToGLPChain.from_llm(
    llm=mistral,
    link_predictor_llm=qwen,
    graph=fbkb_graph,
    sbert=sbert,
    verbose=False
)

  qa_chain = LLMChain(llm=llm, prompt=qa_prompt)


In [37]:
def entity_path_len(entities):
    for path in entities:
        start = path[0]
        for target in path[1:]:
            try:
                return len(nx.shortest_path(
                    fbkb_graph._graph, start, target
                ))
            except (nx.NodeNotFound, nx.NetworkXNoPath):
                continue
    return -1

In [73]:
fbqa = pd.read_csv("/datasets/FreebaseQA/FbQA-eval-1000.csv", index_col=0)
fbqa["entities"] = fbqa["entities"].apply(ast.literal_eval)
fbqa["hops"] = fbqa.apply(lambda t: entity_path_len(t["entities"]), axis=1)

In [71]:
experiment_name = f"tog-lp-1"
res_path = f"/datasets/FreebaseQA/results/{experiment_name}.csv"

In [62]:
def extract_fbqa_entities(row):
    entities = set(map(lambda t: t[0], row.entities))
    return list(filter(None, [fbkb_graph.mid2name.get(ent) for ent in entities]))

In [78]:
r_df = pd.read_csv(res_path)
r_df

Unnamed: 0.1,Unnamed: 0,Model
0,2,"('Terry Gilliam', 3)"
1,5,"('Germany', 1)"
2,6,"('Adolf Hitler and Carl Diem', 3)"
3,9,"('John Steinbeck', 1)"
4,12,"('Dick Fosbury (United States of America)', 1)"
...,...,...
995,3964,('Woodrow Wilson was the president of the USA ...
996,3969,"('Adam Sandler (4 times)', 1)"
997,3985,"('Whitehorse', 1)"
998,3993,"('Figure skating', 1)"


In [75]:
results = []
id_list = []
l = 0
if os.path.isfile(res_path):
    r_df = pd.read_csv(res_path)
    l = len(r_df)
    results = list(r_df.Model.values)
for c, (i, r) in enumerate(tqdm(list(fbqa.iterrows()))):
    id_list.append(i)
    if c < l:
            continue
    q = r.RawQuestion
    topic_entities = extract_fbqa_entities(r)
    response = chain.invoke(
        input={
            "query": q,
            "topic_entities": topic_entities
    })
    answer = response["result"]["text"]
    depth = response["depth"]
    results.append((answer, depth))
    if c % 10 == 0:
        export_results_to_file(res_path, results, id_list)
export_results_to_file(res_path, results, id_list)


100%|██████████| 1000/1000 [00:00<00:00, 1953564.97it/s]
