In [183]:
!pip install smolagents python-dotenv sqlalchemy --upgrade -q


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [None]:
from google.colab import userdata

my_token = userdata.get('HF_TOKEN')
with open('.env', 'w') as f:
    f.write(f"HF_TOKEN={my_token}")


In [185]:
from dotenv import load_dotenv
import os

load_dotenv()

True

In [186]:
db = "professional_basketball"
path_json = f"dataset/{db}/{db}.json"
path_sql = f"dataset/{db}/{db}.sqlite"

In [187]:
import json

question = None
with open(path_json, "r") as f:
    questions = json.load(f)

with open("golds.json", 'r') as v:
  golds = json.load(v)
  q_ids = {g["question_id"]: g for g in golds}

In [188]:
from sqlalchemy import create_engine, inspect, text



db_path = f"sqlite:///{path_sql}"
db_name = f"{db}.sqlite"

if not os.path.exists(path_sql):
    print("WARNING: db not found.")

engine = create_engine(db_path)

In [189]:
inspector = inspect(engine)
table_names = inspector.get_table_names()

schema = "Database Schema:\n"

for table in table_names:
    schema += f"Table: {table}\n"
    columns = inspector.get_columns(table)
    for col in columns:
        schema += f"  - {col['name']} ({col['type']})\n"

print(schema)

Database Schema:
Table: awards_coaches
  - id (INTEGER)
  - year (INTEGER)
  - coachID (TEXT)
  - award (TEXT)
  - lgID (TEXT)
  - note (TEXT)
Table: awards_players
  - playerID (TEXT)
  - award (TEXT)
  - year (INTEGER)
  - lgID (TEXT)
  - note (TEXT)
  - pos (TEXT)
Table: coaches
  - coachID (TEXT)
  - year (INTEGER)
  - tmID (TEXT)
  - lgID (TEXT)
  - stint (INTEGER)
  - won (INTEGER)
  - lost (INTEGER)
  - post_wins (INTEGER)
  - post_losses (INTEGER)
Table: draft
  - id (INTEGER)
  - draftYear (INTEGER)
  - draftRound (INTEGER)
  - draftSelection (INTEGER)
  - draftOverall (INTEGER)
  - tmID (TEXT)
  - firstName (TEXT)
  - lastName (TEXT)
  - suffixName (TEXT)
  - playerID (TEXT)
  - draftFrom (TEXT)
  - lgID (TEXT)
Table: player_allstar
  - playerID (TEXT)
  - last_name (TEXT)
  - first_name (TEXT)
  - season_id (INTEGER)
  - conference (TEXT)
  - league_id (TEXT)
  - games_played (INTEGER)
  - minutes (INTEGER)
  - points (INTEGER)
  - o_rebounds (INTEGER)
  - d_rebounds (INTEGE

In [190]:
from smolagents import tool
@tool
def sql_engine(query: str) -> str:
    """
    Allows you to perform SQL queries on the table. Returns a string representation of the result.

    Args:
        query: The query to perform.
    """
    output = ""
    MAX_ROWS = 10
    with engine.connect() as con:
        rows = con.execute(text(query))
        for i, row in enumerate(rows):
          if i > MAX_ROWS:
            output += "\n...Output truncated."
            return output
          else:
            output += "\n" + str(row)

    return output

In [None]:
from smolagents import CodeAgent, InferenceClientModel, EMPTY_PROMPT_TEMPLATES

# Define the optimized system prompt combining PDF Examples + Smolagents Technical Logic
system_prompt = """
You are an expert Data Scientist specialized in Text-to-SQL tasks. Your goal is to answer natural language questions by generating valid, executable SQL queries.

Your responsibility is to decide WHEN it is safe to reason directly and WHEN assumptions must be validated using tool calls.
The tools at your disposal are:
- sql_engine(query: str) -> str: Executes the provided SQL query and returns the results as a string.
- final_answer(sql_string: str) -> None: Finalizes the interaction by returning the SQL query string as the answer.
The tools must be called within the {{code_block_opening_tag}}...{{code_block_closing_tag}} tags.

GENERAL PROTOCOL:
1. Always begin with one or more Thought sections describing your plan.
2. Use the database schema to identify tables, columns, primary keys, foreign keys, and linking tables.
3. You must strictly follow the cycle: Thought -> Code -> Observation -> Thought. 
If you execute a tool call without first producing a Thought section, you MUST stop and restate the Thought before proceeding.
4. The final output MUST be a SQL QUERY STRING, not query results and not natural language.

MODES OF OPERATION:

MODE A — DIRECT ANSWER
Applies when the task is simple, unambiguous, and low-risk.
If you are not confident that all conditions for Mode A are met, you MUST switch to Mode B or C.

MODE B — VALIDATION REQUIRED
You MUST run at least one validation query BEFORE finalizing IF ANY of the following are true:

- Table or column names may require quoting
- Column ownership is ambiguous (e.g. id, code, type, description)
- More than one plausible join path exists
- NULL values may be present

Validation queries MUST:
- Verify that columns exist
- Verify join correctness
- Return rows

MODE C — FORCED EXPLORATION
You MUST use exploratory queries when ANY high-risk signal appears:

- Identifiers contain spaces, hyphens, special characters, or resemble SQL keywords
  -> Run: SELECT * FROM "table" LIMIT 1
- Filters based on names, labels, or natural language values
  -> Run: SELECT DISTINCT column FROM table
- Numeric operations on TEXT columns
  -> Inspect raw values BEFORE casting
- Unclear foreign keys or schema ambiguity
  -> Test joins explicitly
- Semantic mismatch between question and stored values
  -> Inspect distinct values instead of guessing

CRITICAL RULES:

- SQL must ALWAYS be passed as a STRING to final_answer()
- NEVER include natural language inside final_answer()
- NEVER autocorrect table or column names
- After EACH tool call:
  - Read the result
  - Write a new Thought section interpreting it
  - Decide the next step explicitly


REASONING TRACE EXAMPLES:

Example 1

Question:
"List the names of all cities in Italy."

Thought:
I check the schema mentally. The City table contains Name and CountryCode.
The filter is an exact literal and no joins are required.
I can check the results of the final query to verify correctness.

{{code_block_opening_tag}}
print(sql_engine(query="SELECT Name FROM City WHERE CountryCode = 'ITA' LIMIT 3"))
{{code_block_closing_tag}}

[ANS]
Roma
Milano
Napoli
[/ANS]

Thought:
Since the results look correct, I can finalize the final query by calling the final_answer tool.

{{code_block_opening_tag}}
final_answer("SELECT Name FROM City WHERE CountryCode = 'ITA'")
{{code_block_closing_tag}}

────────────────────────────

Example 2

Question: 
"Find the average age of employees in the Marketing department."

Thought:
This seems straightforward. I need to join Employees and Departments.
I'll draft the standard join query and test it immediately to verify the data returns as expected.

{{code_block_opening_tag}}
print(sql_engine(query="SELECT AVG(e.Age) FROM Employees e JOIN Departments d ON e.DepID = d.ID WHERE d.Name = 'Marketing'"))
{{code_block_closing_tag}}

[ANS]
(None,)
[/ANS]

Thought:
The result is None (no rows). This means my filter `WHERE d.Name = 'Marketing'` failed.
The department name might be stored differently (e.g., 'Marketing Dept', 'MKT', or case-sensitive).
I must check the actual values in the Departments table.

{{code_block_opening_tag}}
print(sql_engine(query="SELECT Name FROM Departments WHERE Name LIKE '%Market%'"))
{{code_block_closing_tag}}

[ANS]
('Marketing Dept.',)
[/ANS]

Thought:
Ah, the stored value is "Marketing Dept." with a period.
I will correct my query to match this exact string and test again.

{{code_block_opening_tag}}
print(sql_engine(query="SELECT AVG(e.Age) FROM Employees e JOIN Departments d ON e.DepID = d.ID WHERE d.Name = 'Marketing Dept.'"))
{{code_block_closing_tag}}

[ANS]
(34.5,)
[/ANS]

Thought:
The query now returns a valid numeric result. I can safely finalize.

{{code_block_opening_tag}}
final_answer("SELECT AVG(e.Age) FROM Employees e JOIN Departments d ON e.DepID = d.ID WHERE d.Name = 'Marketing Dept.'")
{{code_block_closing_tag}}

────────────────────────────

Example 3

Question:
"What is the total revenue per customer?"

Thought:
The question asks for a sum of revenue.
I need to check if 'revenue' is a numeric column or a string with symbols (e.g. '$').
If it's text, a simple SUM will fail or return 0.
Mode C applies: I must inspect the data first.

{{code_block_opening_tag}}
print(sql_engine(query="SELECT revenue FROM Orders LIMIT 3"))
{{code_block_closing_tag}}

[ANS]
$120.00
$85.50
$200.00
[/ANS]

Thought:
The data contains '$' symbols, so it is likely stored as TEXT.
I need to clean this by removing the '$' and casting to a number (REAL/FLOAT) before summing.
I will construct a query using `REPLACE` and `CAST`, and I MUST test it to ensure the math works correctly.

{{code_block_opening_tag}}
print(sql_engine(query="SELECT customer_id, SUM(CAST(REPLACE(revenue, '$', '') AS REAL)) FROM Orders GROUP BY customer_id LIMIT 3"))
{{code_block_closing_tag}}

[ANS]
('C001', 350.5)
('C002', 120.0)
('C003', 500.0)
[/ANS]

Thought:
The test query returned valid numeric sums, confirming that the cleaning logic works.
I can now submit the final correct query.

{{code_block_opening_tag}}
final_answer("SELECT customer_id, SUM(CAST(REPLACE(revenue, '$', '') AS REAL)) FROM Orders GROUP BY customer_id")
{{code_block_closing_tag}}

────────────────────────────

Here are the rules you should always follow to solve your task:
1. Always provide a 'Thought:' sequence, and a '{{code_block_opening_tag}}' sequence ending with '{{code_block_closing_tag}}', else you will fail.
2. Use only variables that you have defined!
3. Always use the right arguments for the tools. DO NOT pass the arguments as a dict as in 'answer = wikipedia_search({'query': \"What is the place where James Bond lives?\"})', but use the arguments directly as in 'answer = wikipedia_search(query=\"What is the place where James Bond lives?\")'.
4. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'.
5. Never create any notional variables in our code, as having these in your logs will derail you from the true variables.
6. You can use imports in your code, but only from the following list of modules: {{authorized_imports}}
7. The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist.
8. Don't give up! You're in charge of solving the task, not providing directions to solve it.
"""


my_templates_dict = EMPTY_PROMPT_TEMPLATES.copy()
my_templates_dict["system_prompt"] = system_prompt

In [192]:
agent = CodeAgent(
    tools=[sql_engine],
    model=InferenceClientModel(model_id="Qwen/Qwen3-8B", token=my_token),
    prompt_templates=my_templates_dict,
    verbosity_level=2
)

In [193]:
import re
import ast

def get_stats(agent):
    log_parts = []
    sql_query = None
    reasoning_len = 0
    tool_call_count = 0
    errors_count = 0
    is_final_answer = False

    for i, step in enumerate(agent.memory.steps):
        if i == 0: continue  # Skip the task step

        # 1. Capture Final Answer
        if getattr(step, 'is_final_answer', False):
            # Convert to string and strip formatting to be safe
            sql_query = str(step.action_output).strip()
            is_final_answer = True

        # 2. Capture Thought (Model Output)
        thought = getattr(step, 'model_output', getattr(step, 'thought', None))
        if thought:
            # Remove code blocks to get pure reasoning
            # Handles both standard markdown ``` and XML-style <code>
            clean_thought = re.sub(r'```.*?```', '', thought, flags=re.DOTALL)
            clean_thought = re.sub(r'<code>.*?</code>', '', clean_thought, flags=re.DOTALL)
            clean_thought = clean_thought.replace('\n', ' ').strip()
            if clean_thought:
                log_parts.append(f"{clean_thought}")

        # 3. Capture Tool Calls (ROBUST EXTRACTION via AST)
        if hasattr(step, 'tool_calls') and step.tool_calls and not is_final_answer:
            for tool_call in step.tool_calls:
                tool_call_count += 1
                
                # Get the raw arguments string (e.g., query="SELECT ...")
                args_str = getattr(tool_call, 'arguments', str(tool_call))
                
                extracted_sql = None
                
                # METHOD A: AST Parsing (Primary - 100% Accurate for valid Python)
                try:
                    # We wrap the args in a dummy function call to parse it as a valid expression
                    # This handles: query="Select 'O\'Reilly'" correctly
                    tree = ast.parse(f"func({args_str})")
                    call_node = tree.body[0].value
                    
                    # Look for the 'query' keyword argument
                    for keyword in call_node.keywords:
                        if keyword.arg == 'query':
                            # Extract value based on Python version node types
                            if isinstance(keyword.value, ast.Constant): # Python 3.8+
                                extracted_sql = keyword.value.value
                            elif isinstance(keyword.value, ast.Str): # Older Python
                                extracted_sql = keyword.value.s
                            break
                except Exception:
                    # If AST fails (e.g. invalid syntax generated by agent), fall back to Regex
                    pass

                # METHOD B: Regex Fallback (Improved)
                if not extracted_sql:
                    # This pattern looks for query= followed by any of the 4 python quote types
                    # Group 1: """...""" | Group 2: '''...''' | Group 3: "..." | Group 4: '...'
                    pattern = r'query\s*=\s*(?:"""(.*?)"""|\'\'\'(.*?)\'\'\'|"(.*?)"|\'(.*?)\')'
                    match = re.search(pattern, args_str, re.DOTALL)
                    if match:
                        # Find the first non-None group
                        extracted_sql = next((g for g in match.groups() if g is not None), None)

                if extracted_sql:
                    # Flatten SQL for logging (remove newlines, collapse spaces)
                    flat_sql = extracted_sql.replace('\n', ' ').replace('\\n', ' ')
                    flat_sql = re.sub(r'\s+', ' ', flat_sql).strip()
                    log_parts.append(f"[CALL] {flat_sql}")

        # 4. Capture Observations (Results)
        if hasattr(step, 'observations') and step.observations and not is_final_answer:
            obs = str(step.observations).strip()
            
            # Clean noise
            obs = obs.replace("Execution logs:", "").replace("Last output from code snippet:", "")
            obs = re.sub(r'\bNone\b', '', obs)
            # Remove tuple wrapping like ('result',)
            obs = re.sub(r"^\('(.+)',\)$", r"\1", obs.strip(), flags=re.MULTILINE)
            obs = re.sub(r"^\('(.+)'\)$", r"\1", obs.strip(), flags=re.MULTILINE)

            obs_clean = obs.strip().replace('\n', ' ')

            # Detect "empty" observations
            if not obs_clean or re.fullmatch(r'[\[\]\(\)\s,]*', obs_clean):
                log_parts.append("[ANS] (no rows) [/ANS]")
            else:
                # Truncate very long outputs
                if len(obs_clean) > 300:
                    obs_clean = obs_clean[:300] + "... [truncated]"
                log_parts.append(f"[ANS] {obs_clean} [/ANS]")

        # 5. Capture Errors
        if hasattr(step, 'error') and step.error:
             err_clean = str(step.error).replace('\n', ' ')
             errors_count += 1
             log_parts.append(f"[ERROR] {err_clean}")

    # Final Formatting
    full_log_string = " ".join(log_parts)
    # Sanitize quotes for CSV storage
    full_log_string = full_log_string.replace('"', "'")
    
    reasoning_len = len(full_log_string)
    
    if sql_query:
        # Clean the final SQL query
        sql_query = sql_query.replace("\n", " ").replace("\\n", " ").strip()
        sql_query = re.sub(r'\s+', ' ', sql_query)

    return full_log_string, sql_query, tool_call_count, errors_count, reasoning_len

In [194]:
def compute_execution_accuracy(gt_results, predict_results):
  num_correct = 0
  num_queries = len(gt_results)
  mismatch_idx = []

  for i, result in enumerate(gt_results):
      if set(result['results']) == set(predict_results[i]['results']):
          num_correct += 1
      else:
          mismatch_idx.append(i)

  acc = (num_correct / num_queries)

  return acc

In [195]:
import sqlite3
def run_query(db_path, query):
  conn = sqlite3.connect(db_path)
  try:
    cursor = conn.cursor()
    cursor.execute(query)
    rows = cursor.fetchall()
    conn.close()

    # Flatten results and convert to list of strings
    return [row[0] for row in rows], True
  except:
    return [], False

In [196]:
import time

traces = []
i = 1
for q in questions:
  trace_accuracy = None
  question = q["questions"]
  evidence = q["evidence"]
  difficulty = q["difficulty"]
  q_id = q["question_id"]

  gt_query = q_ids[q_id]["target_sql"]

  USER_PROMPT = f"""DB Schema: {schema}. Question: {evidence}. {question}"""

  print(f"--- Question {i} ---")
  i += 1

  start_time = time.time()
  agent.run(USER_PROMPT)
  end_time = time.time()

  log_string, pred_query, tool_call_count, errors_count, reasoning_len = get_stats(agent)

  rows_gt, _ = run_query(path_sql, gt_query)
  gt_res = [{"results": rows_gt}]

  rows_pred, is_valid_sql = run_query(path_sql, pred_query)
  pred_res= [{"results": rows_pred}]

  if is_valid_sql:
    acc = compute_execution_accuracy(gt_res, pred_res)

  else:
    trace_accuracy = 0
    acc = 0


  complete_trace = {
    "question_id": q_id,
    "input": USER_PROMPT,
    "output": log_string,
    "difficulty": difficulty,
    "pred_query": pred_query,
    "target_query": gt_query,
    "tool_call_count": tool_call_count,
    "error_count": errors_count,
    "latency": round(end_time-start_time, 2),
    "reasoning_len": reasoning_len,
    "execution_accuracy": int(acc),
    "trace_accuracy": trace_accuracy
  }

  traces.append(complete_trace)


with open(f"traces_CA_final/{db}_final_traces.json", "w") as f:
  json.dump(traces, f, indent=2, ensure_ascii=False)

master_file_path = "traces_CA_final/traces_CA_final.json"
master_traces = []

with open(master_file_path, "r") as f:
  file_content = f.read()
  if file_content:
    master_traces = json.loads(file_content)

master_traces.extend(traces)

with open(master_file_path, "w") as f:
    json.dump(master_traces, f, indent=2, ensure_ascii=False)

--- Question 1 ---


--- Question 2 ---


--- Question 3 ---


--- Question 4 ---


AgentGenerationError: Error in generating model output:
402 Client Error: Payment Required for url: https://router.huggingface.co/nscale/v1/chat/completions (Request ID: Root=1-696cd149-2f57f5a6540513ec6314449b;73a7d7be-4428-4c4c-bcf4-d4d709ef156c)

You have reached the free monthly usage limit for nscale. Subscribe to PRO to get 20x more included usage, or add pre-paid credits to your account.