# Import Libraries

In [None]:
!pip -q install aixplain

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [None]:
import os
import re
import json
import sqlite3
os.environ["TEAM_API_KEY"] = "TEAM_API_KEY"

from aixplain.enums import DataType
from aixplain.modules.model.utility_model import utility_tool, UtilityModelInput
from aixplain.factories import ModelFactory, AgentFactory

# Prompt

In [None]:
ROLE = """You are an expert SQL generation agent responsible for constructing accurate, efficient, and syntactically correct SQLite queries to answer the user's query. Your task is to strictly adhere to the database schema and generate a query that returns the correct results when executed.

This database contains information about California schools. Your query must be optimized, free of errors, and strictly formatted according to SQLite standards.

### **Your Capabilities:**
- You have access to a **Python interpreter tool** that allows you to validate queries if needed.
- You also have access to a **SQL execution tool**, which you can use to ensure your queries return correct results.

### **Guidelines for Query Generation:**
1. **Use Only Available Tables and Columns**
   - Refer strictly to the database schema provided. Do not assume or infer missing data.
   - If the query requires a specific column or table, ensure it exists in the schema before including it.

2. **Proper Column Formatting**
   - Enclose column names containing spaces in backticks (e.g., `column name`).
   - Maintain the original capitalization of column names as provided in the schema.

3. **Query Structure and Optimization**
   - Your SQL query must always start with `SELECT`.
   - Use subqueries, `JOIN`s, and `GROUP BY` as needed to return accurate results.
   - Ensure efficient filtering using `WHERE` clauses to minimize unnecessary computations.
   - Use `ORDER BY` when sorting is required, and apply `LIMIT` appropriately for queries requesting a specific number of results.

4. **Strictly Return the SQL Query Only**
   - Do not provide explanations, comments, or additional context—only return the SQL statement.

5. **Ensure Query Execution Feasibility**
   - If the query is logically impossible to execute based on the schema, return the closest valid SQL query that aligns with the user’s request.
   - Avoid using undefined functions or SQL clauses unsupported by SQLite.

Note: You have access to a Python interpreter tool for testing and validating your query.

Output Format:
 - Your response must begin with `SELECT`.
"""

prompt = """You are an expert agent tasked with providing a valid and correct SQL command to answer the user's query. Your response must contain only the SQL command, with no additional explanation.

This database contains information about California schools. You must generate a valid SQLite query that produces the correct result when executed.

**Guidelines:**
 - Use only the tables and columns provided in the Database Schema; do not assume or invent any data.
 - Enclose column names containing spaces in backticks (e.g., `column name`).
 - Follow SQL best practices for clarity and efficiency.
 - Return only the SQL query used to generate the results—no explanations. Do not return any additional text, context or answer to the query.
 - Your SQL query must start with `SELECT` and may include subqueries if necessary.
 - Preserve the original capitalization of column names as specified in the schema.
 - Include all relevant column names from the database to ensure proper processing.

**Output Format:**
- Your response must begin with `SELECT`.
"""


# SQL Agent

In [None]:
def read_data(file):
    with open(file, "r", encoding="utf-8") as f:
        return f.read()

def read_binary(file):
    with open(file, "rb") as f:
        data = f.read()
    new_file = file.replace('sqlite', 'db') #Convert .sqlite to .db
    with open(new_file, "wb") as f:
        f.write(data)
    return new_file

def read_json(file):
    with open(file, "r") as f:
        data = json.load(f)
    return data


california_schools_schema = read_data("california_schools.schema")
read_binary("california_schools.sqlite")
ground_truths = read_json("dev.json")

In [None]:
ground_truths[0]

{'question_id': 0,
 'db_id': 'california_schools',
 'question': 'What is the highest eligible free rate for K-12 students in the schools in Alameda County?',
 'evidence': 'Eligible free rate for K-12 = `Free Meal Count (K-12)` / `Enrollment (K-12)`',
 'SQL': "SELECT `Free Meal Count (K-12)` / `Enrollment (K-12)` FROM frpm WHERE `County Name` = 'Alameda' ORDER BY (CAST(`Free Meal Count (K-12)` AS REAL) / `Enrollment (K-12)`) DESC LIMIT 1",
 'difficulty': 'simple'}

In [None]:
#Pick 10 examples
ground_truth = ground_truths[:10]
queries = [item["question"] for item in ground_truth]
SQLs = [item["SQL"] for item in ground_truth]

In [None]:
agent = AgentFactory.create(
    name="Text2SQL Agent",
    description="An agent that converts natural language queries into valid SQLite commands.",
    instructions=ROLE,
    tools=[
        AgentFactory.create_python_interpreter_tool(),
        AgentFactory.create_sql_tool(
            description="You are an SQL expert responsible for querying and retrieving information from the database, while accurately returning the SQL commands used in the process.",
            database="/content/california_schools.db",
            schema=california_schools_schema,
            enable_commit=False,
        ),
    ],
    llm_id="6646261c6eb563165658bbb1",
)


In [None]:
results = []
for i, query in enumerate(queries):
    print(f"{i}. {query}")
    query_ = prompt + "\n\n**Question:**\n" + query + "\n\n**SQL:**"
    response = agent.run(query_)
    if response.status == 'SUCCESS':
        results.append(response.data["output"])
        print('  ', response.data["output"])
        print()

0. What is the highest eligible free rate for K-12 students in the schools in Alameda County?
   SELECT MAX(`Percent (%) Eligible Free (K-12)`) FROM frpm WHERE `County Name` = 'Alameda';

1. Please list the lowest three eligible free rates for students aged 5-17 in continuation schools.
   SELECT `School Name`, `Percent (%) Eligible Free (Ages 5-17)` FROM frpm WHERE `School Type` = 'Continuation' ORDER BY `Percent (%) Eligible Free (Ages 5-17)` ASC LIMIT 3;

2. Please list the zip code of all the charter schools in Fresno County Office of Education.
   SELECT Zip FROM schools WHERE County = 'Fresno' AND Charter = 1 AND DOCType = 'County Office of Education';

3. What is the unabbreviated mailing street address of the school with the highest FRPM count for K-12 students?
   SELECT s.MailStreet FROM frpm f JOIN schools s ON f.CDSCode = s.CDSCode ORDER BY f.`FRPM Count (K-12)` DESC LIMIT 1;

4. Please list the phone numbers of the direct charter-funded schools that are opened after 2000/1

# Result Evaluation

In [None]:
def execute_sql(predicted_sql, ground_truth, db_path):
    try:
        conn = sqlite3.connect(db_path)
        cursor = conn.cursor()

        cursor.execute(predicted_sql)
        predicted_res = cursor.fetchall()

        cursor.execute(ground_truth)
        ground_truth_res = cursor.fetchall()

        conn.close()

        return 1 if set(predicted_res) == set(ground_truth_res) else 0
    except Exception as e:
        print(f"Error executing SQL: {e}")
        return 0

In [None]:
total_correct = 0
num_queries = len(queries)
for i, result in enumerate(results):
  result = re.sub(r"\n", " ", re.sub(r"```", "", re.sub(r";","", re.sub(r"sql", "", result.strip(), flags=re.IGNORECASE))))
  res = execute_sql(result, SQLs[i], "/content/california_schools.sqlite")
  total_correct += res

final_result = f"{(total_correct / num_queries * 100):.2f}%" if num_queries > 0 else "No valid files found"
print(f"Final Accuracy: {final_result}")

Final Accuracy: 60.00%
