In [31]:
import sys
sys.path.append('..')
from grammar.db_tool import DBTool
from grammar.llm import AnyOpenAILLM
from grammar.sql_template_generator import SQLTemplateGenerator
from grammar.text_template_generator import TextTemplateGenerator
from grammar.qa_generator import QADataGenerator


llm = AnyOpenAILLM(model_name = "gpt4-short")  
database_name = 'spider'
connection_string = f'sqlite:///{database_name}/rel_database/company_employee.sqlite'
db_tool = DBTool(connection_string)
root_dir = "."

In [32]:
# Step 1: Generate SQL Query Templates
file_path = "spider/SQLTemplateGenerator/sql_templates.json"
sql_template_generator = SQLTemplateGenerator.from_file(file_path, sql_connection=connection_string, llm=llm)
entities_to_sql_templates = sql_template_generator.generate_batch([('company',), ('people',), ('company', 'people')], override=False, verbose=True)
sql_templates = [tpl for entity, tpls in entities_to_sql_templates.items() for tpl in tpls]
# sql_template_generator.save(file_path=file_path)

The 2 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 2 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.


In [33]:
# Step 2: Generate Text Query Templates
linguistic_attr = "short"
file_path = f'spider/TextTemplateGenerator/{linguistic_attr}.json'
text_template_generator = TextTemplateGenerator.from_file(file_path=file_path, verbalize_attrs=linguistic_attr, llm=llm) # Load existing generations to avoid re-generation
sql_to_text_templates = text_template_generator.generate_batch(sql_templates, verbose=True, num_generations=10, override=False)
# text_template_generator.save(file_path=file_path, override=True)

The 10 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 10 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 10 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 10 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 10 generations for the input `k` exist in `cache_generations`! No need to generate more.


In [34]:
# This cell generates sql with IDs for evaluating the effectiveness of our framework
# should not be used when you apply the framework in real world
sql_template_with_id_generator = SQLTemplateGenerator.from_file("spider/SQLTemplateGenerator/sql_templates_with_ids.json", sql_connection=connection_string, llm=llm)
entities_to_sql_templates_with_ids = sql_template_with_id_generator.generate_batch([('company',), ('people',), ('company', 'people')], override=False, verbose=True)
sql_templates_with_ids = [tpl for entity, tpls in entities_to_sql_templates_with_ids.items() for tpl in tpls]
sql_to_text_templates_with_ids = {}
for t, t_id in zip(sql_templates, sql_templates_with_ids):
    sql_to_text_templates_with_ids[t_id] = sql_to_text_templates[t]


The 2 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 2 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.


In [35]:
# Step 3: Generate Evaluation Data (Text Queries and Answers)
qa_generator = QADataGenerator(db_tool)
all_answers_to_text_queries = qa_generator.generate(sql_to_text_templates_with_ids)
qa_generator.print_query_stats(all_answers_to_text_queries)
file_name = f"{linguistic_attr}_with_ids.json"
qa_generator.save(all_answers_to_text_queries, database_name, file_name, overwrite=True)

The number of generated SQL queries:  57
The number of generated text queries:  570


In [25]:
# How About Using Existing Queries?
# import json
# file_name = "spider-database/database/company_employee/company_employee.json"
# with open(f"{file_name}", 'r') as f:
#     filtered_data = json.load(f)
# questions = [example['question'] for example in filtered_data]
# sql_queries = [example['query'] for example in filtered_data]
# questions