In [13]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
import pandas as pd
import torch
# set to "cuda:1" for running in parallel on both GPUs
device = torch.device("cuda:0")
torch.cuda.set_device(device)
torch.set_default_device(device)
import gc
import Stemmer
import re
from tqdm import tqdm
import os
import csv
import gc
from utils.graph import KGraphPreproc
from utils.graph.chain import GraphChain
from utils.llm.mistral import MistralLLM

In [3]:
fbqa = pd.read_json("/datasets/FreebaseQA/FreebaseQA-eval.json")
def get_fbqa_data(question_row):
    """
    Takes in a dataset row and returns Q and A as strings
    """
    question = question_row.Questions.get("RawQuestion", None)
    parse = question_row.Questions.get("Parses", [None])[0]
    if not parse:
        print(f"error in question: {question}")
        return question, None
    answer = parse.get("Answers")
    return question, answer

####### load the graph
fbkb_graph = KGraphPreproc.get_fbkb_graph()

In [4]:
def get_response(prompt):
    global chain
    # del mistral
    gc.collect()
    torch.cuda.empty_cache()
    r = chain.invoke(prompt)
    return r["result"]

def save_results(fpath, data_rows):
    with open(fpath, "w") as f:
        writer = csv.writer(f)
        writer.writerow(["Model"])
        for r in data_rows:
            writer.writerow([str(r)])

In [5]:
mistral = MistralLLM()

In [80]:
from utils.prompt import GRAPH_QA_PROMPT, ENTITY_PROMPT
chain = GraphChain.from_llm(
    llm=mistral,
    graph=fbkb_graph,
    qa_prompt=GRAPH_QA_PROMPT,
    entity_prompt=ENTITY_PROMPT,
    verbose=False,
)
chain.sbert_cache_path = "/datasets/FB15k-237/cache/sbert.csv"

In [None]:
for depth in [3]:
    print(f"depth: {depth}")
    # set the depth
    chain.exploration_depth = depth
    # init experiment
    experiment_name = f"sbert-kb{depth}"
    res_path = f"/datasets/FreebaseQA/results/{experiment_name}.csv"
    results = []
    l = 0
    # load if preinit'ed
    if os.path.isfile(res_path):
        r_df = pd.read_csv(res_path)
        l = len(r_df)
        results = list(r_df.Model.values)
    # load q's
    fbqa = pd.read_json("/datasets/FreebaseQA/FreebaseQA-eval.json")
    # run through
    for i, r in tqdm(list(fbqa.iterrows())):
        if i < l:
            continue
        q, a = get_fbqa_data(r)
        response = get_response(q)
        results.append(response)
        # backup every 10 qs
        if i % 10 == 0:
            save_results(res_path, results)
    save_results(res_path, results)

depth: 3


 29%|██▉       | 1153/3996 [5:51:23<23:21:29, 29.58s/it]

In [48]:
list(fbkb_graph._graph.edges(data=True))[0]

('/m/027rn',
 '/m/06cx9',
 {'relation': '/location/country/form_of_government',
  'embedding': '[ 6.66760579e-02, -3.34311128e-02, -3.30068283e-02, -6.52218470e-03, 5.99559359e-02, -3.25985216e-02, -5.75441048e-02, 6.32693768e-02, -6.85760332e-03, 3.24052875e-03, 8.03972110e-02, -3.03435866e-02, -2.48321630e-02, -7.18592703e-02, 8.65350850e-03, 3.96894943e-03, -7.39242882e-02, 2.74431258e-02, 6.91422969e-02, 4.06056233e-02, 9.83152539e-02, -5.27970977e-02, -3.25965695e-02, 4.13841102e-03, -2.69966782e-03, 1.98496226e-02, 7.19338134e-02, 3.07799038e-02, -7.80925900e-03, -5.32830767e-02, 8.66608229e-03, 5.83224297e-02, 1.87647864e-02, 1.56855173e-02, 3.83375101e-02, -2.22949795e-02, -1.00411817e-01, -6.11528344e-02, 1.09584071e-01, -1.53645836e-02, -1.01085864e-02, -6.29659966e-02, 9.56393704e-02, 2.75516999e-04, -8.36374760e-02, -2.04921532e-02, 1.38285868e-02, 1.07522964e-01, 1.74823459e-02, 5.19003533e-03, 4.82774638e-02, -1.07315909e-02, -3.85233313e-02, 3.29222418e-02, -2.40317434e-

In [41]:
import csv
cache_path = "/datasets/FB15k-237/cache/sbert.csv"
with open(cache_path, "w") as cache_file:
    writer = csv.writer(cache_file)
    writer.writerow(["key", "embedding"])
    for u,v in tqdm(fbkb_graph._graph.edges()):
        embedding = fbkb_graph._graph.edges[u, v].get("embedding", None)
        writer.writerow([(u,v), embedding])
        

  0%|          | 0/248611 [00:00<?, ?it/s]

100%|██████████| 248611/248611 [07:08<00:00, 580.02it/s]
