In [None]:
import os
from sqlalchemy import create_engine, inspect, text
from smolagents import ToolCallingAgent, tool
from smolagents.models import InferenceClientModel
from smolagents import EMPTY_PROMPT_TEMPLATES
import json
from google.colab import userdata

my_token = userdata.get('HF_TOKEN')


In [22]:
db = "address"
path_json = f"dataset/{db}/{db}.json"
path_sql = f"dataset/{db}/{db}.sqlite"

In [23]:
question = None
with open(path_json, "r") as f:
    questions = json.load(f)

with open("golds.json", 'r') as v:
  golds = json.load(v)
  q_ids = {g["question_id"]: g for g in golds}

In [24]:
db_path = f"sqlite:///{path_sql}"
db_name = f"{db}.sqlite"

if not os.path.exists(path_sql):
    print("WARNING: db not found.")

engine = create_engine(db_path)

In [25]:
inspector = inspect(engine)
table_names = inspector.get_table_names()

schema = "Database Schema:\n"

for table in table_names:
    schema += f"Table: {table}\n"
    columns = inspector.get_columns(table)
    for col in columns:
        schema += f"  - {col['name']} ({col['type']})\n"

print(schema)

Database Schema:
Table: CBSA
  - CBSA (INTEGER)
  - CBSA_name (TEXT)
  - CBSA_type (TEXT)
Table: alias
  - zip_code (INTEGER)
  - alias (TEXT)
Table: area_code
  - zip_code (INTEGER)
  - area_code (INTEGER)
Table: avoid
  - zip_code (INTEGER)
  - bad_alias (TEXT)
Table: congress
  - cognress_rep_id (TEXT)
  - first_name (TEXT)
  - last_name (TEXT)
  - CID (TEXT)
  - party (TEXT)
  - state (TEXT)
  - abbreviation (TEXT)
  - House (TEXT)
  - District (INTEGER)
  - land_area (REAL)
Table: country
  - zip_code (INTEGER)
  - county (TEXT)
  - state (TEXT)
Table: state
  - abbreviation (TEXT)
  - name (TEXT)
Table: zip_congress
  - zip_code (INTEGER)
  - district (TEXT)
Table: zip_data
  - zip_code (INTEGER)
  - city (TEXT)
  - state (TEXT)
  - multi_county (TEXT)
  - type (TEXT)
  - organization (TEXT)
  - time_zone (TEXT)
  - daylight_savings (TEXT)
  - latitude (REAL)
  - longitude (REAL)
  - elevation (INTEGER)
  - state_fips (INTEGER)
  - county_fips (INTEGER)
  - region (TEXT)
  - divi

In [26]:
@tool
def sql_engine(query: str, thought: str) -> str:
    """
    Allows you to perform SQL queries on the table. Returns a string representation of the result.

    Args:
        query: The query to perform.
        thought: Your deductive reasoning or strategy for this specific step.
    """
    output = ""
    MAX_ROWS = 10
    with engine.connect() as con:
        rows = con.execute(text(query))
        for i, row in enumerate(rows):
          if i > MAX_ROWS:
            output += "\n...Output truncated."
            return output
          else:
            output += "\n" + str(row)

    return output

@tool
def final_answer(answer: str, thought: str) -> str:
    """
    Returns the final answer to the user. MUST include a thought explaining the conclusion.
    
    Args:
        thought: The final reasoning explaining why this query answers the user's question.
        answer: The final result (the SQL query string).
    """
    return answer

In [None]:
system_prompt = """
You are an expert Data Scientist specialized in Text-to-SQL tasks. Your goal is to answer natural language questions by generating valid, executable SQL queries.
You will be given a task to solve as best you can.
To do so, you have been given access to the tool 'sql_engine', which executes SQL queries on the database and returns structured observations.
To solve the task, you must plan forward and proceed in a series of steps, following a cycle of Thought, Tool Call and Observation.

PROTOCOL:
1. Reasoning Trace: Explicitly state your plan in the 'Thought' section.
2. Schema Understanding: Use the provided database schema to understand:
   - table names
   - column names
   - primary/foreign key relationships
   - bridge tables
3. Test Your Logic: Do not generate the final answer yet.
   - You must run a "Test Query" to verify your hypothesis.
   - Check if your filters exist.
   - Check if your JOINs return rows.
   - If the test returns 0 rows or an error, you must revise your query and test again.
4. After every exploratory or validation query executed through sql_engine, you MUST:
   - Read the result.
   - Produce a new explicit `Thought:` section interpreting the output.
   - Decide the next step based on that interpretation before running another code block.
   
CRITICAL RULES (YOU MUST FOLLOW):
1. TOOL USAGE ONLY: You may interact with the database ONLY by calling the `sql_engine` tool.
2. SQL DISCIPLINE: Always pass a single valid SQL query string to the tool. Do NOT manipulate or filter data outside SQL.
3. HANDLING JOINS (NO SHORTCUTS):
    - Check Foreign Keys: Do NOT join on `id` unless you are sure it is a foreign key.
    - Bridge Tables: If Table A and Table C have no direct link, look for a middle table B:
      Example: Works -> Chapters -> Paragraphs -> Characters. You cannot jump from Works to Characters.
4. FINAL ANSWER FORMAT: NO BLIND SUBMISSIONS:
   - You must see the output of a successful Test Query in the logs before you are allowed to submit.
   - Your final goal is the SQL QUERY STRING, not the data result.
   - DO NOT EXECUTE the final query inside `final_answer`.

WORKED EXAMPLE TRACE:

Task:
"Return the names of companies in the ‘AI’ sector that were founded before 2015 and have at least two offices in France."

Tool Call:
{
  "name": "sql_engine",
  "arguments": {
    "thought": "I need to verify the schema for the 'Company' table to confirm column names for sector and foundation year before applying filters.",
    "query": "SELECT * FROM Company LIMIT 3"
  }
}

Observation:
{
  "row_count": 3,
  "sample_rows": [
    {"cid": 1, "name": "EURECOM", "sector": "Education", "founded_year": 1990},
    {"cid": 2, "name": "Acme AI", "sector": "AI", "founded_year": 2012}
  ]
}

Tool Call:
{
  "name": "sql_engine",
  "arguments": {
    "thought": "The 'Company' table has the necessary columns, but no location info. I must now check the 'Office' table to find the link to companies.",
    "query": "SELECT * FROM Office LIMIT 3"
  }
}

Observation:
{
  "row_count": 3,
  "sample_rows": [
    {"oid": 10, "cid": 2, "country": "France", "city": "Paris"},
    {"oid": 11, "cid": 2, "country": "France", "city": "Lyon"}
  ]
}

Tool Call:
{
  "name": "sql_engine",
  "arguments": {
    "thought": "Since 'Office' links to 'Company' via `cid`, I can now join them to count French offices per company and filter by the criteria.",
    "query": "SELECT c.cid, c.name, COUNT(*) AS fr_offices FROM Company c JOIN Office o ON o.cid = c.cid WHERE c.sector = 'AI' AND c.founded_year < 2015 AND o.country = 'France' GROUP BY c.cid, c.name HAVING COUNT(*) >= 2 LIMIT 3"
  }
}

Observation:
{
  "row_count": 1,
  "sample_rows": [
    {"cid": 2, "name": "Acme AI", "fr_offices": 2}
  ]
}

Tool Call:
{
  "name": "final_answer",
  "arguments": {
    "thought": "The validation query successfully returned 'Acme AI', confirming the logic is correct. I will now generate the final SQL.",
    "answer": "SELECT c.name FROM Company c JOIN Office o ON o.cid = c.cid WHERE c.sector = 'AI' AND c.founded_year < 2015 AND o.country = 'France' GROUP BY c.cid, c.name HAVING COUNT(*) >= 2"
  }
}

Other example:

Task:
"How many scenes are there in Act 1 in Twelfth Night?"

Tool Call:
{
  "name": "sql_engine",
  "arguments": {
    "thought": "I analyze first the chapters table, to get a better understanding of its format.",
    "query": "SELECT * FROM chapters LIMIT 3"
  }
}

Observation:
{
  "row_count": 3,
  "sample_rows": [
    {"id": 18704, "Act": 1, "Scene": 1, "Description": "DUKE ORSINO’s palace.", "work_id": 1},
    {"id": 18705, "Act": 1, "Scene": 2, "Description": "The sea-coast.", "work_id": 1},
    {"id": 18706, "Act": 1, "Scene": 3, "Description": "OLIVIA’S house.", "work_id": 1}
  ]
}

Tool Call:
{
  "name": "sql_engine",
  "arguments": {
    "thought": "I can use the work_id foreign key to perform a join between the works and chapters table. I need to consider only the work 'Twelfth Night', so a filter is needed.",
    "query": "SELECT w.Title, c.Act, c.Scene FROM works w JOIN chapters c ON c.work_id = w.id WHERE w.Title = 'Twelfth Night' LIMIT 3"
  }
}

Observation:
{
  "row_count": 3,
  "sample_rows": [
    {"Title": "Twelfth Night", "Act": 1, "Scene": 1},
    {"Title": "Twelfth Night", "Act": 1, "Scene": 2},
    {"Title": "Twelfth Night", "Act": 1, "Scene": 3}
  ]
}

Tool Call:
{
  "name": "sql_engine",
  "arguments": {
    "thought": "Since I need to count the number of scenes in Act 1, a further filter for c.Act = 1 is required.",
    "query": "SELECT w.Title, c.Act, c.Scene FROM works w JOIN chapters c ON c.work_id = w.id WHERE w.Title = 'Twelfth Night' AND c.Act = 1 LIMIT 3"
  }
}

Observation:
{
  "row_count": 3,
  "sample_rows": [
    {"Title": "Twelfth Night", "Act": 1, "Scene": 1},
    {"Title": "Twelfth Night", "Act": 1, "Scene": 2},
    {"Title": "Twelfth Night", "Act": 1, "Scene": 3}
  ]
}

Tool Call:
{
  "name": "sql_engine",
  "arguments": {
    "thought": "Now that I have only the instances related to Act 1, I can proceed with the final query in which the number of scenes is counted. The column is renamed for better understanding.",
    "query": "SELECT COUNT(*) as n_Scenes FROM works w JOIN chapters c ON c.work_id = w.id WHERE w.Title = 'Twelfth Night' AND c.Act = 1"
  }
}

Observation:
{
  "row_count": 1,
  "sample_rows": [
    {"n_Scenes": 5}
  ]
}

Tool Call:
{
  "name": "final_answer",
  "arguments": {
    "thought": "The number of scenes in Act 1 in Twelfth Night is correctly retrieved, I can proceed with returning the final query.",
    "answer": "SELECT COUNT(*) as n_Scenes FROM works w JOIN chapters c ON c.work_id = w.id WHERE w.Title = 'Twelfth Night' AND c.Act = 1"
  }
}

---

You only have access to these tools:
- sql_engine: Allows you to perform SQL queries on the table. Returns a string representation of the result.
    Takes inputs: {'thought': 'Your deductive reasoning or strategy for this specific step.', 'query': 'The query to perform.', }
- final_answer: Provides a final answer to the given problem.
    Takes inputs: {'thought': 'Your description for this specific step.', 'answer': 'The final answer to the problem'}

Here are the rules you should always follow to solve your task:
1. ALWAYS provide a tool call, else you will fail.
2. Always use the right arguments for the tools. Never use variable names as the action arguments, use the value instead.
3. Call a tool only when needed: do not call the search agent if you do not need information, try to solve the task yourself. If no tool call is needed, use final_answer tool to return your answer.
4. Never re-do a tool call that you previously did with the exact same parameters.


Now Begin!
"""

my_templates_dict = EMPTY_PROMPT_TEMPLATES.copy()
my_templates_dict["system_prompt"] = system_prompt


In [28]:
model = InferenceClientModel(
    model_id="Qwen/Qwen3-8B", 
    token=my_token,
    tool_choice="auto"
)

agent = ToolCallingAgent(
    tools=[sql_engine, final_answer], 
    model=model,
    max_steps=10,
    prompt_templates=my_templates_dict
)

In [29]:
import re
from smolagents import ActionStep

def get_stats(agent):
    log_parts = []
    sql_query = None
    reasoning_len = 0
    tool_call_count = 0
    errors_count = 0
    is_final_answer = False

    for i, step in enumerate(agent.memory.steps):
        if i == 0: continue # Skip the task step

        if isinstance(step, ActionStep):
        #     if step.is_final_answer:
        #         # Handle Final Answer
        #         sql_query = str(step.action_output).strip()
        #         is_final_answer = True
                
        #         # Check if final answer has a thought argument (common in your prompt)
        #         if isinstance(step.tool_calls, list) and step.tool_calls:
        #              final_args = step.tool_calls[0].arguments
        #              if isinstance(final_args, dict) and 'thought' in final_args:
        #                  log_parts.append(f"[THOUGHT] {final_args['thought']}")
                         
            if step.is_final_answer:
                # Se final_answer ha tool_calls, prendiamo 'answer' dagli argomenti
                if isinstance(step.tool_calls, list) and step.tool_calls:
                    final_args = step.tool_calls[0].arguments
                    if isinstance(final_args, dict):
                        if 'answer' in final_args:
                            sql_query = final_args['answer'].strip()  # <-- Qui catturiamo la query finale
                        if 'thought' in final_args:
                            log_parts.append(f"[THOUGHT] {final_args['thought']}")
                is_final_answer = True

            # --- 1. CHECK MODEL OUTPUT (Standard Agent) ---
            # We keep this as a fallback if the model talks before calling the tool
            thought_text = getattr(step, 'model_output', "")
            if thought_text:
                # Basic cleanup
                clean = re.sub(r'```.*?```', '', thought_text, flags=re.DOTALL)
                clean = clean.replace("Thought:", "").strip()
                if clean:
                    log_parts.append(f"[THOUGHT] {clean}")

            # --- 2. CHECK TOOL ARGUMENTS (Your Custom Agent) ---
            if hasattr(step, 'tool_calls') and step.tool_calls and not is_final_answer:
                for tool_call in step.tool_calls:
                    tool_call_count += 1
                    
                    # Safely access arguments (smolagents usually stores them as a dict)
                    args = tool_call.arguments
                    
                    # A. EXTRACT THOUGHT FROM ARGUMENTS
                    # This is the fix for "sql_engine_with_thought"
                    if isinstance(args, dict):
                        arg_thought = args.get('thought')
                        if arg_thought:
                            log_parts.append(f"[THOUGHT] {arg_thought}")
                    
                    # B. EXTRACT SQL FROM ARGUMENTS
                    # Convert to string for regex searching just in case
                    args_str = str(args)
                    
                    match_triple = re.search(r'=\s*"""(.*?)"""', args_str, re.DOTALL)
                    match_direct = re.search(r"query['\"]?\s*:\s*['\"](.*?)['\"]", args_str, re.DOTALL)
                    # Fallback if using dictionary directly
                    dict_query = args.get('query') if isinstance(args, dict) else None

                    found_sql = None
                    if dict_query:
                        found_sql = dict_query
                    elif match_triple:
                        found_sql = match_triple.group(1)
                    elif match_direct:
                        found_sql = match_direct.group(1)

                    if found_sql:
                        flat_sql = found_sql.replace('\n', ' ').strip()
                        log_parts.append(f"[CALL] {flat_sql}")

            # --- 3. OBSERVATIONS ---
            if hasattr(step, 'observations') and step.observations and not is_final_answer:
                obs = str(step.observations).strip()
                obs = obs.replace("Execution logs:", "").replace("Last output from code snippet:", "")
                obs_clean = obs.replace('\n', ' ').strip()

                if not obs_clean or re.fullmatch(r'[\[\]\(\)\s,]*', obs_clean):
                    log_parts.append("[ANS] (no rows) [/ANS]")
                else:
                    log_parts.append(f"[ANS] {obs_clean} [/ANS]")

            # --- 4. ERRORS ---
            if hasattr(step, 'error') and step.error:
                err_clean = str(step.error).replace('\n', ' ')
                errors_count += 1
                log_parts.append(f"[ERROR] {err_clean}")

    # Formatting
    
    full_log_string = " | ".join(log_parts) # Pipe separator is cleaner for one-liners
    full_log_string = full_log_string.replace("\\", "").replace("\"", "'")
    
    reasoning_len = len(full_log_string)
    
    if sql_query:
        sql_query = sql_query.replace("\\n", "").replace("\n", " ").replace("\\'", "'").strip()
    else:
        sql_query = ""

    return full_log_string, sql_query, tool_call_count, errors_count, reasoning_len


In [30]:
def compute_execution_accuracy(gt_results, predict_results):
  num_correct = 0
  num_queries = len(gt_results)
  mismatch_idx = []

  for i, result in enumerate(gt_results):
      if set(result['results']) == set(predict_results[i]['results']):
          num_correct += 1
      else:
          mismatch_idx.append(i)

  acc = (num_correct / num_queries)

  return acc

In [31]:
import sqlite3
def run_query(db_path, query):
  conn = sqlite3.connect(db_path)
  try:
    cursor = conn.cursor()
    cursor.execute(query)
    rows = cursor.fetchall()
    conn.close()

    # Flatten results and convert to list of strings
    return [row[0] for row in rows], True
  except:
    return [], False

In [32]:
import time

os.makedirs("traces_TCA", exist_ok=True)

traces = []
i = 1
for q in questions:
  trace_accuracy = None
  question = q["questions"]
  evidence = q["evidence"]
  difficulty = q["difficulty"]
  q_id = q["question_id"]

  gt_query = q_ids[q_id]["target_sql"]

  USER_PROMPT = f"""DB Schema: {schema}. Question: {evidence}. {question}"""

  print(f"--- Question {i} ---")
  i += 1

  start_time = time.time()
  agent.run(USER_PROMPT)
  end_time = time.time()

  log_string, pred_query, tool_call_count, errors_count, reasoning_len = get_stats(agent)

  rows_gt, _ = run_query(path_sql, gt_query)
  gt_res = [{"results": rows_gt}]

  rows_pred, is_valid_sql = run_query(path_sql, pred_query)
  pred_res= [{"results": rows_pred}]

  if is_valid_sql:
    acc = compute_execution_accuracy(gt_res, pred_res)

  else:
    trace_accuracy = 0
    acc = 0


  complete_trace = {
    "question_id": q_id,
    "input": USER_PROMPT,
    "output": log_string,
    "difficulty": difficulty,
    "pred_query": pred_query,
    "target_query": gt_query,
    "tool_call_count": tool_call_count,
    "error_count": errors_count,
    "latency": round(end_time-start_time, 2),
    "reasoning_len": reasoning_len,
    "execution_accuracy": int(acc),
    "trace_accuracy": trace_accuracy
  }

  traces.append(complete_trace)


with open(f"traces_TCA/{db}_traces.json", "w") as f:
  json.dump(traces, f, indent=2, ensure_ascii=False)

master_file_path = "traces_TCA/traces_TCA.json"
if os.path.exists(master_file_path):
    with open(master_file_path, "r") as f:
        file_content = f.read()
        master_traces = json.loads(file_content) if file_content else []
else:
    master_traces = []
master_traces.extend(traces)

with open(master_file_path, "w") as f:
    json.dump(master_traces, f, indent=2, ensure_ascii=False)

--- Question 1 ---


--- Question 2 ---


--- Question 3 ---


--- Question 4 ---


--- Question 5 ---


--- Question 6 ---


--- Question 7 ---


--- Question 8 ---
