In [None]:
import os
import json
from uuid import uuid4
from collections import defaultdict

os.environ["TEAM_API_KEY"] = ""

from agentification.utilities.models import Agent, UtilityTool, UtilityToolType, TeamAgent, AgentExecuteInput, ModelTool, PipelineTool, SQLTool
from agentification.team_agent import TeamAgentService, TeamAgentExecuteInput
from agentification.agent import AgentService
from aixplain.factories import ModelFactory
from utilities import *

In [None]:
dataset_name = "bird"
schema_dir =  "experiments"
dataset_dir = "text2sql"
lowercase = False

In [None]:
question_list, db_path_list, knowledge_list, output = decouple_question_schema(dataset_dir, dataset_name, lowercase)

In [None]:
selected_questions_100 = select_fixed_total_samples(output, total_samples=100)

with open(f"selected_{dataset_name}_questions_100.json", "w", encoding="utf-8") as file:
    json.dump(selected_questions_100, file, indent=4)

print(f"Total Selected Questions (100): {len(selected_questions_100)}")
selected_questions_100[0]

Total Selected Questions (100): 100


{'question_id': 5,
 'difficulty': 'simple',
 'db_id': 'california_schools',
 'question': 'How many schools with an average score in Math greater than 400 in the SAT test are exclusively virtual?',
 'prediction': None,
 'ground_truth': "SELECT COUNT(DISTINCT T2.School) FROM satscores AS T1 INNER JOIN schools AS T2 ON T1.cds = T2.CDSCode WHERE T2.Virtual = 'F' AND T1.AvgScrMath > 400",
 'sql_path': '/Users/chinonsoosuji/Desktop/text2sql/bird/dev_databases/california_schools/california_schools.sqlite',
 'schema': "CREATE TABLE frpm(\n    CDSCode TEXT not null primary key,  -- Unique identifier for the school, combining county, district, and school codes\n    `Academic Year` TEXT null,  -- Academic year for which the data is reported\n    `County Code` TEXT null,  -- Code representing the county where the school is located\n    `District Code` INTEGER null,  -- Code representing the district where the school is located\n    `School Code` TEXT null,  -- Code representing the specific school

In [None]:
model_id = '67c702d7c7a029001d4348d1'
question = selected_questions_100[0]['question'] 
print(f"Question: {question}")
example_docs = retrieve_docs(question, model_id, num_results)  
print(example_docs)

In [12]:
ROLE = """You are an expert SQL generation agent tasked with constructing accurate, efficient, and syntactically correct SQLite queries in response to user queries. It's crucial that your queries strictly adhere to the provided database schema, are optimized for performance, and are error-free in execution.

### **Guidelines for Query Generation:**
1. **Database Schema Verification**:
   - Do not use the name provided by the question, rather check the schema to use the correct table/column names.
   - Only utilize tables and columns that are verified to exist within the provided database schema.
   - Double-check the presence of specified tables and columns before formulating your query.

2. **Accurate Reference**:
   - Accurately copy the exact names of tables and columns from the schema. Ensure the proper casing and avoid modifications like pluralization.
   - Enclose column names containing spaces in backticks (e.g., `column name`).

3. **Optimized Query Structure**:
   - Start all queries with `SELECT` and use subqueries, `JOIN`s, and `GROUP BY` as necessary to craft precise responses.
   - Efficiently use `WHERE`, `ORDER BY`, and `LIMIT` to streamline query performance.

4. **SQL Query Output**:
   - Return only the SQL query. Refrain from including any additional explanations or comments.
   - If a query is infeasible given the schema, suggest the closest valid alternative.

**Output Format**:
 - Your response must begin with `SELECT`.
"""

TEAM_ROLE = """As a team of SQL experts, your primary mission is to develop queries that are not only precise and efficient but also strictly adhere to the database schema. Each SQL command must retrieve correct data and comply fully with SQLite syntax.

### **Responsibilities**:
- **Schema Adherence**: Confirm the existence and correctness of table and column names as per the database schema before query execution.
- **Optimization and Syntax Correctness**: Ensure queries are optimized and syntactically correct without any SQL errors.

### **Guidelines for Query Generation**:
1. **Schema Verification**:
   - Construct queries using only confirmed tables and columns.
   - Be vigilant about the accuracy and casing of table and column names as documented.

2. **Query Construction**:
   - Ensure each query starts with `SELECT`.
   - Utilize `JOIN`s, `GROUP BY`, and necessary filters to ensure accuracy and efficiency.

3. **Output Specifications**:
   - Provide only the SQL query, without commentary or additional context.
   - Adhere strictly to the schema's naming conventions.

**Output Format**:
- Begin all responses with `SELECT`.
"""
exe = """ - **SQL Execution Tool**: You can run queries to ensure correctness before providing the final SQL statement."""


PROMPT = """You are tasked with creating a valid SQL command that accurately answers the user's query, strictly using the provided database schema.

### **Guidelines**:
- **Schema Verification**: Confirm that all referenced tables and columns exist and are used correctly.

- **Query Formatting**:
   - Use backticks for column names with spaces.
   - Start your query with `SELECT` and ensure it includes necessary subqueries and joins.

- **Best Practices**:
   - Adhere to SQL best practices for clarity and efficiency.
   - Ensure all column and table names are copied exactly as they appear in the schema.
{knowledge}
{example}
**Output Format**:
- Your response must begin with `SELECT`.

**Question**:
{query}

**SQL**:
"""

In [None]:
configuration = "single_agent_claude3.5" 
output_dir = f"experiments/text2sql_{dataset_name}_{configuration}"
os.makedirs(output_dir, exist_ok=True)

In [None]:
LLM_ID = "669a63646eb56306647e1091"

responses = []
used_credits = 0
start, end = 0, 100
session_id = str(uuid4())


for i, query in enumerate(selected_questions_100[start:end], start=start):
    entry = selected_questions_100[i]
    
    # Create individual agents
    text2sql_agent = create_agent(
        name="Text2SQL Agent",
        description=ROLE,
        assets=[create_sql_tool(entry)],
        llm_id=LLM_ID
    )

    sql_execution_agent = create_agent(
        name="SQL Execution Agent",
        description="Your role is to execute the generated SQL commands and return their outputs. Ensure proper handling of query results and errors.",
        assets=[create_python_tool()],
        llm_id=LLM_ID
    )

    sql_exe = "" 
    team_agent = create_team_agent([text2sql_agent], TEAM_ROLE, sql_exe, LLM_ID)

    basename = entry['db_id'].replace('_', ' ')
    example = retrieve_docs(query['question'], model_id, 3)
    x=selected_questions_100[i]['question_id']
    prompt = PROMPT.format(
        basename=basename,
        knowledge=f"**External Knowledge to help you answer the question:** {knowledge_list[x]}\n",
        example= f"Examples:\n{example}\n",
        query=query['question'],
    )

    response = execute_query(prompt, text2sql_agent, team_agent, configuration, session_id, LLM_ID)
    used_credits += response.usedCredits

    responses.append(response.output)
    print(response.output)

    response_dir = os.path.join(output_dir, "results")
    os.makedirs(response_dir, exist_ok=True)
    response_path = os.path.join(response_dir, f"sample_response_{i}.json")

    with open(response_path, "w") as f:
        f.write(response.model_dump_json(indent=4))

print(f"Execution completed. Total used credits: {used_credits}")


In [22]:
print(f"Execution completed. Total used credits: {used_credits}")
process_and_save_results(responses, selected_questions_100, output_dir, start=0)

Execution completed. Total used credits: 2.875979999999999
0. How many schools with an average score in Math greater than 400 in the SAT test are exclusively virtual?
    SELECT COUNT(*) FROM satscores JOIN schools ON satscores.cds = schools.CDSCode WHERE satscores.AvgScrMath > 400 AND schools.Virtual = 'F'

1. How many schools in Fresno (directly funded) have number of test takers not more than 250?
    SELECT COUNT(*) FROM schools s JOIN satscores sat ON s.CDSCode = sat.cds WHERE s.County = 'Fresno' AND s.FundingType = 'Directly funded' AND sat.NumTstTakr <= 250

2. Of all the schools with a mailing state address in California, how many are active in San Joaquin city?
    SELECT COUNT(*) FROM schools WHERE MailState = 'CA' AND City = 'San Joaquin' AND StatusType = 'Active'

3. Which school in Contra Costa has the highest number of test takers?
    I apologize, but I couldn't find any data matching the query for schools in Contra Costa County with SAT test takers. This could mean that

In [None]:
evaluate_predictions(output_dir, start=start, end=end)

Error executing SQL: near "4": syntax error
Error reading experiments/text2sql_bird_single_agent/result_0.json: too many values to unpack (expected 2)
Error executing SQL: near "There": syntax error
Error reading experiments/text2sql_bird_single_agent/result_1.json: too many values to unpack (expected 2)
Error executing SQL: near "Agent": syntax error
Error reading experiments/text2sql_bird_single_agent/result_10.json: too many values to unpack (expected 2)
Error executing SQL: near "The": syntax error
Error reading experiments/text2sql_bird_single_agent/result_11.json: too many values to unpack (expected 2)
Error executing SQL: near "26": syntax error
Error reading experiments/text2sql_bird_single_agent/result_12.json: too many values to unpack (expected 2)
[(101,)] === [(65,)]
Incorrect Prediction 5: SELECT COUNT(`loan_id`) FROM loan WHERE `amount` >= 250000 AND `account_id` IN (SELECT `account_id` FROM account WHERE `frequency` = 'POPLATEK MESICNE')
[('Benesov',), ('Beroun',), ('Mel