In [1]:
# %pip install aixplain nltk

In [2]:
# Download dataset from: https://bird-bench.github.io/

In [None]:
import os
import json
os.environ["TEAM_API_KEY"] = ""
from aixplain.factories import IndexFactory
from utilities import *

In [None]:
dataset_name = "bird"
schema_dir =  "experiments"
dataset_dir = "" # The directory where the dataset is stored
file_path = os.path.join(dataset_dir, dataset_name, "train/train.json")

In [3]:
chunked_records = process_train_data(file_path)
print(f"Generated {len(chunked_records)} chunks")
chunked_records[0].__dict__

Generated 9428 chunks


{'value': 'Question: Name movie titles released in year 1945. Sort the listing by the descending order of movie popularity. ; \nSQL: SELECT movie_title FROM movies WHERE movie_release_year = 1945 ORDER BY movie_popularity DESC LIMIT 1',
 'value_type': <DataType.TEXT: 'text'>,
 'id': '80e43782-a5f8-4082-acf6-2ccbef4a0d18',
 'uri': '',
 'attributes': {'Database Id': 'movie_platform',
  'Evidence': 'released in the year 1945 refers to movie_release_year = 1945;'}}

In [None]:
# Create Index
index_model = IndexFactory.create(name="Index: Spider train set", description="The bird train set for text2sql indexing for retrieval")
for batch_idx in range(0, len(chunked_records), 20):
    try:
        index_model.upsert(chunked_records[batch_idx:batch_idx+20])
    except:
        print(f"Error on batch {batch_idx}")

In [8]:
question_list, db_path_list, knowledge_list, output = decouple_question_schema(dataset_dir, dataset_name)

In [9]:
selected_questions_100 = select_fixed_total_samples(output, total_samples=100)

with open(f"selected_{dataset_name}_questions_100.json", "w", encoding="utf-8") as file:
    json.dump(selected_questions_100, file, indent=4)

print(f"Total Selected Questions (100): {len(selected_questions_100)}")

Total Selected Questions (100): 100


In [10]:
question = selected_questions_100[0]['question'] 
print(f"Question: {question}")

Question: What is the monthly average number of schools that opened in Alameda County under the jurisdiction of the Elementary School District in 1980?


In [None]:
index_model = '6825dfc1865b56001de2a955' #or index_model.id
example_docs = retrieve_docs(question, index_model, 5)  
print(example_docs)

In [12]:
ROLE = """You are an expert SQL generation agent tasked with constructing accurate, efficient, and syntactically correct SQLite queries in response to user queries. It's crucial that your queries strictly adhere to the provided database schema, are optimized for performance, and are error-free in execution.

### **Guidelines for Query Generation:**
1. **Database Schema Verification**:
   - Do not use the name provided by the question, rather check the schema to use the correct table/column names.
   - Only utilize tables and columns that are verified to exist within the provided database schema.
   - Double-check the presence of specified tables and columns before formulating your query.

2. **Accurate Reference**:
   - Accurately copy the exact names of tables and columns from the schema. Ensure the proper casing and avoid modifications like pluralization.
   - Enclose column names containing spaces in backticks (e.g., `column name`).

3. **Optimized Query Structure**:
   - Start all queries with `SELECT` and use subqueries, `JOIN`s, and `GROUP BY` as necessary to craft precise responses.
   - Efficiently use `WHERE`, `ORDER BY`, and `LIMIT` to streamline query performance.

4. **SQL Query Output**:
   - Return only the SQL query. Refrain from including any additional explanations or comments.
   - If a query is infeasible given the schema, suggest the closest valid alternative.

**Output Format**:
 - Your response must begin with `SELECT`.
"""

TEAM_ROLE = """As a team of SQL experts, your primary mission is to develop queries that are not only precise and efficient but also strictly adhere to the database schema. Each SQL command must retrieve correct data and comply fully with SQLite syntax.

### **Responsibilities**:
- **Schema Adherence**: Confirm the existence and correctness of table and column names as per the database schema before query execution.
- **Optimization and Syntax Correctness**: Ensure queries are optimized and syntactically correct without any SQL errors.

### **Guidelines for Query Generation**:
1. **Schema Verification**:
   - Construct queries using only confirmed tables and columns.
   - Be vigilant about the accuracy and casing of table and column names as documented.

2. **Query Construction**:
   - Ensure each query starts with `SELECT`.
   - Utilize `JOIN`s, `GROUP BY`, and necessary filters to ensure accuracy and efficiency.

3. **Output Specifications**:
   - Provide only the SQL query, without commentary or additional context.
   - Adhere strictly to the schema's naming conventions.

**Output Format**:
- Begin all responses with `SELECT`.
"""
exe = """ - **SQL Execution Tool**: You can run queries to ensure correctness before providing the final SQL statement."""


PROMPT = """You are tasked with creating a valid SQL command that accurately answers the user's query, strictly using the provided database schema.

### **Guidelines**:
- **Schema Verification**: Confirm that all referenced tables and columns exist and are used correctly.

- **Query Formatting**:
   - Use backticks for column names with spaces.
   - Start your query with `SELECT` and ensure it includes necessary subqueries and joins.

- **Best Practices**:
   - Adhere to SQL best practices for clarity and efficiency.
   - Ensure all column and table names are copied exactly as they appear in the schema.
{knowledge}
{example}
**Output Format**:
- Your response must begin with `SELECT`.

**Question**:
{query}

**SQL**:
"""

In [13]:
configuration = "single_agent" 
output_dir = f"experiments/text2sql_{dataset_name}_{configuration}"
os.makedirs(output_dir, exist_ok=True)

In [None]:
LLM_ID = "67fd9ddfef0365783d06e2ef" #GPT4.1 mini

responses = []
used_credits = 0
start, end = 0, 10 #change 1 to the number of samples you want to run


for i, query in enumerate(selected_questions_100[start:end], start=start):
    entry = selected_questions_100[i]
    
    # Create individual agents
    text2sql_agent = create_agent(
        name="Text2SQL Agent",
        description=ROLE,
        assets=[create_sql_tool(entry)],
        llm_id=LLM_ID
    )

    sql_execution_agent = create_agent(
        name="SQL Execution Agent",
        description="Your role is to execute the generated SQL commands and return their outputs. Ensure proper handling of query results and errors.",
        assets=[create_python_tool()],
        llm_id=LLM_ID
    )

    sql_exe = ""
 
    team_agent = create_team_agent([text2sql_agent], TEAM_ROLE, sql_exe, LLM_ID)

    basename = entry['db_id'].replace('_', ' ')
    example = retrieve_docs(query['question'], index_model, 3)
    x=selected_questions_100[i]['question_id']
    prompt = PROMPT.format(
        basename=basename,
        knowledge=f"**External Knowledge to help you answer the question:** {knowledge_list[x]}\n",
        example= f"Examples:\n{example}\n",
        query=query['question'],
    )

    plan_inspector = True
    response = execute_query(prompt, text2sql_agent, team_agent, configuration, plan_inspector)
    used_credits += response.used_credits

    responses.append(response.data.output)
    # print(response.data.output)

    response_dir = os.path.join(output_dir, "results")
    os.makedirs(response_dir, exist_ok=True)
    response_path = os.path.join(response_dir, f"sample_response_{i}.json")

    safe_dump_response_step(response, response_path)


In [22]:
print(f"Execution completed. Total used credits: {used_credits}")
process_and_save_results(responses, selected_questions_100, output_dir, start=0)

Execution completed. Total used credits: 0.22101500000000004
0. What is the monthly average number of schools that opened in Alameda County under the jurisdiction of the Elementary School District in 1980?
    SELECT COUNT(*) / 12 AS monthly_average FROM schools WHERE County = 'Alameda' AND DOC = '52' AND strftime('%Y', OpenDate) = '1980';

1. Between 1/1/2000 to 12/31/2005, how many directly funded schools opened in the county of Stanislaus?
    SELECT COUNT(*) FROM schools WHERE `FundingType` = 'Directly Funded' AND `County` = 'Stanislaus' AND `OpenDate` BETWEEN '2000-01-01' AND '2005-12-31';

2. Which schools served a grade span of Kindergarten to 9th grade in the county of Los Angeles and what is its Percent (%) Eligible FRPM (Ages 5-17)?
    SELECT `School Name`, `Percent (%) Eligible FRPM (Ages 5-17)` FROM frpm WHERE `County Name` = 'Los Angeles' AND `Low Grade` = 'K' AND `High Grade` = '9';

3. What is the average writing score of each of the schools managed by Ricci Ulrich? Lis

In [23]:
evaluate_predictions(output_dir, start=start, end=end)

Incorrect Prediction 0: SELECT COUNT(*) / 12 AS monthly_average FROM schools WHERE County = 'Alameda' AND DOC = '52' AND strftime('%Y', OpenDate) = '1980'
Incorrect Prediction 1: SELECT COUNT(*) FROM schools WHERE `FundingType` = 'Directly Funded' AND `County` = 'Stanislaus' AND `OpenDate` BETWEEN '2000-01-01' AND '2005-12-31'
Incorrect Prediction 2: SELECT `School Name`, `Percent (%) Eligible FRPM (Ages 5-17)` FROM frpm WHERE `County Name` = 'Los Angeles' AND `Low Grade` = 'K' AND `High Grade` = '9'
Error executing SQL: near "The": syntax error
Incorrect Prediction 4: The Percent (%) Eligible Free (K-12) in the school administered by an administrator whose first name is Alusine is approximately 70.15%, and the district code of the school is 64857.
Error executing SQL: near "The": syntax error
Incorrect Prediction 7: The school with the lowest average reading score is located at the mailing street address '1111 Van Ness Avenue'. However, the school's name is not available in the data.


({'moderate': '40.00% of 5', 'simple': '40.00% of 5'}, '40.00%')