In [11]:
import sqlite3
import ollama

# Configuration
DB_PATH = "D:/SQL_agent/Chinook_Sqlite.sqlite"
MODEL_NAME = 'qwen2.5-coder:7b'

In [12]:
"""
def create_dummy_db():
    #Creates a sample database to test the agent.
    conn = sqlite3.connect(DB_PATH)
    cursor = conn.cursor()
    
    # Create a 'users' table
    cursor.execute('''
        CREATE TABLE IF NOT EXISTS users (
            id INTEGER PRIMARY KEY,
            username TEXT NOT NULL,
            email TEXT,
            signup_date DATE
        )
    ''')
    
    # Create a 'orders' table
    cursor.execute('''
        CREATE TABLE IF NOT EXISTS orders (
            order_id INTEGER PRIMARY KEY,
            user_id INTEGER,
            amount REAL,
            order_date DATE,
            FOREIGN KEY(user_id) REFERENCES users(id)
        )
    ''')
    
    conn.commit()
    conn.close()
    print(f"âœ… Database '{DB_PATH}' ready.")
"""

'\ndef create_dummy_db():\n    #Creates a sample database to test the agent.\n    conn = sqlite3.connect(DB_PATH)\n    cursor = conn.cursor()\n\n    # Create a \'users\' table\n    cursor.execute(\'\'\'\n        CREATE TABLE IF NOT EXISTS users (\n            id INTEGER PRIMARY KEY,\n            username TEXT NOT NULL,\n            email TEXT,\n            signup_date DATE\n        )\n    \'\'\')\n\n    # Create a \'orders\' table\n    cursor.execute(\'\'\'\n        CREATE TABLE IF NOT EXISTS orders (\n            order_id INTEGER PRIMARY KEY,\n            user_id INTEGER,\n            amount REAL,\n            order_date DATE,\n            FOREIGN KEY(user_id) REFERENCES users(id)\n        )\n    \'\'\')\n\n    conn.commit()\n    conn.close()\n    print(f"âœ… Database \'{DB_PATH}\' ready.")\n'

In [13]:
def get_database_schema(db_path):
    """
    Critical Step: Connects to SQLite and extracts the CREATE TABLE statements.
    This tells the LLM exactly what columns and types exist.
    """
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    
    # Query sqlite_master to get the CREATE statements for all tables
    cursor.execute("SELECT sql FROM sqlite_master WHERE type='table';")
    tables = cursor.fetchall()
    
    schema_context = ""
    for table in tables:
        # table[0] contains the full "CREATE TABLE..." string
        if table[0]: 
            schema_context += table[0] + ";\n"
            
    conn.close()
    return schema_context

In [14]:
def generate_sql(question, schema):
    """Sends the schema + question to Ollama."""
    
    system_prompt = f"""
    You are an expert SQL assistant. 
    1. Your task is to write a valid SQLite query to answer the user's question.
    2. Use ONLY the tables and columns defined in the schema below.
    3. Output ONLY the SQL query. Do not wrap it in markdown or add explanations.
    
    Target Database Schema:
    {schema}
    """

    print(f"ðŸ¤– Asking {MODEL_NAME}...")
    
    response = ollama.chat(model=MODEL_NAME, messages=[
        {'role': 'system', 'content': system_prompt},
        {'role': 'user', 'content': question},
    ])
    
    # Clean up response (sometimes models add ```sql ... ```)
    sql_result = response['message']['content'].strip()
    sql_result = sql_result.replace("```sql", "").replace("```", "").strip()
    
    return sql_result


In [15]:
def connect_readonly(db_path: str) -> sqlite3.Connection:
    # Read-only connection: prevents INSERT/UPDATE/DELETE/DROP even if a query slips through
    return sqlite3.connect(f"file:{db_path}?mode=ro", uri=True)


In [None]:
"""
import re
from typing import Tuple

BLOCKLIST = {
    "insert", "update", "delete", "drop", "alter", "create", "truncate",
    "attach", "detach", "vacuum", "reindex", "replace",
    "pragma", "begin", "commit", "rollback"
}

def validate_sql_readonly(sql: str) -> Tuple[bool, str]:
    if not sql or not sql.strip():
        return False, "Empty SQL."

    s = sql.strip()

    # Disallow multiple statements (common injection / jailbreak vector)
    # Allow at most one trailing semicolon.
    if s.count(";") > 1 or (";" in s and not s.endswith(";")):
        return False, "Multiple statements are not allowed."

    # Disallow SQL comments
    if "--" in s or "/*" in s or "*/" in s:
        return False, "SQL comments are not allowed."

    # Normalize for keyword checks
    lowered = re.sub(r"\s+", " ", s.lower())

    # Must start with SELECT or WITH (CTE) (optionally EXPLAIN)
    if not (lowered.startswith("select") or lowered.startswith("with") or lowered.startswith("explain select") or lowered.startswith("explain with")):
        return False, "Only SELECT queries (optionally WITH/EXPLAIN) are allowed."

    # Block dangerous keywords anywhere
    tokens = set(re.findall(r"[a-z_]+", lowered))
    bad = sorted(tokens.intersection(BLOCKLIST))
    if bad:
        return False, f"Disallowed keyword(s) found: {', '.join(bad)}"

    return True, "OK"
"""


In [None]:
"""def enforce_limit(sql: str, default_limit: int = 200) -> str:
    s = sql.strip().rstrip(";")
    # naive check: if LIMIT already present, leave it
    if re.search(r"\blimit\b", s, re.IGNORECASE):
        return s + ";"
    return f"{s}\nLIMIT {default_limit};"
"""

In [20]:
def execute_sql(db_path, sql):
    conn = connect_readonly(db_path)
    cur = conn.cursor()
    cur.execute(sql)
    rows = cur.fetchall()
    cols = [d[0] for d in cur.description] if cur.description else []
    conn.close()
    return cols, rows


In [17]:
# Legacy direct-flow main removed.
# LangGraph-driven entrypoint is defined in the next cell.


In [21]:
# pip install -U langgraph

from typing import TypedDict, Optional, Any
from langgraph.graph import StateGraph, START, END

class SQLState(TypedDict):
    question: str
    schema: str
    sql: str
    result: Any
    error: Optional[str]
    tries: int

MAX_TRIES = 2

def node_load_schema(state: SQLState) -> SQLState:
    state["schema"] = get_database_schema(DB_PATH)
    return state

def node_generate_sql(state: SQLState) -> SQLState:
    # If retrying, include prior error to guide correction
    q = state["question"]
    if state.get("error"):
        q = f"{q}\n\nThe previous SQL failed with this error:\n{state['error']}\nFix the SQL."
    state["sql"] = generate_sql(q, state["schema"])
    state["error"] = None
    return state

def node_execute_sql(state: SQLState) -> SQLState:
    try:
        cols, rows = execute_sql(DB_PATH, state["sql"])
        state["result"] = {"columns": cols, "rows": rows}
        state["error"] = None
    except Exception as e:
        state["result"] = None
        state["error"] = str(e)
    return state

def route_after_execute(state: SQLState) -> str:
    if state["error"] is None:
        return "done"
    if state["tries"] >= MAX_TRIES:
        return "done"
    return "retry"

def node_inc_tries(state: SQLState) -> SQLState:
    state["tries"] += 1
    return state

def build_sql_graph():
    g = StateGraph(SQLState)
    g.add_node("load_schema", node_load_schema)
    g.add_node("gen_sql", node_generate_sql)
    g.add_node("exec_sql", node_execute_sql)
    g.add_node("inc_tries", node_inc_tries)

    g.add_edge(START, "load_schema")
    g.add_edge("load_schema", "gen_sql")
    g.add_edge("gen_sql", "exec_sql")

    g.add_conditional_edges(
        "exec_sql",
        route_after_execute,
        {
            "retry": "inc_tries",
            "done": END,
        },
    )
    g.add_edge("inc_tries", "gen_sql")
    return g.compile()

app = build_sql_graph()

def run_sql_agent(question: str) -> SQLState:
    initial_state: SQLState = {
        "question": question,
        "schema": "",
        "sql": "",
        "result": None,
        "error": None,
        "tries": 0,
    }
    return app.invoke(initial_state)

def main(question: Optional[str] = None) -> None:
    #create_dummy_db()

    user_question = question or "Show me the top 5 users who spent the most money, including their email."
    print(f"??? Question: {user_question}\n")

    out = run_sql_agent(user_question)

    print("-" * 30)
    print("?? Generated SQL:")
    print(out["sql"])
    print("-" * 30)

    if out["error"]:
        print(f"? Error after retries: {out['error']}")
        return

    result = out["result"] or {"columns": [], "rows": []}
    print(result["columns"])
    for row in result["rows"]:
        print(row)

if __name__ == "__main__":
    main()


??? Question: Show me the top 5 users who spent the most money, including their email.

ðŸ¤– Asking qwen2.5-coder:7b...
------------------------------
?? Generated SQL:
SELECT u.username, u.email, SUM(o.amount) as total_spent 
FROM users u 
JOIN orders o ON u.id = o.user_id 
GROUP BY u.id 
ORDER BY total_spent DESC 
LIMIT 5
------------------------------
['username', 'email', 'total_spent']
