In [None]:
from tqdm import tqdm
import argparse
from utils import *
from freebase import *
from propagation import *
import random
import concurrent.futures



parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str,
                    default="cwq", help="choose the dataset from {webqsp, cwq}.")
parser.add_argument("--max_length", type=int,
                    default=1024, help="the max length of LLMs output.")
parser.add_argument("--limit", type=int,
                    default=7000, help="the max length of the approximation LLMs input.")      
parser.add_argument("--temperature", type=float,
                    default=0., help="the temperature")
parser.add_argument("--llm", type=str,
                    default="llama-3", help="choose base LLM model from {llama, gpt-3.5-turbo, gpt-4}.")
parser.add_argument("--openai_api_key", type=str,
                    default="", help="if the LLM is gpt-3.5-turbo or gpt-4, you need add your own openai api key.")
parser.add_argument('--verbose', action='store_true', help="print LLM input and output.")
args = parser.parse_args(["--verbose"])
# args = parser.parse_args("")


datas, question_string = prepare_dataset(args.dataset)

In [None]:
data = datas[2943]
question = data[question_string]
topics = data['topic_entity']
paths = {topics[topic]: {} for topic in topics}
print(question)

In [None]:
# for topic in topics:
for topic in topics:
    topic_name = topics[topic]
    # 1-hop propagation
    relations = get_relations(question, topic, topic_name, args, 3)
    entities = get_entities({topic: topic_name}, relations, topic)
    [paths[topic_name].update({r: {"entities": entities[i]}}) for i, r in enumerate(relations)]
    facts = propagate(question, topic_name, relations, paths[topic_name], 1, args)
    [paths[topic_name][r].update({"fact": facts[i]}) for i, r in enumerate(relations)]
    # 2-hop propagation
    relations = get_relations_distant(question, topic, topic_name, relations, paths[topic_name], args, 3)
    entities = get_entities_distant(paths[topic_name], relations, topic)
    [paths[topic_name].update({r: {"entities": entities[i]}}) for i, r in enumerate(relations)]
    facts = propagate(question, topic_name, relations, paths[topic_name], 2, args)
    [paths[topic_name][r].update({"fact": facts[i]}) for i, r in enumerate(relations)]
    # 3-hop propagation
    relations = get_relations_distant(question, topic, topic_name, relations, paths[topic_name], args, 3)
    entities = get_entities_distant(paths[topic_name], relations, topic)
    [paths[topic_name].update({r: {"entities": entities[i]}}) for i, r in enumerate(relations)]
    facts = propagate(question, topic_name, relations, paths[topic_name], 3, args)
    [paths[topic_name][r].update({"fact": facts[i]}) for i, r in enumerate(relations)]
    # # # # clean paths
    [paths[topic_name].update({r: paths[topic_name][r]['fact']}) for r in paths[topic_name]]


In [None]:
facts = construct_facts(paths, topics)
prompt = question_prompt.format(facts, question)
response = run_llm(prompt, args)
output = {"question": question, "result": response, "paths": paths}

In [None]:
save_2_jsonl("lmp_{}_{}_3hop.jsonl".format(args.dataset, args.llm), output)

In [7]:
sparql_name = """
PREFIX ns: <http://rdf.freebase.com/ns/>
SELECT DISTINCT ?e ?name
WHERE {
?e ns:type.object.name ?name . 
FILTER(CONTAINS(LCASE(?name), "National Football League"))
}
"""

execute_sparql(sparql_name)

[]

In [None]:
sparql_r = """
PREFIX ns: <http://rdf.freebase.com/ns/>
SELECT DISTINCT ?e ?r ?name
WHERE {
?e ?r ns:%s.
?e ns:type.object.name ?name . 
}
"""

execute_sparql(sparql_r % ('m.0cs1bx'))

In [None]:
[{'m.03m5x4': 'The NBA Finals'},
{"m.059yj": "National Football League"},
{'m.06x5s': 'Super Bowl'}]

In [None]:
execute_sparql(sparql_relations % 'm.076ps')

In [None]:
execute_sparql(sparql_entities % ('ns:m.04thp ns:m.0j4b', 'location.statistical_region.part_time_employment_percent', 'm.04thp'))

In [None]:
def get_entities_distant(paths, relations, topic):
    entities = []
    for relation in relations:
        start_entities = {}
        previous_entities = paths[relation.rsplit('->', 1)[0]]['entities']
        for i in previous_entities:
            for j in previous_entities[i]:
                if j not in ['literal', 'typed-literal']:
                    start_entities.update({j: previous_entities[i][j]})
        print(start_entities)
        sparql_output = execute_sparql(sparql_entities % (' '.join(['ns:' + i for i in list(start_entities.keys())]), relation.rsplit('->', 1)[1], topic))
        filtered_entities = filter_entities(start_entities, sparql_output)

        entities.append(filtered_entities)
    
    return entities, sparql_output


get_entities_distant(paths[topic_name], [relations[0]], topic)

In [None]:
relations

In [None]:
execute_sparql(sparql_relations_3hop % ('m.06x5s', 'sports.sports_championship.events', 'sports.sports_championship_event.season', 'm.06x5s'))