In [1]:
!pip install smolagents python-dotenv sqlalchemy --upgrade -q


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [None]:

my_token = ""

#my_token = userdata.get('HF_TOKEN')
with open('.env', 'w') as f:
    f.write(f"HF_TOKEN={my_token}")


In [4]:
from dotenv import load_dotenv
import os

load_dotenv()

True

In [5]:
from sqlalchemy import create_engine, inspect, text

db_path = "sqlite:///shakespeare.sqlite"
db_name = "shakespeare.sqlite"

if not os.path.exists("shakespeare.sqlite"):
    print("WARNING: db not found.")

engine = create_engine(db_path)



In [5]:
inspector = inspect(engine)
table_names = inspector.get_table_names()

schema = "Database Schema:\n"

for table in table_names:
    schema += f"Table: {table}\n"
    columns = inspector.get_columns(table)
    for col in columns:
        schema += f"  - {col['name']} ({col['type']})\n"

print(schema)

Database Schema:
Table: chapters
  - id (INTEGER)
  - Act (INTEGER)
  - Scene (INTEGER)
  - Description (TEXT)
  - work_id (INTEGER)
Table: characters
  - id (INTEGER)
  - CharName (TEXT)
  - Abbrev (TEXT)
  - Description (TEXT)
Table: paragraphs
  - id (INTEGER)
  - ParagraphNum (INTEGER)
  - PlainText (TEXT)
  - character_id (INTEGER)
  - chapter_id (INTEGER)
Table: works
  - id (INTEGER)
  - Title (TEXT)
  - LongTitle (TEXT)
  - Date (INTEGER)
  - GenreType (TEXT)



In [6]:
from smolagents import tool
@tool
def sql_engine(query: str) -> str:
    """
    Allows you to perform SQL queries on the table. Returns a string representation of the result.

    Args:
        query: The query to perform.
    """
    output = ""
    with engine.connect() as con:
        rows = con.execute(text(query))
        for row in rows:
            output += "\n" + str(row)
    return output

In [7]:
from smolagents import CodeAgent, InferenceClientModel, EMPTY_PROMPT_TEMPLATES

# Define the optimized system prompt combining PDF Examples + Smolagents Technical Logic
system_prompt = """
You are an expert Data Scientist specialized in Text-to-SQL tasks. Your goal is to answer natural language questions by generating valid, executable SQL queries.

You will be given a task to solve as best you can.
To do so, you have been given access to the tool 'sql_engine': this tool is basically a Python function which you can call with code.
To solve the task, you must plan forward to proceed in a series of steps, in a cycle of Thought, Code, and Observation sequences.

PROTOCOL:
1. Reasoning Trace: Explicitly state your plan in the 'Thought' section.
2. Schema Understanding: Use the provided database schema to understand:
   - table names
   - column names
   - primary/foreign key relationships
   - bridge tables
3. Test Your Logic: Do not generate the final answer yet.
   - You must run a "Test Query" to verify your hypothesis.
   - Check if your filters exist (e.g., `WHERE Name LIKE '%Apple%'` vs `WHERE Name = 'Apple'`).
   - Check if your JOINs return rows.
   - If the test returns 0 rows or an error, you must revise your query and test again.
4. After every exploratory or validation query executed through sql_engine, you MUST:
   - Read the result.
   - Produce a new explicit `Thought:` section interpreting the output.
   - Decide the next step based on that interpretation before running another code block.

CRITICAL RULES (YOU MUST FOLLOW):
1. Python Syntax Only: The code block contains PYTHON code.
2. SQL as Strings: Always pass SQL as a string to `sql_engine`: `print(sql_engine(query="SELECT ..."))`
3. HANDLING JOINS (NO SHORTCUTS):
    - Check Foreign Keys: Do NOT join on `id` unless you are sure it is a foreign key.
    - Bridge Tables: If Table A and Table C have no direct link, look for a middle table B:
      Example: Works -> Chapters -> Paragraphs -> Characters. You cannot jump from Works to Characters.
4. NO PYTHON DATA MANIPULATION: Do not fetch data into Python to filter it. Write a single SQL query using JOINs or Subqueries.
5. FINAL ANSWER FORMAT: NO BLIND SUBMISSIONS:
   - You must see the output of a successful Test Query in the logs before you are allowed to submit.
   - Your final goal is the SQL QUERY STRING, not the data result.
   - DO NOT EXECUTE the final query inside `final_answer`:
    * WRONG: `final_answer(sql_engine(query="SELECT..."))` <-- returns data (BAD)
    * CORRECT: `final_answer("SELECT...")` <-- returns SQL (GOOD)

WORKED EXAMPLE TRACE:

Task: "Return the names of companies in the ‘AI’ sector that were founded before 2015 and have at least two offices in France."

Thought: I need to find companies meeting three criteria: sector 'AI', founded before 2015, and having > 1 office in France. First, I need to inspect the tables to understand the schema. I will use Python to call the sql_engine.
<code>
# Inspect tables using Python strings for the queries
print(sql_engine(query="SELECT * FROM Company LIMIT 3"))
print(sql_engine(query="SELECT * FROM Office LIMIT 3"))
</code>

[ANS]
cid | name | sector | founded_year
1 | EURECOM | Education | 1990
2 | Acme AI | AI | 2012
3 | RoboCorp | AI | 2016

oid | cid | country | city
10 | 2 | France | Paris
11 | 2 | France | Lyon
12 | 2 | Germany | Berlin
[/ANS]

Thought: The tables are 'Company' and 'Office', joined by 'cid'.
- 'founded_year' is an integer.
- 'sector' is a string.
Now I will construct a validation query. I must assign the SQL to a Python variable as a string.

<code>
# Validating the logic with a count. Notice the query is a string.
validation_query = \"\"\"
SELECT c.cid, c.name, COUNT(*) AS fr_offices
FROM Company c
JOIN Office o ON o.cid=c.cid
WHERE c.sector='AI' AND c.founded_year<2015 AND o.country='France'
GROUP BY c.cid, c.name
HAVING COUNT(*)>=2
LIMIT 3
\"\"\"
print(sql_engine(query=validation_query))
</code>

[ANS]
cid | name | fr_offices
2 | Acme AI | 2
[/ANS]

Thought: The logic is correct. Acme AI is identified. I can now generate the final SQL query string.
<code>
final_sql = "\"\"
SELECT c.name
FROM Company AS c
JOIN Office AS o ON o.cid = c.cid
WHERE c.sector = 'AI'
AND c.founded_year < 2015
AND o.country = 'France'
GROUP BY c.cid, c.name
HAVING COUNT(*) >= 2
"\"\"
final_answer(final_sql)
</code>

Other example:
2. Task: How many scenes are there in Act 1 in Twelfth Night?

Thought: I analyze first the chapters table, to get a better understanding of its format.

<code>
test_query = "\"\"
SELECT *
from chapters
LIMIT 3
"\"\"
print(sql_engine(test_query))
</code>

[ANS]
id | Act | Scene | Description | work_id
18704 | 1 | 1 | DUKE ORSINO’s palace. | 1
18705 | 1 | 2 | The sea-coast. | 1
18706 | 1 | 3 | OLIVIA’S house. | 1
[/ANS]

Thought: I can use the work_id foreign key to perform a join between the works and chapters table. I need to consider only the work 'Twelfth Night', so a filter is needed.

<code>
test_query = "\"\"
SELECT w.Title, c.Act, c.Scene
FROM works w JOIN chapters c ON c.work_id = w.id
WHERE w.Title = 'Twelfth Night'
LIMIT 3
"\"\"
print(sql_engine(query))
</code>

[ANS]
Title | Act | Scene
Twelfth Night | 1 | 1
Twelfth Night | 1 | 2
Twelfth Night | 1 | 3
[/ANS]

Thought: As expected, I have now the information about all the acts and the scenes of the work 'Twelfth Night'. Since I need to count the number of scenes in Act 1, a further filter is needed.

<code>
test_query = "\"\"
SELECT w.Title, c.Act, c.Scene
FROM works w JOIN chapters c ON c.work_id = w.id
WHERE w.Title = 'Twelfth Night' AND c.Act = 1
LIMIT 3
"\"\"
print(sql_engine(query))
</code>

[ANS]
Title | Act | Scene
Twelfth Night | 1 | 1
Twelfth Night | 1 | 2
Twelfth Night | 1 | 3
[/ANS]

Thought: Now that I have only the instances related to Act 1, I can proceed with the final query in which the number of scenes of Act 1 in the work 'Twelfth Night' are counted. The column is renamed for better understanding.

<code>
test_query = "\"\"
SELECT COUNT(*) as n_Scenes
FROM works w JOIN chapters c ON c.work_id = w.id
WHERE w.Title = 'Twelfth Night' AND c.Act = 1
"\"\"
print(sql_engine(test_query))
</code>

[ANS]
n_Scenes
5
[/ANS]

Thought: The number of scenes in Act 1 in Twelfth Night is correctly retrieved, I can proceed with returning the final query.

<code>
final_sql = "\"\"
SELECT COUNT(*) as n_Scenes
FROM works w JOIN chapters c ON c.work_id = w.id
WHERE w.Title = 'Twelfth Night' AND c.Act = 1
"\"\"
final_answer(final_sql)
</code>

---

You only have access to these tools, behaving like regular python functions:
{{code_block_opening_tag}}
{%- for tool in tools.values() %}
{{ tool.to_code_prompt() }}
{% endfor %}
{{code_block_closing_tag}}

Here are the rules you should always follow to solve your task:
1. Always provide a 'Thought:' sequence, and a '{{code_block_opening_tag}}' sequence ending with '{{code_block_closing_tag}}', else you will fail.
2. Use only variables that you have defined!
3. Always use the right arguments for the tools. DO NOT pass the arguments as a dict as in 'answer = wikipedia_search({'query': \"What is the place where James Bond lives?\"})', but use the arguments directly as in 'answer = wikipedia_search(query=\"What is the place where James Bond lives?\")'.
4. For tools WITHOUT JSON output schema: Take care to not chain too many sequential tool calls in the same code block, as their output format is unpredictable. For instance, a call to wikipedia_search without a JSON output schema has an unpredictable return format, so do not have another tool call that depends on its output in the same block: rather output results with print() to use them in the next block.
5. For tools WITH JSON output schema: You can confidently chain multiple tool calls and directly access structured output fields in the same code block! When a tool has a JSON output schema, you know exactly what fields and data types to expect, allowing you to write robust code that directly accesses the structured response (e.g., result['field_name']) without needing intermediate print() statements.
6. Call a tool only when needed, and never re-do a tool call that you previously did with the exact same parameters.
7. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'.
8. Never create any notional variables in our code, as having these in your logs will derail you from the true variables.
9. You can use imports in your code, but only from the following list of modules: {{authorized_imports}}
10. The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist.
11. Don't give up! You're in charge of solving the task, not providing directions to solve it.

{%- if custom_instructions %}
{{custom_instructions}}
{%- endif %}

Now Begin!
"""


my_templates_dict = EMPTY_PROMPT_TEMPLATES.copy()
my_templates_dict["system_prompt"] = system_prompt

In [8]:
agent = CodeAgent(
    tools=[sql_engine],
    model=InferenceClientModel(model_id="Qwen/Qwen3-8B", token=my_token),
    prompt_templates=my_templates_dict,
    verbosity_level=2
)

In [9]:
question = "Give the title and the characters name of the most recent work of Shakespeare."
evidence = "characters name refers to CharName; most recent work refers to max(Date)"


In [10]:
USER_PROMPT = f"""
DB Schema:
{schema}
Question:
{evidence}. {question}
"""

In [16]:
agent.run(USER_PROMPT)


'\nSELECT DISTINCT w.Title, ch.CharName\nFROM works w\nJOIN chapters c ON w.id = c.work_id\nJOIN paragraphs p ON c.id = p.chapter_id\nJOIN characters ch ON p.character_id = ch.id\nWHERE w.id = 15\n'

In [17]:
import re

def get_clean_log_string(agent):

    log_parts = []
    sql_query = None

    for i, step in enumerate(agent.memory.steps):

        if i == 0: continue # Skip the task step

        thought = getattr(step, 'model_output', getattr(step, 'thought', None))
        if thought:
            clean_thought = re.sub(r'<code>.*?</code>', '', thought, flags=re.DOTALL)
            clean_thought = clean_thought.replace('\n', ' ').strip()
            if clean_thought:
                log_parts.append(f"{clean_thought}")

        # call
        if hasattr(step, 'tool_calls') and step.tool_calls:
            for tool_call in step.tool_calls:
                content = getattr(tool_call, 'content', str(tool_call))

                # Regex to find SQL
                sql_match = re.search(r'query\s*=\s*"""(.*?)"""', content, re.DOTALL)
                if not sql_match:
                    sql_match = re.search(r"query\s*=\s*'''(.*?)'''", content, re.DOTALL)
                if not sql_match:
                    sql_match = re.search(r'query\s*=\s*"(.*?)"', content, re.DOTALL)

                if sql_match:
                    raw_sql = sql_match.group(1).strip()
                    # Flatten the SQL: remove newlines and double backslashes
                    flat_sql = raw_sql.replace('\n', ' ').replace('\\n', ' ').replace("   ", " ")

                    log_parts.append(f"[CALL] {flat_sql}")

        # ans - obs
        if hasattr(step, 'observations') and step.observations:
          if step.is_final_answer:
            sql_query = step.action_output.strip()
            continue

          obs = str(step.observations).strip()


          obs = obs.replace("Execution logs:", "").replace("Last output from code snippet:", "")
          obs = re.sub(r'\bNone\b', '', obs)
          obs = re.sub(r"^\('(.+)',\)$", r"\1", obs.strip(), flags=re.MULTILINE)
          obs = re.sub(r"^\('(.+)'\)$", r"\1", obs.strip(), flags=re.MULTILINE)

          obs_clean = obs.strip().replace('\n', ' ')

          if not obs_clean or re.fullmatch(r'[\[\]\(\)\s,]*', obs_clean):
              log_parts.append("[ANS] (no rows) [/ANS]")

          # Truncate if too long
          if len(obs_clean) > 200:
              obs_clean = obs_clean[:200] + "... [truncated]"

          if obs_clean:
              log_parts.append(f"[ANS] {obs_clean} [/ANS]")

        # errors
        if hasattr(step, 'error') and step.error:
             err_clean = str(step.error).replace('\n', ' ')
             log_parts.append(f"[ERROR] {err_clean}")


    # Join everything with a single newline between steps
    full_log_string = " ".join(log_parts)

    return full_log_string, sql_query


In [18]:
log_string, pred_query = get_clean_log_string(agent)

log_string = log_string.replace("\\", "").replace("\"", "\'")
pred_query = pred_query.replace("\\n",  "").replace("\n", " ").replace("\\'", "'").strip()
print("LOG:\n" + log_string)
print(f"\nFinal Query: {pred_query}")

LOG:
Thought: I need to find the most recent work by Shakespeare and retrieve the characters' names from it. First, I'll check the schema to confirm the relationships between tables. The works table contains the Date and LongTitle, which are essential for determining the most recent work. Characters are linked via paragraphs, chapters, and works. [CALL] SELECT * FROM works LIMIT 3 [ANS] (1, 'Twelfth Night', 'Twelfth Night, Or What You Will', 1599, 'Comedy') (2, 'All's Well That Ends Well', 'All's Well That Ends Well', 1602, 'Comedy') (3, 'Antony and Cleopatra', 'Antony and Cleopatra'... [truncated] [/ANS] Thought: The sample data shows that the works table has a 'Date' column which is an integer. To find Shakespeare's most recent work, I need to select the work with the maximum Date. Then, I must retrieve the characters associated with that work through the chapters and paragraphs tables. I'll start by finding the most recent work's ID. [CALL] SELECT id FROM works ORDER BY Date DESC LI

In [19]:
def compute_execution_accuracy(gt_results, predict_results):
  num_correct = 0
  num_queries = len(gt_results)
  mismatch_idx = []

  for i, result in enumerate(gt_results):
      if set(result['results']) == set(predict_results[i]['results']):
          num_correct += 1
      else:
          mismatch_idx.append(i)

  acc = (num_correct / num_queries)

  return acc

In [20]:
import sqlite3
def run_query(db_path, query):
  conn = sqlite3.connect(db_path)
  cursor = conn.cursor()
  cursor.execute(query)
  rows = cursor.fetchall()
  conn.close()

  # Flatten results and convert to list of strings
  return [row[0] for row in rows]

In [21]:
gt_query = """SELECT DISTINCT w.Title, ch.CharName FROM works w JOIN chapters c ON w.id = c.work_id JOIN paragraphs p ON p.chapter_id = c.id JOIN characters ch ON ch.id = p.character_id WHERE w.date = ( SELECT max(date) FROM works w2 )"""

In [22]:
rows_gt = run_query(db_name, gt_query)
gt_res = [{"results": rows_gt}]

rows_pred = run_query(db_name, pred_query)
pred_res = [{"results": rows_pred}]

In [23]:
acc = compute_execution_accuracy(gt_res, pred_res)
print(f"Accuracy of the generated SQL query: {acc}")

Accuracy of the generated SQL query: 1.0


In [24]:
complete_trace = {
    "input": USER_PROMPT,
    "output": log_string,
    "pred_query": pred_query,
    "target_query": gt_query,
    "execution_accuracy": int(acc)
}
complete_trace

{'input': '\nDB Schema:\nDatabase Schema:\nTable: chapters\n  - id (INTEGER)\n  - Act (INTEGER)\n  - Scene (INTEGER)\n  - Description (TEXT)\n  - work_id (INTEGER)\nTable: characters\n  - id (INTEGER)\n  - CharName (TEXT)\n  - Abbrev (TEXT)\n  - Description (TEXT)\nTable: paragraphs\n  - id (INTEGER)\n  - ParagraphNum (INTEGER)\n  - PlainText (TEXT)\n  - character_id (INTEGER)\n  - chapter_id (INTEGER)\nTable: works\n  - id (INTEGER)\n  - Title (TEXT)\n  - LongTitle (TEXT)\n  - Date (INTEGER)\n  - GenreType (TEXT)\n\nQuestion:\ncharacters name refers to CharName; most recent work refers to max(Date). Give the title and the characters name of the most recent work of Shakespeare.\n',
 'output': "Thought: I need to find the most recent work by Shakespeare and retrieve the characters' names from it. First, I'll check the schema to confirm the relationships between tables. The works table contains the Date and LongTitle, which are essential for determining the most recent work. Characters a