In [None]:
# Download dataset from: https://yale-lily.github.io/spider

In [None]:
# %pip install aixplain nltk

In [None]:
import os
import json
os.environ["TEAM_API_KEY"] = ""
from aixplain.factories import IndexFactory
from utilities import *

In [None]:
dataset_name = "spider"
schema_dir =  "experiments"
dataset_dir = "" # The directory where the dataset is stored
file_path = os.path.join(dataset_dir, dataset_name, "train_combined.json")

In [3]:
chunked_records = process_train_data(file_path)
print(f"Generated {len(chunked_records)} chunks")
chunked_records[0].__dict__

Generated 8651 chunks


{'value': 'Question: what is the biggest city in wyoming; \nSQL: SELECT city_name FROM city WHERE population  =  ( SELECT MAX ( population ) FROM city WHERE state_name  =  "wyoming" ) AND state_name  =  "wyoming";',
 'value_type': <DataType.TEXT: 'text'>,
 'id': '33460b7d-1273-40e0-bd1b-805f27bfa09c',
 'uri': '',
 'attributes': {'Database Id': 'geo', 'Evidence': ''}}

In [None]:
# Create Index
index_model = IndexFactory.create(name="Index: Spider train set", description="The spider train set for text2sql indexing for retrieval")
for batch_idx in range(0, len(chunked_records), 20):
    try:
        index_model.upsert(chunked_records[batch_idx:batch_idx+20])
    except:
        print(f"Error on batch {batch_idx}")

In [5]:
question_list, db_path_list, knowledge_list, output = decouple_question_schema(dataset_dir, dataset_name)

In [6]:
selected_questions_100 = select_fixed_total_samples(output, total_samples=100)

with open(f"selected_{dataset_name}_questions_100.json", "w", encoding="utf-8") as file:
    json.dump(selected_questions_100, file, indent=4)

print(f"Total Selected Questions (100): {len(selected_questions_100)}")

Total Selected Questions (100): 100


In [7]:
question = selected_questions_100[0]['question'] 
print(f"Question: {question}")

Question: Show the stadium name and capacity with most number of concerts in year 2014 or after.


In [None]:
index_model = '6826051f6cb8fb001df525d0' #or index_model.id
example_docs = retrieve_docs(question, index_model, 5)  
print(example_docs)

In [10]:
ROLE = """You are an expert SQL generation agent tasked with constructing accurate, efficient, and syntactically correct SQLite queries in response to user queries. It's crucial that your queries strictly adhere to the provided database schema, are optimized for performance, and are error-free in execution.

### **Guidelines for Query Generation:**
1. **Database Schema Verification**:
   - Do not use the name provided by the question, rather check the schema to use the correct table/column names.
   - Only utilize tables and columns that are verified to exist within the provided database schema.
   - Double-check the presence of specified tables and columns before formulating your query.

2. **Accurate Reference**:
   - Accurately copy the exact names of tables and columns from the schema. Ensure the proper casing and avoid modifications like pluralization.
   - Enclose column names containing spaces in backticks (e.g., `column name`).

3. **Optimized Query Structure**:
   - Start all queries with `SELECT` and use subqueries, `JOIN`s, and `GROUP BY` as necessary to craft precise responses.
   - Efficiently use `WHERE`, `ORDER BY`, and `LIMIT` to streamline query performance.

4. **SQL Query Output**:
   - Return only the SQL query. Refrain from including any additional explanations or comments.
   - If a query is infeasible given the schema, suggest the closest valid alternative.

**Output Format**:
 - Your response must begin with `SELECT`.
"""

TEAM_ROLE = """As a team of SQL experts, your primary mission is to develop queries that are not only precise and efficient but also strictly adhere to the database schema. Each SQL command must retrieve correct data and comply fully with SQLite syntax.

### **Responsibilities**:
- **Schema Adherence**: Confirm the existence and correctness of table and column names as per the database schema before query execution.
- **Optimization and Syntax Correctness**: Ensure queries are optimized and syntactically correct without any SQL errors.

### **Guidelines for Query Generation**:
1. **Schema Verification**:
   - Construct queries using only confirmed tables and columns.
   - Be vigilant about the accuracy and casing of table and column names as documented.

2. **Query Construction**:
   - Ensure each query starts with `SELECT`.
   - Utilize `JOIN`s, `GROUP BY`, and necessary filters to ensure accuracy and efficiency.

3. **Output Specifications**:
   - Provide only the SQL query, without commentary or additional context.
   - Adhere strictly to the schema's naming conventions.

**Output Format**:
- Begin all responses with `SELECT`.
"""
exe = """ - **SQL Execution Tool**: You can run queries to ensure correctness before providing the final SQL statement."""


PROMPT = """You are tasked with creating a valid SQL command that accurately answers the user's query, strictly using the provided database schema.

### **Guidelines**:
- **Schema Verification**: Confirm that all referenced tables and columns exist and are used correctly.

- **Query Formatting**:
   - Use backticks for column names with spaces.
   - Start your query with `SELECT` and ensure it includes necessary subqueries and joins.

- **Best Practices**:
   - Adhere to SQL best practices for clarity and efficiency.
   - Ensure all column and table names are copied exactly as they appear in the schema.
{knowledge}
{example}
**Output Format**:
- Your response must begin with `SELECT`.

**Question**:
{query}

**SQL**:
"""

In [11]:
configuration = "single_agent"
output_dir = f"experiments/text2sql_{dataset_name}_{configuration}"
os.makedirs(output_dir, exist_ok=True)

In [None]:
LLM_ID = "67fd9ddfef0365783d06e2ef" #GPT4.1 mini

responses = []
used_credits = 0
start, end = 0, 10 #change 1 to the number of samples you want to run


for i, query in enumerate(selected_questions_100[start:end], start=start):
    entry = selected_questions_100[i]
    
    text2sql_agent = create_agent(
        name="Text2SQL Agent",
        description=ROLE,
        assets=[create_sql_tool(entry), create_python_tool()],
        llm_id=LLM_ID
    )

    sql_execution_agent = create_agent(
        name="SQL Execution Agent",
        description="Your role is to execute the generated SQL commands and return their outputs. Ensure proper handling of query results and errors.",
        assets=[create_python_tool()],
        llm_id=LLM_ID
    )

    sql_exe = "" 
    team_agent = create_team_agent([text2sql_agent], TEAM_ROLE, sql_exe, LLM_ID)

    basename = entry['db_id'].replace('_', ' ')
    example = retrieve_docs(query['question'], index_model, 3)
    prompt = PROMPT.format(
        basename=basename,
        knowledge='',
        example= f"Examples:\n{example}\n",
        query=query['question'],
    )

    plan_inspector = True
    response = execute_query(prompt, text2sql_agent, team_agent, configuration, plan_inspector)
    used_credits += response.used_credits

    responses.append(response.data.output)
    print(response.data.output)

    response_dir = os.path.join(output_dir, "results")
    os.makedirs(response_dir, exist_ok=True)
    response_path = os.path.join(response_dir, f"sample_response_{i}.json")

    safe_dump_response_step(response, response_path)

In [13]:
print(f"Execution completed. Total used credits: {used_credits}")
process_and_save_results(responses, selected_questions_100, output_dir, start=0)

Execution completed. Total used credits: 0.084275
0. Show the stadium name and capacity with most number of concerts in year 2014 or after.
    SELECT s.Name, s.Capacity FROM stadium s JOIN concert c ON s.Stadium_ID = c.Stadium_ID WHERE c.Year >= '2014' GROUP BY s.Stadium_ID ORDER BY COUNT(c.concert_ID) DESC LIMIT 1

1. What is the average and maximum capacities for all stadiums ?
    SELECT AVG(Capacity) AS average_capacity, MAX(Capacity) AS maximum_capacity FROM stadium;

2. What is the maximum capacity and the average of all stadiums ?
    SELECT MAX(Capacity), AVG(Capacity) FROM stadium

3. Show the name and the release year of the song by the youngest singer.
    SELECT Name, Song_release_year FROM singer ORDER BY Age ASC LIMIT 1

4. Show the name and theme for all concerts and the number of singers in each concert.
    SELECT concert.concert_Name, concert.Theme, COUNT(singer_in_concert.Singer_ID) AS number_of_singers FROM concert JOIN singer_in_concert ON concert.concert_ID = sin

In [14]:
evaluate_sql_predictions(output_dir, start=start, end=end)

[('Somerset Park', 11998)] === [('Somerset Park', 11998)]
[(10621.666666666666, 52500)] === [(10621.666666666666, 52500)]
[(52500, 10621.666666666666)] === [(52500, 730)]
Incorrect Prediction 2: SELECT MAX(Capacity), AVG(Capacity) FROM stadium
[('Tribal King', '2016')] === [('Love', '2016')]
Incorrect Prediction 3: SELECT Name, Song_release_year FROM singer ORDER BY Age ASC LIMIT 1
[('Auditions', 'Free choice', 3), ('Super bootcamp', 'Free choice 2', 2), ('Home Visits', 'Bleeding Love', 1), ('Week 1', 'Wide Awake', 1), ('Week 1', 'Happy Tonight', 2), ('Week 2', 'Party All Night', 1)] === [('Auditions', 'Free choice', 3), ('Super bootcamp', 'Free choice 2', 2), ('Home Visits', 'Bleeding Love', 1), ('Week 1', 'Wide Awake', 1), ('Week 1', 'Happy Tonight', 2), ('Week 2', 'Party All Night', 1)]
[('cat', 12.0), ('dog', 13.4)] === [(12.0, 'cat'), (13.4, 'dog')]
Incorrect Prediction 5: SELECT PetType, MAX(weight) FROM Pets GROUP BY PetType
[(2001,)] === [(2001,)]
[(2,)] === [(2,)]
[] === []
[(

'70.00%'