In [1]:
import sqlite3
import ollama
from __future__ import annotations

# Configuration
DB_PATH = "D:/SQL_agent/Chinook_Sqlite.sqlite"
MODEL_NAME = 'qwen2.5-coder:7b'

In [2]:
"""
def create_dummy_db():
    #Creates a sample database to test the agent.
    conn = sqlite3.connect(DB_PATH)
    cursor = conn.cursor()
    
    # Create a 'users' table
    cursor.execute('''
        CREATE TABLE IF NOT EXISTS users (
            id INTEGER PRIMARY KEY,
            username TEXT NOT NULL,
            email TEXT,
            signup_date DATE
        )
    ''')
    
    # Create a 'orders' table
    cursor.execute('''
        CREATE TABLE IF NOT EXISTS orders (
            order_id INTEGER PRIMARY KEY,
            user_id INTEGER,
            amount REAL,
            order_date DATE,
            FOREIGN KEY(user_id) REFERENCES users(id)
        )
    ''')
    
    conn.commit()
    conn.close()
    print(f"✅ Database '{DB_PATH}' ready.")
"""

'\ndef create_dummy_db():\n    #Creates a sample database to test the agent.\n    conn = sqlite3.connect(DB_PATH)\n    cursor = conn.cursor()\n\n    # Create a \'users\' table\n    cursor.execute(\'\'\'\n        CREATE TABLE IF NOT EXISTS users (\n            id INTEGER PRIMARY KEY,\n            username TEXT NOT NULL,\n            email TEXT,\n            signup_date DATE\n        )\n    \'\'\')\n\n    # Create a \'orders\' table\n    cursor.execute(\'\'\'\n        CREATE TABLE IF NOT EXISTS orders (\n            order_id INTEGER PRIMARY KEY,\n            user_id INTEGER,\n            amount REAL,\n            order_date DATE,\n            FOREIGN KEY(user_id) REFERENCES users(id)\n        )\n    \'\'\')\n\n    conn.commit()\n    conn.close()\n    print(f"✅ Database \'{DB_PATH}\' ready.")\n'

In [3]:
import chromadb
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
import re

In [4]:

# You can persist locally so you don't rebuild every run
CHROMA_DIR = "./chroma_sql_rag"
COLLECTION_NAME = "schema_docs"

embedding_fn = SentenceTransformerEmbeddingFunction(
    model_name="all-MiniLM-L6-v2"  # strong default; small + fast
)

def schema_to_docs(schema_text: str):
    """
    Splits schema into per-table documents.
    Assumes schema contains repeated CREATE TABLE statements.
    """
    docs = []
    # crude but effective split for SQLite CREATE TABLE blocks
    blocks = re.split(r";\s*\n", schema_text.strip())
    for b in blocks:
        b = b.strip()
        if not b:
            continue
        # only keep CREATE TABLE blocks
        if re.search(r"create\s+table", b, re.IGNORECASE):
            # try to extract table name for metadata
            m = re.search(r"create\s+table\s+(?:if\s+not\s+exists\s+)?([^\s(]+)", b, re.IGNORECASE)
            table = m.group(1) if m else "unknown"
            docs.append({
                "id": f"table::{table}",
                "text": b + ";",
                "meta": {"table": table}
            })
    return docs


  from .autonotebook import tqdm as notebook_tqdm


In [5]:

def build_or_load_chroma(schema_text: str, force_rebuild: bool = False):
    client = chromadb.PersistentClient(path=CHROMA_DIR)
    # get or create collection
    if force_rebuild:
        try:
            client.delete_collection(COLLECTION_NAME)
        except Exception:
            pass

    col = client.get_or_create_collection(
        name=COLLECTION_NAME,
        embedding_function=embedding_fn
    )

    # If empty, populate
    if col.count() == 0:
        docs = schema_to_docs(schema_text)
        col.add(
            ids=[d["id"] for d in docs],
            documents=[d["text"] for d in docs],
            metadatas=[d["meta"] for d in docs],
        )
        print(f"✅ Chroma populated with {len(docs)} schema docs.")
    else:
        print(f"✅ Chroma collection already has {col.count()} docs.")

    return col

def retrieve_schema_context(col, question: str, k: int = 4) -> str:
    res = col.query(query_texts=[question], n_results=k)
    docs = res["documents"][0] if res and res.get("documents") else []
    # Join retrieved docs as context
    return "\n\n".join(docs)


In [6]:
def get_database_schema(db_path):
    """
    Critical Step: Connects to SQLite and extracts the CREATE TABLE statements.
    This tells the LLM exactly what columns and types exist.
    """
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    
    # Query sqlite_master to get the CREATE statements for all tables
    cursor.execute("SELECT sql FROM sqlite_master WHERE type='table';")
    tables = cursor.fetchall()
    
    schema_context = ""
    for table in tables:
        # table[0] contains the full "CREATE TABLE..." string
        if table[0]: 
            schema_context += table[0] + ";\n"
            
    conn.close()
    return schema_context

In [7]:

def generate_sql(question, schema, rag_context: str = ""):
    """Sends retrieved schema context + question to Ollama."""

    system_prompt = f"""
You are an expert SQLite SQL assistant.

Hard constraints:
- Produce a SINGLE read-only query: SELECT (optionally WITH/EXPLAIN).
- DO NOT use INSERT/UPDATE/DELETE/DROP/ALTER/CREATE/TRUNCATE/PRAGMA/ATTACH/DETACH/VACUUM.
- Output ONLY the SQL query (no markdown, no explanation).

Relevant schema context (retrieved):
{rag_context}

Full schema (fallback reference):
{schema}
"""

    response = ollama.chat(model=MODEL_NAME, messages=[
        {'role': 'system', 'content': system_prompt},
        {'role': 'user', 'content': question},
    ])

    sql_result = response['message']['content'].strip()
    sql_result = sql_result.replace("```sql", "").replace("```", "").strip()
    return sql_result


In [8]:
def connect_readonly(db_path: str) -> sqlite3.Connection:
    # Read-only connection: prevents INSERT/UPDATE/DELETE/DROP even if a query slips through
    return sqlite3.connect(f"file:{db_path}?mode=ro", uri=True)


In [9]:
def execute_sql(db_path, sql):
    conn = connect_readonly(db_path)
    cur = conn.cursor()
    cur.execute(sql)
    rows = cur.fetchall()
    cols = [d[0] for d in cur.description] if cur.description else []
    conn.close()
    return cols, rows


In [None]:


from typing import TypedDict, Optional, Any
from langgraph.graph import StateGraph, START, END

class SQLState(TypedDict):
    question: str
    schema: str
    rag_context: str # RAG context can be added to state if we want to include it in the loop
    sql: str
    result: Any
    error: Optional[str]
    tries: int

MAX_TRIES = 2
# global-ish handle (or rebuild inside node once)
_chroma_collection = None

def node_build_rag_index(state: SQLState) -> SQLState:
    global _chroma_collection
    if _chroma_collection is None:
        _chroma_collection = build_or_load_chroma(state["schema"])
    return state

def node_retrieve_rag(state: SQLState) -> SQLState:
    global _chroma_collection
    state["rag_context"] = retrieve_schema_context(_chroma_collection, state["question"], k=4)
    return state

def node_load_schema(state: SQLState) -> SQLState:
    state["schema"] = get_database_schema(DB_PATH)
    return state

def node_generate_sql(state: SQLState) -> SQLState:
    q = state["question"]
    if state.get("error"):
        q = f"{q}\n\nThe previous SQL failed with this error:\n{state['error']}\nFix the SQL."

    state["sql"] = generate_sql(q, state["schema"], rag_context=state.get("rag_context",""))
    state["error"] = None
    return state

BLOCKED = re.compile(
    r'\b(INSERT|UPDATE|DELETE|DROP|ALTER|CREATE|TRUNCATE|PRAGMA|ATTACH|DETACH|VACUUM)\b',
    re.IGNORECASE
)

def node_security_check(state: SQLState) -> SQLState:
    if BLOCKED.search(state["sql"]):
        state["error"] = "Blocked: query contains a disallowed keyword."
        state["result"] = None
    return state

def node_execute_sql(state: SQLState) -> SQLState:
    try:
        cols, rows = execute_sql(DB_PATH, state["sql"])
        state["result"] = {"columns": cols, "rows": rows}
        state["error"] = None
    except Exception as e:
        state["result"] = None
        state["error"] = str(e)
    return state

def route_after_execute(state: SQLState) -> str:
    if state["error"] is None:
        return "done"
    if state["tries"] >= MAX_TRIES:
        return "done"
    return "retry"

def node_inc_tries(state: SQLState) -> SQLState:
    state["tries"] += 1
    return state

def build_sql_graph():
    g = StateGraph(SQLState)
    g.add_node("load_schema", node_load_schema)
    g.add_node("build_rag", node_build_rag_index)
    g.add_node("retrieve_rag", node_retrieve_rag)
    g.add_node("gen_sql", node_generate_sql)
    g.add_node("sec_check", node_security_check)
    g.add_node("exec_sql", node_execute_sql)
    g.add_node("inc_tries", node_inc_tries)

    g.add_edge(START, "load_schema")
    g.add_edge("load_schema", "build_rag")
    g.add_edge("build_rag", "retrieve_rag")
    g.add_edge("retrieve_rag", "gen_sql")
    g.add_edge("gen_sql", "sec_check")
    g.add_edge("sec_check", "exec_sql")

    g.add_conditional_edges(
        "exec_sql",
        route_after_execute,
        {
            "retry": "inc_tries",
            "done": END,
        },
    )
    g.add_edge("inc_tries", "gen_sql")  
    return g.compile()

app = build_sql_graph()

def run_sql_agent(question: str) -> SQLState:
    initial_state: SQLState = {
        "question": question,
        "schema": "",
        "rag_context": "",
        "sql": "",
        "result": None,
        "error": None,
        "tries": 0,
    }
    return app.invoke(initial_state)


In [None]:

def main(question: Optional[str] = None) -> None:
    #create_dummy_db()

    user_question = question or "Show me the top 5 users who spent the most money, including their email."
    print(f"Question: {user_question}\n")

    out = run_sql_agent(user_question)

    print("-" * 30)
    print("Generated SQL:")
    print(out["sql"])
    print("-" * 30)

    if out["error"]:
        print(f"? Error after retries: {out['error']}")
        return

    result = out["result"] or {"columns": [], "rows": []}
    print(result["columns"])
    for row in result["rows"]:
        print(row)

if __name__ == "__main__":
    main()


??? Question: Show me the top 5 users who spent the most money, including their email.

✅ Chroma collection already has 13 docs.
------------------------------
?? Generated SQL:
SELECT u.username, u.email, SUM(o.amount) as total_spent FROM orders o JOIN users u ON o.user_id = u.id GROUP BY u.id ORDER BY total_spent DESC LIMIT 5;
------------------------------
['username', 'email', 'total_spent']
