In [19]:
import sys
sys.path.append('..')
from grammar.db_tool import DBTool
from grammar.llm import AnyOpenAILLM
from grammar.sql_template_generator import SQLTemplateGenerator
from grammar.text_template_generator import TextTemplateGenerator
from grammar.qa_generator import QADataGenerator


llm = AnyOpenAILLM(model_name = "gpt4-short") 

setup_env = "spider"
if setup_env == "spider" or setup_env == "spider_closed":
    database_name = 'spider'
    connection_string = f'sqlite:///{database_name}/rel_database/company_employee.sqlite'
    schemas = [('company',), ('people',), ('company', 'people')]
elif setup_env == "aurp":
    database_name = 'Aurp'
    connection_string = "mysql+pymysql://root:!wasdB793050@localhost:3306/Aurp"
    schemas = [('client',), ('employee',), ('project', )]
db_tool = DBTool(connection_string)

In [20]:
# Step 1: Generate SQL Query Templates
file_path = f"{setup_env}/SQLTemplateGenerator/sql_templates.json"
sql_template_generator = SQLTemplateGenerator.from_file(file_path, sql_connection=connection_string, llm=llm)
entities_to_sql_templates = sql_template_generator.generate_batch(schemas, override=False, verbose=True)
sql_templates = [tpl for entity, tpls in entities_to_sql_templates.items() for tpl in tpls]
# sql_template_generator.save(file_path=file_path)

The 2 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 2 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 1 generations for the input `k` exist in `cache_generations`! No need to generate more.


In [21]:
# Step 2: Generate Text Query Templates
linguistic_attr = "long"
file_path = f'{setup_env}/TextTemplateGenerator/{linguistic_attr}.json'
text_template_generator = TextTemplateGenerator.from_file(file_path=file_path, verbalize_attrs=linguistic_attr, llm=llm) # Load existing generations to avoid re-generation
sql_to_text_templates = text_template_generator.generate_batch(sql_templates, verbose=True, num_generations=10, override=False)
# text_template_generator.save(file_path=file_path, override=True)

The 10 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 10 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 10 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 10 generations for the input `k` exist in `cache_generations`! No need to generate more.
The 10 generations for the input `k` exist in `cache_generations`! No need to generate more.


In [22]:
# Step 3: Generate Evaluation Data (Text Queries and Answers)
with_ids = False
if with_ids:
    # This cell generates sql with IDs for evaluating the effectiveness of our framework since IDs align with the document ID within evaluation retrievals
    # SHOULD NOT be used when you apply the framework in the real world
    sql_template_with_id_generator = SQLTemplateGenerator.from_file(f"{setup_env}/SQLTemplateGenerator/sql_templates_with_ids.json", sql_connection=connection_string, llm=llm)
    entities_to_sql_templates_with_ids = sql_template_with_id_generator.generate_batch(schemas, override=False, verbose=True)
    sql_templates_with_ids = [tpl for entity, tpls in entities_to_sql_templates_with_ids.items() for tpl in tpls]
    sql_to_text_templates_with_ids = {}
    for t, t_id in zip(sql_templates, sql_templates_with_ids):
        sql_to_text_templates_with_ids[t_id] = sql_to_text_templates[t]
    sql_to_text_templates = sql_to_text_templates_with_ids
    save_file = f"{linguistic_attr}_with_ids.json"
else:
    save_file = f"{linguistic_attr}.json"

In [23]:
unbalanced = False
if unbalanced:
    print("Number of text templates before unbalancing:", len([tpl for tpls in sql_to_text_templates.values() for tpl in tpls]))
    if linguistic_attr == "long":
        print("only use one text template for each SQL template for client and employee tables amd all text templates for project table")
        sql_to_text_templates_unbalanced = {}
        for sql, text_templates in sql_to_text_templates.items():
            if "FROM Client" in sql or "FROM Employee" in sql:
                sql_to_text_templates_unbalanced[sql] = [text_templates[0]]
            else:
                sql_to_text_templates_unbalanced[sql] = text_templates
    elif linguistic_attr == "short":
        print('use the first two text templates for each SQL template')
        sql_to_text_templates_unbalanced = {sql: text_templates[:2] for sql, text_templates in sql_to_text_templates.items()}
        
    sql_to_text_templates = sql_to_text_templates_unbalanced 
    print("Number of text templates after unbalancing:", len([tpl for tpls in sql_to_text_templates.values() for tpl in tpls]))
    save_file = save_file[:-5]+"_unbalanced.json"

In [24]:
qa_generator = QADataGenerator(db_tool)
all_answers_to_text_queries = qa_generator.generate(sql_to_text_templates)
qa_generator.print_query_stats(all_answers_to_text_queries)
qa_generator.save(all_answers_to_text_queries, database_name, save_file, overwrite=True)

The number of generated SQL queries:  57
The number of generated text queries:  570


In [80]:
# How About Using Existing Queries?
# import json
# file_name = "spider-database/database/company_employee/company_employee.json"
# with open(f"{file_name}", 'r') as f:
#     filtered_data = json.load(f)
# questions = [example['question'] for example in filtered_data]
# sql_queries = [example['query'] for example in filtered_data]
# questions