In [8]:
!apt-get update -qq
!apt-get install mysql-server -qq
!pip install mysql-connector-python -q
!service mysql start


print(f'Database started.')

W: Skipping acquire of configured file 'main/source/Sources' as repository 'https://r2u.stat.illinois.edu/ubuntu jammy InRelease' does not seem to provide it (sources.list entry misspelt?)
 * Starting MySQL database server mysqld
   ...done.
Database started.


In [9]:
import os

# download mysql database file:
if not os.path.exists('./data/bird.sql'):
    !mkdir './data'
    !gdown "https://drive.google.com/uc?id=1nHezEJ-px69Di0Xzxnwwrpo2KyVaf4hO" -O './data/'
    !mv './data/BIRD_dev.sql' './data/bird.sql'

# create bird database in mysql server
!mysql -u root -e 'drop database bird;'
!mysql -u root -e 'create database bird;'
!mysql -u root -e 'show databases;'

# put ./data/bird.sql's content to bird database in mysql server:
!mysql -u root bird < ./data/bird.sql

+--------------------+
| Database           |
+--------------------+
| bird               |
| information_schema |
| mysql              |
| performance_schema |
| sys                |
+--------------------+


In [10]:
# verify bird.sql was uploaded:
!mysql -u root -e 'use bird; show tables;'

+----------------------+
| Tables_in_bird       |
+----------------------+
| Country              |
| Examination          |
| Laboratory           |
| League               |
| Match                |
| Patient              |
| Player               |
| Player_Attributes    |
| Team                 |
| Team_Attributes      |
| account              |
| alignment            |
| atom                 |
| attendance           |
| attribute            |
| badges               |
| bond                 |
| budget               |
| card                 |
| cards                |
| circuits             |
| client               |
| colour               |
| comments             |
| connected            |
| constructorResults   |
| constructorStandings |
| constructors         |
| customers            |
| disp                 |
| district             |
| driverStandings      |
| drivers              |
| event                |
| expense              |
| foreign_data         |
| frpm                 |


In [11]:
# python to database imports:
import mysql.connector
import pandas as pd
from typing import List
import sqlalchemy
from sqlalchemy.engine.base import Engine
from sqlalchemy import text

In [12]:
from datasets import load_dataset

dataset = load_dataset("birdsql/bird_mini_dev")
print(dataset["mini_dev_mysql"][0])

Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


README.md: 0.00B [00:00, ?B/s]

mini_dev_mysql-00000-of-00001.json: 0.00B [00:00, ?B/s]

mini_dev_pg-00000-of-00001.json: 0.00B [00:00, ?B/s]

mini_dev_sqlite-00000-of-00001.json: 0.00B [00:00, ?B/s]

Generating mini_dev_mysql split:   0%|          | 0/500 [00:00<?, ? examples/s]

Generating mini_dev_pg split:   0%|          | 0/500 [00:00<?, ? examples/s]

Generating mini_dev_sqlite split:   0%|          | 0/500 [00:00<?, ? examples/s]

{'question_id': 1471, 'db_id': 'debit_card_specializing', 'question': 'What is the ratio of customers who pay in EUR against customers who pay in CZK?', 'evidence': "ratio of customers who pay in EUR against customers who pay in CZK = count(Currency = 'EUR') / count(Currency = 'CZK').", 'SQL': "SELECT  CAST(SUM(CASE WHEN `Currency` = 'EUR' THEN 1 ELSE 0 END) AS DOUBLE) / SUM(CASE WHEN `Currency` = 'CZK' THEN 1 ELSE 0 END) FROM `customers`", 'difficulty': 'simple'}


In [13]:
# create an user agent and grant him access permissions:
!mysql -u root -e "create user 'agent'@'localhost' identified by 'agent_pass';"

# grant the user agent all ther permissions to the bird database
!mysql -u root -e "grant select on bird.* to 'agent'@'localhost'; flush privileges;"
!mysql -u root -e 'select user , host , plugin from mysql.user;'

# show permissions for user agent
!mysql -u root -e 'show grants for agent@localhost;'

+------------------+-----------+-----------------------+
| user             | host      | plugin                |
+------------------+-----------+-----------------------+
| agent            | localhost | caching_sha2_password |
| debian-sys-maint | localhost | caching_sha2_password |
| mysql.infoschema | localhost | caching_sha2_password |
| mysql.session    | localhost | caching_sha2_password |
| mysql.sys        | localhost | caching_sha2_password |
| root             | localhost | auth_socket           |
+------------------+-----------+-----------------------+
+-------------------------------------------------+
| Grants for agent@localhost                      |
+-------------------------------------------------+
| GRANT USAGE ON *.* TO `agent`@`localhost`       |
| GRANT SELECT ON `bird`.* TO `agent`@`localhost` |
+-------------------------------------------------+


In [14]:
from sqlalchemy import create_engine

db_engine = create_engine(
    "mysql+mysqlconnector://agent:agent_pass@localhost:3306/bird"
)
with db_engine.connect() as conn:
    print("Connected successfully!")
    

Connected successfully!


In [15]:
# ──────────────────────────────────────────────
# STEP 1: Define the Persona as a dictionary
# ──────────────────────────────────────────────

persona = {
    "role": "Senior SQL Developer and Data Analyst",

    "instructions": (
        "You generate and execute SQL queries against a MySQL database.\n"
        "Follow this workflow for every user question:\n"
        "  1. List all available tables using the provided tool.\n"
        "  2. Get the schema (columns) of the relevant table(s).\n"
        "  3. Write a syntactically correct MySQL SELECT query.\n"
        "  4. Execute the query and return the results.\n"
    ),

    "rules": [
        "ONLY use table and column names confirmed by the schema tool.",
        "NEVER guess or invent table or column names.",
        "NEVER run DELETE, DROP, UPDATE, INSERT, or any data-modifying SQL.",
        "If the question is unclear, state your assumptions first.",
        "If a query fails, analyze the error and retry with a corrected query.",
    ],

    "output_format": (
        "Always respond with:\n"
        "  - The SQL query you generated\n"
        "  - The query result\n"
        "  - A short natural-language summary of the answer"
    ),
}

# ──────────────────────────────────────────────
# STEP 2: Build the system prompt from persona
# ──────────────────────────────────────────────

def build_system_prompt(persona: dict) -> str:
    """Convert a persona dictionary into a system prompt string."""

    rules_text = "\n".join(f"  - {rule}" for rule in persona["rules"])

    prompt = (
        f"You are a {persona['role']}.\n\n"
        f"INSTRUCTIONS:\n{persona['instructions']}\n"
        f"RULES:\n{rules_text}\n\n"
        f"OUTPUT FORMAT:\n{persona['output_format']}"
    )
    return prompt


# ──────────────────────────────────────────────
# STEP 3: Generate and inspect the prompt
# ──────────────────────────────────────────────

system_prompt = build_system_prompt(persona)
print(system_prompt)

You are a Senior SQL Developer and Data Analyst.

INSTRUCTIONS:
You generate and execute SQL queries against a MySQL database.
Follow this workflow for every user question:
  1. List all available tables using the provided tool.
  2. Get the schema (columns) of the relevant table(s).
  3. Write a syntactically correct MySQL SELECT query.
  4. Execute the query and return the results.

RULES:
  - ONLY use table and column names confirmed by the schema tool.
  - NEVER guess or invent table or column names.
  - NEVER run DELETE, DROP, UPDATE, INSERT, or any data-modifying SQL.
  - If the question is unclear, state your assumptions first.
  - If a query fails, analyze the error and retry with a corrected query.

OUTPUT FORMAT:
Always respond with:
  - The SQL query you generated
  - The query result
  - A short natural-language summary of the answer


In [16]:
# user qwen-2.5-7b-instruct to test the system prompt and get an execution sequence from the llm:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = 'Qwen/Qwen2.5-7B-Instruct'
tokenizer = AutoTokenizer.from_pretrained(model_name)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    dtype = torch.float16,
    device_map = 'auto'
)

def qqwen(system_prompt : str , user_query : str):
    messages = [
        {'role' : 'system' , 'content' : system_prompt},
        {'role' : 'user' , 'content' : user_query}
    ]
    
    
    #apply qwen's chat template:
    text = tokenizer.apply_chat_template(messages , tokenize = False , add_generation_prompt = True)
    #tokenizer the messages:
    tokens = tokenizer(text , return_tensors = 'pt').to(model.device)
    
    print('thinking...')
    #generate output:
    outputs = model.generate(
        **tokens,
        max_new_tokens = 512,
        temperature = 0.1,
        do_sample = False        
    )
    
    response = tokenizer.decode(
        outputs[0],
        skip_special_tokens = True
    )
    
    return response
    
user_query = 'What is the total revenue for each customer ?'
print(qqwen(system_prompt , user_query))

config.json:   0%|          | 0.00/663 [00:00<?, ?B/s]



tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Downloading (incomplete total...): 0.00B [00:00, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

Loading weights:   0%|          | 0/339 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/243 [00:00<?, ?B/s]

The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


thinking...
system
You are a Senior SQL Developer and Data Analyst.

INSTRUCTIONS:
You generate and execute SQL queries against a MySQL database.
Follow this workflow for every user question:
  1. List all available tables using the provided tool.
  2. Get the schema (columns) of the relevant table(s).
  3. Write a syntactically correct MySQL SELECT query.
  4. Execute the query and return the results.

RULES:
  - ONLY use table and column names confirmed by the schema tool.
  - NEVER guess or invent table or column names.
  - NEVER run DELETE, DROP, UPDATE, INSERT, or any data-modifying SQL.
  - If the question is unclear, state your assumptions first.
  - If a query fails, analyze the error and retry with a corrected query.

OUTPUT FORMAT:
Always respond with:
  - The SQL query you generated
  - The query result
  - A short natural-language summary of the answer
user
What is the total revenue for each customer ?
assistant
To provide you with the total revenue for each customer, I nee

In [17]:
# generate plan from llm:

def generate_plan(user_question : str) -> str:
    '''
    Ask llm to generate a step by step plan for answering a question.
    '''
    planning_prompt = (
        "Before taking any action, create a step-by-step plan to answer "
        "the user's question. List each step as a numbered action. "
        "Specify which tool you would use at each step.\n\n"
        "Available tools:\n"
        "  - list_tables: returns all table names in the database\n"
        "  - get_schema(table_name): returns column names for a table\n"
        "  - execute_sql(query): runs a SQL query and returns results\n\n"
        "Output ONLY the plan, nothing else."
    )
    
    system_prompt = build_system_prompt(persona)
    
    user_query = f"{planning_prompt} \n \n Question: {user_question}"
    response = qqwen(system_prompt , user_query)
    
    return response

plan = generate_plan("How many customers pay in EUR ?")
print("═" * 50)
print("PLAN FOR: How many customers pay in EUR?")
print("═" * 50)
print(plan)

thinking...
══════════════════════════════════════════════════
PLAN FOR: How many customers pay in EUR?
══════════════════════════════════════════════════
system
You are a Senior SQL Developer and Data Analyst.

INSTRUCTIONS:
You generate and execute SQL queries against a MySQL database.
Follow this workflow for every user question:
  1. List all available tables using the provided tool.
  2. Get the schema (columns) of the relevant table(s).
  3. Write a syntactically correct MySQL SELECT query.
  4. Execute the query and return the results.

RULES:
  - ONLY use table and column names confirmed by the schema tool.
  - NEVER guess or invent table or column names.
  - NEVER run DELETE, DROP, UPDATE, INSERT, or any data-modifying SQL.
  - If the question is unclear, state your assumptions first.
  - If a query fails, analyze the error and retry with a corrected query.

OUTPUT FORMAT:
Always respond with:
  - The SQL query you generated
  - The query result
  - A short natural-language su