In [2]:
import torch
# set to "cuda:1" for running in parallel on both GPUs
device = torch.device("cuda:1")
torch.cuda.set_device(device)
torch.set_default_device(device)
import pandas as pd
from utils.graph import KGraphPreproc
from utils.llm.mistral import MistralLLM
from utils.prompt import GRAPH_QA_PROMPT
from tqdm import tqdm
from utils.file import export_results_to_file
import os
import networkx as nx
import ast

In [3]:
mistral = MistralLLM()

In [18]:
mqa_graph = KGraphPreproc.get_metaqa_graph()

In [26]:
def get_triplet_path(graph, start, target):
    try:
        path = nx.shortest_path(graph._graph, start, target)
        triplets = []
        for s,t in zip(path, path[1:]):
            head = graph.mid2name.get(s, None)
            rel = graph._graph[s][t].get("relation", None)
            tail = graph.mid2name.get(t, None)
            triplets.append(f'{head}-{rel}-{tail}')
        return triplets
    except (nx.NetworkXNoPath, nx.NodeNotFound):
        return []

In [16]:
import re
regex_mqa_topic_entity = re.compile("\[(.*?)\]")
def extract_mqa_topic_entity(question):
    mid = regex_mqa_topic_entity.findall(question)[0]
    return mqa_graph.name2mid.get(mid, None)

In [31]:
# init experiment
for hop in [1,2,3]:
    print(f"Hops: {hop}")
    experiment_name = f"kb-path"
    res_path = f"/datasets/MetaQA/results/{hop}hop/{experiment_name}.csv"
    results = []
    id_list = []
    l = 0
    # load if preinit'ed
    if os.path.isfile(res_path):
        r_df = pd.read_csv(res_path)
        l = len(r_df)
        results = list(r_df.Model.values)
    # run through
    metaqa = metaqa = pd.read_csv(f"/datasets/MetaQA/{hop}hop/test_1000.txt", header=None, index_col=0)
    metaqa.rename(columns={1: "Question", 2: "Answers"}, inplace=True)
    for c, (i, r) in enumerate(tqdm(list(metaqa.iterrows()))):
        id_list.append(i)
        if c < l:
            continue
        topic_ids = [extract_mqa_topic_entity(r.Question)]
        answer_ids = list(map(mqa_graph.name2mid.get, r.Answers.split("|")))

        paths = [
            [start, target] for start in topic_ids for target in answer_ids
        ]
        context = []
        for pair in paths:
            context.extend(get_triplet_path(mqa_graph, *pair))
        prompt = GRAPH_QA_PROMPT.format(
            context=";".join(context),
            question=r.Question
        )
        response = mistral.get_response(prompt)
        results.append(response)
        # backup every 10 qs
        if c % 10 == 0:
            export_results_to_file(res_path, results, id_list)
    export_results_to_file(res_path, results, id_list)

Hops: 1


100%|██████████| 1000/1000 [00:00<00:00, 1785570.03it/s]




Hops: 2


100%|██████████| 1000/1000 [08:34<00:00,  1.94it/s] 


Hops: 3


100%|██████████| 1000/1000 [40:50<00:00,  2.45s/it] 
