In [3]:
import torch
# set to "cuda:1" for running in parallel on both GPUs
device = torch.device("cuda:0")
torch.cuda.set_device(device)
torch.set_default_device(device)
import pandas as pd
from utils.graph import KGraphPreproc
from utils.llm.mistral import MistralLLM
from utils.prompt import GRAPH_QA_PROMPT
from tqdm import tqdm
from utils.file import export_results_to_file
import os
import networkx as nx
import ast

In [2]:
fbkb_graph = KGraphPreproc.get_fbkb_graph()

In [6]:
def get_triplet_path(graph, start, target):
    try:
        path = nx.shortest_path(graph, start, target)
        triplets = []
        for s,t in zip(path, path[1:]):
            head = fbkb_graph.mid2name.get(s, None)
            rel = graph[s][t].get("relation", None)
            tail = fbkb_graph.mid2name.get(t, None)
            if head and rel and tail:
                triplets.append(f'{head}-{rel}-{tail}')
        return triplets[:250]
    except (nx.NetworkXNoPath, nx.NodeNotFound):
        return []

In [4]:
fbqa = pd.read_csv("/datasets/FreebaseQA/FbQA-eval-1000.csv", index_col=0)
fbqa.head(1)

Unnamed: 0,Question-ID,RawQuestion,ProcessedQuestion,Parses,entities
2,FreebaseQA-eval-2,Who directed the films; The Fisher King (1991)...,who directed the films; the fisher king (1991)...,"[{'Parse-Id': 'FreebaseQA-eval-2.P0', 'Potenti...","[['/m/07j6w', '/m/07h5d'], ['/m/04z257', '/m/0..."


In [5]:
mistral = MistralLLM()

In [7]:
# init experiment
experiment_name = f"kb-path"
res_path = f"/datasets/FreebaseQA/results/{experiment_name}.csv"
results = []
id_list = []
l = 0
# load if preinit'ed
if os.path.isfile(res_path):
    r_df = pd.read_csv(res_path)
    l = len(r_df)
    results = list(r_df.Model.values)
# run through
for c, (i, r) in enumerate(tqdm(list(fbqa.iterrows()))):
    id_list.append(i)
    if c < l:
        continue
    paths = ast.literal_eval(r.entities)
    context = []
    for pair in paths:
        context.extend(get_triplet_path(fbkb_graph._graph, *pair))
    prompt = GRAPH_QA_PROMPT.format(
        context=";".join(context),
        question=r.RawQuestion
    )
    response = mistral.get_response(prompt)
    results.append(response)
    # backup every 10 qs
    if c % 10 == 0:
        export_results_to_file(res_path, results, id_list)
export_results_to_file(res_path, results)

100%|██████████| 1000/1000 [05:22<00:00,  3.10it/s]


In [5]:
cwq = pd.read_csv("/datasets/CWQ/cwq-1000.csv", index_col=0)

In [11]:
cwq

Unnamed: 0,ID,compositionality_type,question,answers,topic_ids,answer_ids
894,WebQTest-1311_e920e31a99d6b7dfbeef110668d3103d,comparative,What inspiration of Antoni Gaudi died later th...,"[{'aliases': ['W. Morris'], 'answer': 'William...","['/m/0g84t93', '/m/0g84t93', '/m/0g84t93', '/m...",['/m/08304']
1329,WebQTest-1382_b9b879060be6df6cb7cd937a7996f9d9,comparative,What country borders Argentina and has an army...,"[{'aliases': ['Brazilian ', 'República Federat...","['/m/0g84t93', '/m/0g84t93', '/m/02lw5z', '/m/...",['/m/015fr']
1235,WebQTest-1382_edc19b9010b39a6a7a1e3926399c8522,comparative,What country bordering Argentina has populatio...,"[{'aliases': ['Republic of Chile'], 'answer': ...","['/m/0g84t93', '/m/0g84t93', '/m/02lw5z', '/m/...","['/m/01p1v', '/m/0165v', '/m/05v10', '/m/015fr..."
1778,WebQTrn-3252_34f533dd75026e91fc9acd350e6eeffb,comparative,What countries in which the Niger River flows ...,"[{'aliases': [], 'answer': 'Benin', 'answer_id...","['/m/0g84t93', '/m/0g84t93', '/m/02lw5z', '/m/...",['/m/0164v']
1471,WebQTrn-2177_dec2523c78124e170c353878876cff1e,comparative,"Where did Caroline Kennedy attend university, ...","[{'aliases': ['Harvard University, main campus...","['/m/0g84t93', '/m/0g84t93', '/m/02lw5z', '/m/...","['/m/03ksy', '/m/01mpwj', '/m/01n951']"
...,...,...,...,...,...,...
2898,WebQTest-1382_73e475413d79895e98983bff8b926f21,superlative,Which country bordering Argentina has the lowe...,"[{'aliases': ['Brazilian ', 'República Federat...","['/m/0g84t93', '/m/0g84t93', '/m/02lw5z', '/m/...",['/m/015fr']
3254,WebQTrn-2525_6f63302a5c1425c7a4f31bd93c423f2b,superlative,"What college, that has the largest number of u...","[{'aliases': ['TU', 'Temple', 'Temple Universi...","['/m/0g84t93', '/m/0g84t93', '/m/02lw5z', '/m/...",['/m/01jt2w']
1901,WebQTrn-909_86f42ab931739ed1c6ba88a8db93fd0d,superlative,What location with the smallest GNIS feature I...,"[{'aliases': ['Minneapolis, Minnesota', 'Henne...","['/m/0g84t93', '/m/0g84t93', '/m/0g84t93', '/m...",['/m/0fpzwf']
2507,WebQTrn-3358_887e6cbf6fd62ad83a3815bd45a6b28d,superlative,Which politician who held office most recently...,"[{'aliases': [""Long 'Un"", 'The Flatboat Man', ...","['/m/0g84t93', '/m/0g84t93', '/m/0g84t93', '/m...",['/m/0gzh']


In [91]:
# init experiment
experiment_name = f"kb-path"
res_path = f"/datasets/CWQ/results/{experiment_name}.csv"
results = []
id_list = []
l = 0
# load if preinit'ed
if os.path.isfile(res_path):
    r_df = pd.read_csv(res_path)
    l = len(r_df)
    results = list(r_df.Model.values)
# run through
for c, (i, r) in enumerate(tqdm(list(cwq.iterrows()))):
    id_list.append(i)
    if c < l:
        continue
    topic_ids = set(ast.literal_eval(r["topic_ids"]))
    answer_ids = set(ast.literal_eval(r["answer_ids"]))
    paths = [
        [start, target] for start in topic_ids for target in answer_ids
    ]
    context = []
    for pair in paths:
        context.extend(get_triplet_path(fbkb_graph._graph, *pair))
    prompt = GRAPH_QA_PROMPT.format(
        context=";".join(context),
        question=r.question
    )
    response = mistral.get_response(prompt)
    results.append(response)
    # backup every 10 qs
    if c % 10 == 0:
        export_results_to_file(res_path, results, id_list)
export_results_to_file(res_path, results)

  0%|          | 0/1000 [00:00<?, ?it/s]

100%|██████████| 1000/1000 [10:02<00:00,  1.66it/s]
