In [1]:
import sys
sys.path.append('../..')
from grammar.db_tool import DBTool
from grammar.llm import AnyOpenAILLM
from grammar.sql_template_generator import SQLTemplateGenerator
from grammar.text_template_generator import TextTemplateGenerator
from grammar.qa_generator import QADataGenerator

llm = AnyOpenAILLM(model_name = "gpt4-short")  
connection_string = 'sqlite:///spider-database/database/company_employee/company_employee.sqlite'
db_tool = DBTool(connection_string)
root_dir = "."

In [2]:
# Step 1: Generate SQL Query Templates
## a. Automatically generate SQL query templates
sql_template_generator = SQLTemplateGenerator(connection_string, llm)
sql_template_generator.set_system_msg()
sql_templates = sql_template_generator.generate()
sql_template_generator.save(sql_templates, root_dir, overwrite=True)

In [16]:
# b. Manually modify SQL query templates (optional)

# c. Load modified SQL query templatesn (optional)
sql_templates_with_ids = SQLTemplateGenerator.load(root_dir, file_name='sql_templates_with_ids.txt')
sql_templates = SQLTemplateGenerator.load(root_dir, file_name='sql_templates.txt')
print(sql_templates_with_ids)

["SELECT Industry, Company_ID FROM company WHERE Name = '[company.Name]';", "SELECT Headquarters, Company_ID FROM company WHERE Name = '[company.Name]';", "SELECT Nationality, People_ID FROM people WHERE Name = '[people.Name]';", "SELECT Graduation_College, People_ID FROM people WHERE Name = '[people.Name]';", "SELECT e.Year_working, e.Company_ID, e.People_ID FROM employment e JOIN company c ON e.Company_ID = c.Company_ID JOIN people p ON e.People_ID = p.People_ID WHERE c.Name = '[company.Name]' AND p.Name = '[people.Name]';"]


In [10]:
# Step 2: Generate Text Query Templates
linguistic_attr = "short"
text_template_generator = TextTemplateGenerator(llm=llm)
text_template_generator.set_system_msg(linguistic_attr= linguistic_attr)
text_template_generator = text_template_generator.from_file(linguistic_attr= linguistic_attr, root_dir=root_dir) # Load existing generations to avoid re-generation
sql_to_text_templates = text_template_generator.generate_batch(sql_templates, verbose=True)
text_template_generator.save(sql_to_text_templates, root_dir, overwrite=True)

Text templates for the SQL template: SELECT Industry FROM company WHERE Name = '[company.Name]';. EXIST!!!
Text templates for the SQL template: SELECT Headquarters FROM company WHERE Name = '[company.Name]';. EXIST!!!
Text templates for the SQL template: SELECT Nationality FROM people WHERE Name = '[people.Name]';. EXIST!!!
Text templates for the SQL template: SELECT Graduation_College FROM people WHERE Name = '[people.Name]';. EXIST!!!
Text templates for the SQL template: SELECT e.Year_working FROM employment e JOIN company c ON e.Company_ID = c.Company_ID JOIN people p ON e.People_ID = p.People_ID WHERE c.Name = '[company.Name]' AND p.Name = '[people.Name]';. EXIST!!!


In [17]:
sql_to_text_templates_with_ids = {}
for t, t_id in zip(sql_templates, sql_templates_with_ids):
    sql_to_text_templates_with_ids[t_id] = sql_to_text_templates[t]

In [20]:
# Step 3: Generate Evaluation Data (Text Queries and Answers)
qa_generator = QADataGenerator(db_tool)
all_answers_to_text_queries = qa_generator.generate(sql_to_text_templates_with_ids)
qa_generator.print_query_stats(all_answers_to_text_queries)
file_name = f"{linguistic_attr}_with_ids.json"
qa_generator.save(all_answers_to_text_queries, root_dir, file_name, overwrite=True)

The number of generated SQL queries:  57
The number of generated text queries:  171


In [6]:
# How About Using Existing Queries?
# import json
# file_name = "spider-database/database/company_employee/company_employee.json"
# with open(f"{file_name}", 'r') as f:
#     filtered_data = json.load(f)
# questions = [example['question'] for example in filtered_data]
# sql_queries = [example['query'] for example in filtered_data]
# questions

['How many companies are headquartered in the US?',
 'List the names of companies by ascending number of sales.',
 'What are the headquarters and industries of all companies?',
 'Show the names of companies in the banking or retailing industry?',
 'What is the maximum and minimum market value of companies?',
 'What is the headquarter of the company with the largest sales?',
 'Show the different headquarters and number of companies at each headquarter.',
 'Show the most common headquarter for companies.',
 'Show the headquarters that have at least two companies.',
 'Show the headquarters that have both companies in banking industry and companies in oil and gas industry.',
 'Show the names of companies and of employees.',
 'Show names of companies and that of employees in descending order of number of years working for that employee.',
 'Show the names of employees that work for companies with sales bigger than 200.',
 'Show the names of companies and the number of employees they have',
