In [54]:
import torch
# set to "cuda:1" for running in parallel on both GPUs
device = torch.device("cuda:0")
torch.cuda.set_device(device)
torch.set_default_device(device)
import pandas as pd
from utils.graph import KGraphPreproc
from utils.llm.mistral import MistralLLM
from utils.prompt import GRAPH_QA_PROMPT
from tqdm import tqdm
from utils.file import export_results_to_file
import os
import networkx as nx
import ast

In [3]:
fbkb_graph = KGraphPreproc.get_fbkb_graph()

In [63]:
def get_triplet_path(graph, start, target):
    try:
        path = nx.shortest_path(graph, start, target)
        triplets = []
        for s,t in zip(path, path[1:]):
            head = fbkb_graph.mid2name.get(s, None)
            rel = graph[s][t].get("relation", None)
            tail = fbkb_graph.mid2name.get(t, None)
            triplets.append(f'{head}-{rel}-{tail}')
        return triplets
    except (nx.NetworkXNoPath, nx.NodeNotFound):
        return []

In [8]:
fbqa = pd.read_csv("/datasets/FreebaseQA/FbQA-eval-1000.csv", index_col=0)
fbqa.head(1)

Unnamed: 0,Question-ID,RawQuestion,ProcessedQuestion,Parses,entities
2,FreebaseQA-eval-2,Who directed the films; The Fisher King (1991)...,who directed the films; the fisher king (1991)...,"[{'Parse-Id': 'FreebaseQA-eval-2.P0', 'Potenti...","[['/m/07j6w', '/m/07h5d'], ['/m/04z257', '/m/0..."


In [52]:
mistral = MistralLLM()

In [64]:
# init experiment
experiment_name = f"kb-path"
res_path = f"/datasets/FreebaseQA/results/{experiment_name}.csv"
results = []
id_list = []
l = 0
# load if preinit'ed
if os.path.isfile(res_path):
    r_df = pd.read_csv(res_path)
    l = len(r_df)
    results = list(r_df.Model.values)
# run through
for c, (i, r) in enumerate(tqdm(list(fbqa.iterrows()))):
    id_list.append(i)
    if c < l:
        continue
    paths = ast.literal_eval(r.entities)
    context = []
    for pair in paths:
        context.extend(get_triplet_path(fbkb_graph._graph, *pair))
    prompt = GRAPH_QA_PROMPT.format(
        context=";".join(context),
        question=r.RawQuestion
    )
    response = mistral.get_response(prompt)
    results.append(response)
    # backup every 10 qs
    if c % 10 == 0:
        export_results_to_file(res_path, results, id_list)
export_results_to_file(res_path, results)

  0%|          | 0/1000 [00:00<?, ?it/s]

100%|██████████| 1000/1000 [05:46<00:00,  2.88it/s]


In [65]:
cwq = pd.read_csv("/datasets/CWQ/cwq-1000.csv", index_col=0)

In [91]:
# init experiment
experiment_name = f"kb-path"
res_path = f"/datasets/CWQ/results/{experiment_name}.csv"
results = []
id_list = []
l = 0
# load if preinit'ed
if os.path.isfile(res_path):
    r_df = pd.read_csv(res_path)
    l = len(r_df)
    results = list(r_df.Model.values)
# run through
for c, (i, r) in enumerate(tqdm(list(cwq.iterrows()))):
    id_list.append(i)
    if c < l:
        continue
    topic_ids = set(ast.literal_eval(r["topic_ids"]))
    answer_ids = set(ast.literal_eval(r["answer_ids"]))
    paths = [
        [start, target] for start in topic_ids for target in answer_ids
    ]
    context = []
    for pair in paths:
        context.extend(get_triplet_path(fbkb_graph._graph, *pair))
    prompt = GRAPH_QA_PROMPT.format(
        context=";".join(context),
        question=r.question
    )
    response = mistral.get_response(prompt)
    results.append(response)
    # backup every 10 qs
    if c % 10 == 0:
        export_results_to_file(res_path, results, id_list)
export_results_to_file(res_path, results)

  0%|          | 0/1000 [00:00<?, ?it/s]

100%|██████████| 1000/1000 [10:02<00:00,  1.66it/s]
