### DBの構造をRAGで入れてSPARQLを直接生成

In [None]:
import os
import time
import json
from openai import OpenAI
from functions.SPARQL_executer import execute_query
from functions.results_evaluater import evaluate_jaccard
from dotenv import load_dotenv
load_dotenv()
db = "rhea"
endpoint = os.environ[f"ENDPOINT_{db.upper()}"]

save_path = f'data/questions/easy_question_augmented_{db}_baseline_PT.json'
save_path_with_results = f'data/questions/easy_question_augmented_with_results_{db}_baseline_PT.json'

In [2]:
# sparql 生成関数
class OpenAIChat:
    def __init__(self):
        self.client = OpenAI()
        self.model_name = "gpt-4-1106-preview"

    def chat(self, prompt):
        model_name = "gpt-4-1106-preview"
        completion = self.client.chat.completions.create(
            model=model_name,
            messages=[prompt],
        )
        return completion.choices[0].message.content


class SPARQLQueryManager:
    def __init__(self):
        self.query_history = []
        self.output_history = []
        self.llm = OpenAIChat()

    def extract_sparql_query(self, gpt_output):
        parts = gpt_output.split("```sparql")
        sparql_queries = []

        # 最初の要素はコードブロック前のテキストなので無視する
        for part in parts[1:]:
            # 分割した各部分から実際のコードを抽出
            code = part.split("```")[0]
            sparql_queries.append(code.strip())

        if len(sparql_queries) == 0:
            return None
        else:
            return sparql_queries[0]

    def check_sparql_result(self, result):
        """
        result: pd.DataFrame
        return: bool
        """
        if result.empty:
            return False
        else:
            return True

    def generate_query_for_data(self, user_question, model_yaml, prefix_str):
        """
        user_question: str
        query_list: list
        output_list: list
        return: list of SPAQRL query string
        """
        #print(model_yaml)
        prompt = f"""
Create a SPARQL query to retrieve values from the database for the user question provided below:
User Question: '{user_question}'

Note: Database structure information is provided in YAML format. Use this information effectively and accurately to ensure the query is feasible.

Prefixes:
{prefix_str}

Database Structure (YAML):
{model_yaml}

[INSTURCTIONS]:
- Create a simple SPARQL query using the YAML database structure provided to answer the user question.
- Use only URI reference that are explicitly listed in the YAML schema. Do not introduce any elements that are not part of the actual schema.
- Apply the FILTER clause judiciously to refine your results. Ensure that your query effectively uses the attributes and relationships detailed in the YAML without omitting valid data.
- Do not modify the values from database using BIND, STRAFTER or other operations, such as trimming text. Because the SPARQL output for the question should directly return values from the database.

---
Now Create a SPARQL query to retrieve values from the database for the user question.
First, explain about the content of the SPARQL query you need to write, 
Second, check if it follows the [INSTURCTIONS].
Finally, create the SPARQL query.
"""
        # print(prompt)
        prompt = {"role": "user", "content": prompt}
        llm_output = self.llm.chat(prompt)
        return self.extract_sparql_query(llm_output)

In [None]:
with open(f"questions/json_format/{db}.json", 'r') as f:
	questions = json.load(f)

In [6]:
query_results = []
for question in questions:
    time.sleep(0.5)
    query_results += [execute_query(question, endpoint, "sparql", 10000, "")]
    print(question["id"], len(query_results[-1][0]))

# 各質問に対する結果を追加
for result, question_id in query_results:
    # 質問を見つけて結果を追加
    for question in questions:
        if question["id"] == question_id:
            question["results"] = result

with open(save_path_with_results, 'w') as f:
    json.dump(questions, f, indent=4)

Q1-1-1 1
Q1-2-1 1
Q1-3-1 1
Q1-4-1 1
Q1-5-1 1
Q1-6-1 1
Q2-1-1 3
Q2-2-1 3
Q2-3-1 3
Q2-4-1 3
Q2-5-1 3
Q2-6-1 3
Q3-1-1 3
Q3-2-1 3
Q3-3-1 3
Q3-4-1 3
Q3-5-1 3
Q3-6-1 3
Q3-1-2 9
Q3-2-2 9
Q3-3-2 9
Q3-4-2 9
Q3-5-2 9
Q3-6-2 9
Q4-1-1 1
Q4-2-1 1
Q4-3-1 1
Q4-4-1 1
Q4-5-1 1
Q4-6-1 1
Q5-1-1 1
Q5-2-1 1
Q5-3-1 1
Q5-4-1 1
Q5-5-1 1
Q5-6-1 1
Q6-1-1 1
Q6-2-1 1
Q6-3-1 1
Q6-4-1 1
Q6-5-1 1
Q6-6-1 1
Q7-1-1 1
Q7-2-1 1
Q7-3-1 1
Q7-4-1 1
Q7-5-1 1
Q7-6-1 1
Q8-1-1 1
Q8-2-1 1
Q8-3-1 1
Q8-4-1 1
Q8-5-1 1
Q8-6-1 1
Q9-1-1 20
Q9-2-1 20
Q9-3-1 20
Q9-4-1 20
Q9-5-1 20
Q9-6-1 20
Q10-1-1 1
Q10-2-1 1
Q10-3-1 1
Q10-4-1 1
Q10-5-1 1
Q10-6-1 1
Q11-1-1 10
Q11-2-1 10
Q11-3-1 10
Q11-4-1 10
Q11-5-1 10
Q11-6-1 10
Q12-1-1 5
Q12-2-1 5
Q12-3-1 5
Q12-4-1 5
Q12-5-1 5
Q12-6-1 5
Q13-1-1 1
Q13-2-1 1
Q13-3-1 1
Q13-4-1 1
Q13-5-1 1
Q13-6-1 1
Q14-1-1 1
Q14-2-1 1
Q14-3-1 1
Q14-4-1 1
Q14-5-1 1
Q14-6-1 1
Q15-1-1 2
Q15-2-1 2
Q15-3-1 2
Q15-4-1 2
Q15-5-1 2
Q15-6-1 2
Q16-1-1 3
Q16-2-1 3
Q16-3-1 3
Q16-4-1 3
Q16-5-1 3
Q16-6-1 3
Q16-1-2 9
Q16-2-2 9
Q16-3-2 

In [None]:
path_rdf_config = os.environ["PATH_RDF_CONFIG"]
prefix_path = f"{path_rdf_config}/{db}/prefix.yaml"
model_path = f"{path_rdf_config}/{db}/model.yaml"

prefix_str = open(prefix_path).read()
model_yaml = open(model_path).read()

In [9]:
manager = SPARQLQueryManager()

In [10]:
len(questions)

168

In [11]:
for q in questions:
    if not "llm_sparql" in q.keys():
        llm_sparql = manager.generate_query_for_data(q["user_question"], model_yaml, prefix_str)
        q["llm_sparql"] = llm_sparql

In [None]:
# エンドポイントにSPARQLを投げて結果を取得
query_results = []
for question in questions:
    query_results += [execute_query(question, endpoint, "llm_sparql", 10000, "")]
    print(question["id"], len(query_results[-1][0]))

# 各質問に対する結果を追加
for result, question_id in query_results:
    # 質問を見つけて結果を追加
    for question in questions:
        if question["id"] == question_id:
            question["results"] = result

Q1-1-1 1
Execute Error: can only concatenate str (not "NoneType") to str
Q1-2-1
Q1-2-1 0
Q1-3-1 0
Q1-4-1 1
Q1-5-1 0
Q1-6-1 0
Q2-1-1 3
Q2-2-1 3
Q2-3-1 3
Q2-4-1 3


Q2-5-1 3
Q2-6-1 3
Q3-1-1 0
Q3-2-1 24
Q3-3-1 24
Q3-4-1 0
Q3-5-1 24
Q3-6-1 0
Q3-1-2 0
Q3-2-2 0
Q3-3-2 0
Q3-4-2 0
Q3-5-2 10
Q3-6-2 0
Q4-1-1 0
Q4-2-1 0
Q4-3-1 0
Q4-4-1 0
Q4-5-1 0
Q4-6-1 0
Q5-1-1 1
Q5-2-1 1
Q5-3-1 1
Q5-4-1 1
Q5-5-1 1
Q5-6-1 1
Q6-1-1 1
Q6-2-1 1
Q6-3-1 1
Q6-4-1 1
Q6-5-1 1
Q6-6-1 0
Q7-1-1 1
Q7-2-1 1
Q7-3-1 0
Q7-4-1 1
Q7-5-1 0
Q7-6-1 1
Q8-1-1 1
Q8-2-1 1
Q8-3-1 1
Q8-4-1 1
Q8-5-1 1
Q8-6-1 1
Q9-1-1 20
Q9-2-1 20
Q9-3-1 20
Q9-4-1 0
Q9-5-1 20
Q9-6-1 20
Execute Error: QueryBadFormed: A bad request has been sent to the endpoint: probably the SPARQL query is badly formed. 

Response:
b'<!DOCTYPE html SYSTEM "about:legacy-compat">\n<html xmlns="http://www.w3.org/1999/xhtml" lang="en" xml:lang="en"><head><title>Rhea</title><meta content="text/html; charset=UTF-8" http-equiv="Content-Type"/><link href="/" rel="home"/><link href="/base.css" type="text/css" rel="stylesheet"/><link type="image/vnd.microsoft.icon" href="https://www.rhea-db.org//favicon.ico" rel="shortcut icon"/><link href="/rh

In [13]:
# save questions "save_questions_results.json"
with open(save_path, "w") as f:
    json.dump(questions, f, indent=2)

In [6]:
# 正解出力との比較を行う
with open(save_path) as f:
    questions = json.load(f)
with open(save_path_with_results) as f:
    answers = json.load(f)

In [7]:
# 正解出力との比較を行う
score = evaluate_jaccard(questions, answers)

Evaluating:   0%|          | 0/102 [00:00<?, ?it/s]

Skipping Q1-1-1 due to empty columns.
Skipping Q1-2-1 due to missing results.
Skipping Q1-3-1 due to empty columns.
Skipping Q1-4-1 due to missing results.
Skipping Q1-5-1 due to missing results.
Skipping Q1-6-1 due to missing results.
Skipping Q2-1-1 due to empty columns.


Evaluating:   8%|▊         | 8/102 [00:01<00:13,  6.95it/s]

Skipping Q2-3-1 due to missing results.
Skipping Q2-4-1 due to empty columns.


Evaluating:  42%|████▏     | 43/102 [00:02<00:02, 26.83it/s]

Skipping Q2-6-1 due to empty columns.
Skipping Q3-1-1 due to empty columns.
Skipping Q3-2-1 due to empty columns.
Skipping Q3-3-1 due to empty columns.
Skipping Q3-4-1 due to missing results.
Skipping Q3-5-1 due to empty columns.
Skipping Q3-6-1 due to empty columns.
Skipping Q4-1-1 due to empty columns.
Skipping Q4-2-1 due to empty columns.
Skipping Q4-3-1 due to missing results.
Skipping Q4-4-1 due to missing results.
Skipping Q4-5-1 due to empty columns.
Skipping Q4-6-1 due to empty columns.
Skipping Q5-1-1 due to missing results.
Skipping Q5-2-1 due to missing results.
Skipping Q5-3-1 due to missing results.
Skipping Q5-4-1 due to empty columns.
Skipping Q5-5-1 due to empty columns.
Skipping Q5-6-1 due to empty columns.
Skipping Q6-5-1 due to empty columns.
Skipping Q7-1-1 due to empty columns.
Skipping Q7-2-1 due to missing results.
Skipping Q7-3-1 due to missing results.
Skipping Q7-4-1 due to missing results.
Skipping Q7-5-1 due to missing results.
Skipping Q7-6-1 due to missing

Evaluating: 100%|██████████| 102/102 [00:04<00:00, 24.67it/s]

Skipping Q10-1-1 due to missing results.
Skipping Q10-2-1 due to empty columns.
Skipping Q10-3-1 due to missing results.
Skipping Q10-4-1 due to missing results.
Skipping Q10-5-1 due to empty columns.
Skipping Q10-6-1 due to empty columns.
Skipping Q11-1-1 due to missing results.
Skipping Q11-2-1 due to missing results.
Skipping Q11-3-1 due to empty columns.
Skipping Q11-4-1 due to empty columns.
Skipping Q11-5-1 due to missing results.
Skipping Q11-6-1 due to missing results.
Skipping Q12-1-1 due to empty columns.
Skipping Q12-2-1 due to empty columns.
Skipping Q12-3-1 due to empty columns.
Skipping Q12-4-1 due to empty columns.
Skipping Q12-5-1 due to empty columns.
Skipping Q12-6-1 due to empty columns.
Skipping Q13-1-1 due to empty columns.
Skipping Q13-2-1 due to missing results.
Skipping Q13-3-1 due to empty columns.
Skipping Q13-4-1 due to missing results.
Skipping Q13-5-1 due to empty columns.
Skipping Q13-6-1 due to empty columns.
Skipping Q14-1-1 due to empty columns.
Skippin




In [9]:
score

{'Q1-1-1': {'jaccard_score': 0},
 'Q1-2-1': {'jaccard_score': 0},
 'Q1-3-1': {'jaccard_score': 0},
 'Q1-4-1': {'jaccard_score': 0},
 'Q1-5-1': {'jaccard_score': 0},
 'Q1-6-1': {'jaccard_score': 0},
 'Q2-1-1': {'jaccard_score': 0},
 'Q2-2-1': {'jaccard_score': 0.0},
 'Q2-3-1': {'jaccard_score': 0},
 'Q2-4-1': {'jaccard_score': 0},
 'Q2-5-1': {'jaccard_score': 0.0},
 'Q2-6-1': {'jaccard_score': 0},
 'Q3-1-1': {'jaccard_score': 0},
 'Q3-2-1': {'jaccard_score': 0},
 'Q3-3-1': {'jaccard_score': 0},
 'Q3-4-1': {'jaccard_score': 0},
 'Q3-5-1': {'jaccard_score': 0},
 'Q3-6-1': {'jaccard_score': 0},
 'Q4-1-1': {'jaccard_score': 0},
 'Q4-2-1': {'jaccard_score': 0},
 'Q4-3-1': {'jaccard_score': 0},
 'Q4-4-1': {'jaccard_score': 0},
 'Q4-5-1': {'jaccard_score': 0},
 'Q4-6-1': {'jaccard_score': 0},
 'Q5-1-1': {'jaccard_score': 0},
 'Q5-2-1': {'jaccard_score': 0},
 'Q5-3-1': {'jaccard_score': 0},
 'Q5-4-1': {'jaccard_score': 0},
 'Q5-5-1': {'jaccard_score': 0},
 'Q5-6-1': {'jaccard_score': 0},
 'Q6-1