## Create SQLite DB inside Colab

In [121]:
# =========================
# Setup (Colab)
# =========================
!pip -q install langchain langchain-community langchain-experimental openai python-dotenv pandas tabulate sqlglot chromadb langchain-openai

import os, re, json, sqlite3, textwrap, warnings
warnings.filterwarnings("ignore")

import pandas as pd
from tabulate import tabulate
import re

# LangChain / OpenAI
from langchain_openai import ChatOpenAI
from langchain_community.utilities import SQLDatabase
from langchain.agents import initialize_agent, Tool
from langchain_experimental.sql import SQLDatabaseChain
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.agents import AgentType

In [122]:
os.environ["OPENAI_API_KEY"] = 'sk-proj-'

In [123]:
db_path = "sales.db"
# table_name_input = input("SQLite table name (leave blank to auto from filename): ").strip()
table_name_input = 'sales'
table_name = table_name_input if table_name_input else re.sub(r"[^a-zA-Z0-9_]", "_", os.path.splitext(os.path.basename(csv_path))[0])


In [124]:
# Load CSV
csv_path = '/content/sales_data_sample.csv'
df = pd.read_csv(csv_path, encoding='latin1')

In [125]:
# Create SQLite DB
conn = sqlite3.connect(db_path)
df.to_sql(table_name, conn, if_exists="replace", index=False)
conn.commit()

In [126]:

print(f"\n Loaded CSV into SQLite → {db_path}, table: {table_name}")
print("Sample rows:")
print(tabulate(df.head(5), headers="keys", tablefmt="github"))


 Loaded CSV into SQLite → sales.db, table: sales
Sample rows:
|    |   ORDERNUMBER |   QUANTITYORDERED |   PRICEEACH |   ORDERLINENUMBER |   SALES | ORDERDATE       | STATUS   |   QTR_ID |   MONTH_ID |   YEAR_ID | PRODUCTLINE   |   MSRP | PRODUCTCODE   | CUSTOMERNAME             | PHONE            | ADDRESSLINE1                  |   ADDRESSLINE2 | CITY          | STATE   |   POSTALCODE | COUNTRY   | TERRITORY   | CONTACTLASTNAME   | CONTACTFIRSTNAME   | DEALSIZE   |
|----|---------------|-------------------|-------------|-------------------|---------|-----------------|----------|----------|------------|-----------|---------------|--------|---------------|--------------------------|------------------|-------------------------------|----------------|---------------|---------|--------------|-----------|-------------|-------------------|--------------------|------------|
|  0 |         10107 |                30 |       95.7  |                 2 | 2871    | 2/24/2003 0:00  | Shipped  |    

In [127]:
#  Build SQL Utilities
# =========================
db_uri = f"sqlite:///{db_path}"
db = SQLDatabase.from_uri(db_uri)

In [128]:
def fetch_schema_text() -> str:
    # Compact schema, plus a few sample values as "data dictionary"
    with sqlite3.connect(db_path) as c:
        cur = c.cursor()
        cur.execute(f'PRAGMA table_info("{table_name}")')
        cols = cur.fetchall()
    lines = [f"CREATE TABLE {table_name} ("]
    for cid, name, coltype, notnull, default, pk in cols:
        lines.append(f"  {name} {coltype or ''} {'NOT NULL' if notnull else ''} {'PRIMARY KEY' if pk else ''}")
    lines.append(");")

    # Sample values for each column to help disambiguate
    dd_lines = ["\n-- Data Dictionary (sample values) --"]
    with sqlite3.connect(db_path) as c:
        for name, dtype in zip(df.columns, df.dtypes):
            try:
                vals = pd.read_sql(f'SELECT "{name}" AS v FROM "{table_name}" WHERE "{name}" IS NOT NULL LIMIT 5;', c)["v"].tolist()
            except Exception:
                vals = []
            dd_lines.append(f"{name} ({str(dtype)}): examples -> {vals}")
    return "\n".join(lines + dd_lines)

SCHEMA_TEXT = fetch_schema_text()

In [129]:
SCHEMA_TEXT

'CREATE TABLE sales (\n  ORDERNUMBER INTEGER  \n  QUANTITYORDERED INTEGER  \n  PRICEEACH REAL  \n  ORDERLINENUMBER INTEGER  \n  SALES REAL  \n  ORDERDATE TEXT  \n  STATUS TEXT  \n  QTR_ID INTEGER  \n  MONTH_ID INTEGER  \n  YEAR_ID INTEGER  \n  PRODUCTLINE TEXT  \n  MSRP INTEGER  \n  PRODUCTCODE TEXT  \n  CUSTOMERNAME TEXT  \n  PHONE TEXT  \n  ADDRESSLINE1 TEXT  \n  ADDRESSLINE2 TEXT  \n  CITY TEXT  \n  STATE TEXT  \n  POSTALCODE TEXT  \n  COUNTRY TEXT  \n  TERRITORY TEXT  \n  CONTACTLASTNAME TEXT  \n  CONTACTFIRSTNAME TEXT  \n  DEALSIZE TEXT  \n);\n\n-- Data Dictionary (sample values) --\nORDERNUMBER (int64): examples -> [10107, 10121, 10134, 10145, 10159]\nQUANTITYORDERED (int64): examples -> [30, 34, 41, 45, 49]\nPRICEEACH (float64): examples -> [95.7, 81.35, 94.74, 83.26, 100.0]\nORDERLINENUMBER (int64): examples -> [2, 5, 2, 6, 14]\nSALES (float64): examples -> [2871.0, 2765.9, 3884.34, 3746.7, 5205.27]\nORDERDATE (object): examples -> [\'2/24/2003 0:00\', \'5/7/2003 0:00\', \'7/1/

In [130]:
# Some seed example queries to help RAG (few-shot)
EXAMPLE_SQLS = [
    ("Total revenue by product code (top 5)",
     f'SELECT PRODUCTCODE, SUM(SALES) AS total_revenue FROM {table_name} GROUP BY PRODUCTCODE ORDER BY total_revenue DESC LIMIT 5;'),
    ("Revenue by year",
     f'SELECT YEAR_ID, SUM(SALES) AS total_revenue FROM {table_name} GROUP BY YEAR_ID ORDER BY YEAR_ID;'),
    ("Top 5 customers by spend",
     f'SELECT CUSTOMERNAME, SUM(SALES) AS total_spend FROM {table_name} GROUP BY CUSTOMERNAME ORDER BY total_spend DESC LIMIT 5;'),
    ("Monthly revenue trend for 2004",
     f'SELECT MONTH_ID, SUM(SALES) AS total_revenue FROM {table_name} WHERE YEAR_ID=2004 GROUP BY MONTH_ID ORDER BY MONTH_ID;'),
    ("Average order value by country",
     f'SELECT COUNTRY, AVG(SALES) AS avg_order FROM {table_name} GROUP BY COUNTRY ORDER BY avg_order DESC;'),
]

In [131]:
#  Build RAG index (schema + data dict + examples)
# =========================
emb = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

rag_texts = [SCHEMA_TEXT] + [f"Example: {title}\n{sql}" for title, sql in EXAMPLE_SQLS]
rag_metadatas = [{"type": "schema"}] + [{"type": "example", "title": t} for t,_ in EXAMPLE_SQLS]

vectordb = Chroma.from_texts(texts=rag_texts, embedding=emb, metadatas=rag_metadatas)
retriever = vectordb.as_retriever(search_kwargs={"k": 4})

In [132]:
def strip_sql_fences(s: str) -> str:
    print(f"strip_sql_fences :- ")
    print(f" S Before :- {s}")
    if not isinstance(s, str):
        return s
    s = s.strip()
    s = re.sub(r"^```(?:sql)?\s*", "", s, flags=re.IGNORECASE)
    s = re.sub(r"\s*```$", "", s)
    # Remove wrapping backticks (` ... `)
    if s.startswith("`") and s.endswith("`"):
        s = s[1:-1].strip()

    print(f" S After :- {s}")
    print(s.rstrip("; ") + ";")
    return s.rstrip("; ") + ";"
    # return s.strip(" ;") + ";"

def safe_sql(sql: str) -> tuple[bool, str]:
    banned = ["DROP ", "DELETE ", "UPDATE ", "INSERT ", "ALTER ", "TRUNCATE ", "ATTACH ", "DETACH ", "REPLACE "]
    up = re.sub(r"\s+", " ", (sql or "").upper() + " ")
    for b in banned:
        if b in up:
            return False, f"Blocked potentially destructive SQL: {b.strip()}"
    if not re.match(r"(?is)^\s*SELECT\b", sql or ""):
        return False, "Only SELECT statements are allowed."
    return True, ""

In [133]:
# =========================
#  Tools for the Agent
# =========================
def get_schema_tool(_: str = "") -> str:
    return SCHEMA_TEXT

def sample_rows_tool(table: str) -> str:
    print(f"************ sample_rows_tool **********")
    """
    Return a small sample of a table. Agent can call with a table name guess.
    """
    try:
        q = f'SELECT * FROM "{table}" LIMIT 5;'
        with sqlite3.connect(db_path) as c:
            sdf = pd.read_sql(q, c)
        return tabulate(sdf, headers="keys", tablefmt="github")
    except Exception as e:
        return f"[sample_rows_tool] error: {e}"

def safe_sql(sql: str) -> tuple[bool, str]:
    print(f"************ safe_sql **********")
    """
    Primitive safety: block mutating / dangerous statements.
    """
    banned = ["DROP ", "DELETE ", "UPDATE ", "INSERT ", "ALTER ", "TRUNCATE ", "ATTACH ", "DETACH "]
    sql_upper = re.sub(r"\s+", " ", sql.upper() + " ")
    for b in banned:
        if b in sql_upper:
            return False, f"Blocked potentially destructive SQL: {b.strip()}"
    return True, ""

def run_sql_tool(sql: str) -> str:
    sql = strip_sql_fences(sql or "")
    ok, why = safe_sql(sql)
    if not ok:
        return f"[SQL BLOCKED] {why}\nSQL: {sql}"
    try:
        with sqlite3.connect(db_path) as c:
            out = pd.read_sql(sql, c)
        pretty = tabulate(out.head(25), headers="keys", tablefmt="github")
        return f"SQL_OK\nROWS={len(out)}\n\n{pretty}"
    except Exception as e:
        return f"[SQL ERROR] {e}\nSQL: {sql}"

def explain_sql_tool(sql: str) -> str:
    print(f"************ explain_sql_tool **********")
    """
    Use SQLGlot to parse & reformat SQL; return pretty SQL and simple parse feedback.
    """
    from sqlglot import parse_one
    try:
        ast = parse_one(sql)
        formatted = ast.sql(pretty=True)
        return f"Formatted SQL:\n{formatted}"
    except Exception as e:
        return f"[explain_sql_tool] Could not parse: {e}"


In [134]:
GENERATOR_SYSTEM = """You are a Text-to-SQL generator for a SQLite database.
Rules:
- Use ONLY tables/columns that exist in the provided schema context.
- Prefer simple, correct SQL.
- For large results, include LIMIT.
- Return ONLY raw SQL (no explanations, no markdown, no ``` fences).
- Only SELECT statements are allowed.
"""

def generate_sql_tool(question: str) -> str:
    context = retrieve_context(question)
    prompt = f"""{GENERATOR_SYSTEM}
    # Schema & Examples (retrieved via RAG)
    {context}

    # Task
    User question: {question}

    # Return ONLY the SQL SELECT needed to answer the question. No text, no markdown.
    """
    sql = llm.predict(prompt)
    return strip_sql_fences(sql)

In [135]:

# =========================
#  LLM + Base SQL Chain
# =========================
# Low temperature for determinism
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)  # swap to 'gpt-4o'/'gpt-4.1' if available

In [136]:
sql_chain = SQLDatabaseChain.from_llm(
    llm=llm,
    db=db,
    verbose=True,
    top_k=5,               # limit rows when probing
    use_query_checker=True # adds an internal SQL safety/check pass
)

In [137]:

# =========================
#  Tool: RAG Retriever
# =========================
def retrieve_context(question: str) -> str:
    docs = retriever.get_relevant_documents(question)
    chunks = []
    for d in docs:
        meta = d.metadata.get("type","")
        title = d.metadata.get("title","")
        header = f"[{meta.upper()}] {title}" if title else f"[{meta.upper()}]"
        chunks.append(header + "\n" + d.page_content)
    return "\n\n---\n\n".join(chunks)

In [138]:
#   Tools for the Agent
# =========================
tools = [
    Tool(
        name="GetSchema",
        func=get_schema_tool,
        description=(
            "Use to understand available tables, columns, and types. "
            "Call BEFORE writing SQL and again if you are unsure."
        ),
    ),
    Tool(
        name="RetrieveContext",
        func=retrieve_context,
        description="Retrieve RAG context (schema snippets & example SQLs) for a question."
    ),
    Tool(
        name="SampleRows",
        func=sample_rows_tool,
        description=(
            "Use to peek at example rows for a given table name. "
            "Input should be exactly the table name (string)."
        ),
    ),
    Tool(
        name="GenerateSQL",
        func=generate_sql_tool,
        description=(
            "Generate and (optionally) execute a SQL query to answer a question. "
            "Always verify columns/tables exist using GetSchema first. Returns results or error."
        ),
    ),
    Tool(
        name="RunSQL",
        func=run_sql_tool,
        description=(
            "Execute a SQL SELECT statement and return results. "
            "Only SELECT queries are permitted."
        ),
    ),
    Tool(
        name="ExplainSQL",
        func=explain_sql_tool,
        description="Pretty-print and sanity-check a SQL string.",
    ),
]

In [139]:
# =========================
#  Agent Initialization
# =========================


agent = initialize_agent(
    tools=tools,
    llm=llm,
    agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,  # ReAct-style planning with tools
    verbose=True,
    max_iterations=6,
    handle_parsing_errors=True,
)

In [140]:
# =========================
#  Helper: Ask function (with auto-retry & visibility)
# =========================
SYSTEM_HINT = f"""
You are a careful Text-to-SQL assistant over a SQLite database at {db_path}.
Available table(s) include: {table_name}.
Rules:
- ALWAYS call GetSchema first.
- Validate table and column names before generating SQL.
- Prefer simple, correct SQL (LIMIT when results could be large).
- If a SQL fails, fix and retry at most twice using tool feedback.
- Return BOTH: (1) the final SQL, (2) a concise answer summary.
- NEVER run non-SELECT statements.
"""

In [141]:

def ask(nl_question: str) -> dict:
    prompt = f"""
{SYSTEM_HINT}

User question: {nl_question}

Return JSON with fields:
- "sql": the final SQL you would run
- "tool_to_use": "RunSQL" (preferred) or "GenerateSQL"
- "notes": brief reasoning
Then call the tool to execute. After execution, summarize the answer.
"""
    print("\n====================")
    print("🔎 QUESTION:", nl_question)
    print("====================\n")

    # First attempt
    result = agent.run(prompt)

    # Very light post-processing: if the agent didn't include SQL, try to elicit with ExplainSQL
    if "SELECT" not in result and "sql" not in result.lower():
        _ = agent.run(f"User asked: {nl_question}\nPlease produce the exact SELECT SQL only.")
        # (We leave result as-is; the verbose=True trace shows the tool calls.)

    print("\n--- Agent Final Output ---")
    print(result)
    return {"raw": result}

In [142]:
#  Demo Queries (edit these)
# =========================
print("\nYou can now ask questions! Examples:")
print(f' - "Show 5 random rows from {table_name}."')
print(' - "Top 5 products by total revenue." (assuming columns like Product, Quantity, Price)')
print(' - "Count of rows by Region and Year." (adjust to your schema)\n')



You can now ask questions! Examples:
 - "Show 5 random rows from sales."
 - "Top 5 products by total revenue." (assuming columns like Product, Quantity, Price)
 - "Count of rows by Region and Year." (adjust to your schema)



In [143]:

while True:
    q = input("Ask (or 'quit'): ").strip()
    if q.lower() in {"q", "quit", "exit"}:
        print("Bye! 👋")
        break
    try:
        _ = ask(q)
    except Exception as e:
        print("Error:", e)

Ask (or 'quit'): Show 5 random rows from sales

🔎 QUESTION: Show 5 random rows from sales



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mI need to start by checking the schema of the database to understand the structure of the "sales" table and its columns. This will help me ensure that I can write a correct SQL query to retrieve 5 random rows.

Action: GetSchema  
Action Input: ''  [0m
Observation: [36;1m[1;3mCREATE TABLE sales (
  ORDERNUMBER INTEGER  
  QUANTITYORDERED INTEGER  
  PRICEEACH REAL  
  ORDERLINENUMBER INTEGER  
  SALES REAL  
  ORDERDATE TEXT  
  STATUS TEXT  
  QTR_ID INTEGER  
  MONTH_ID INTEGER  
  YEAR_ID INTEGER  
  PRODUCTLINE TEXT  
  MSRP INTEGER  
  PRODUCTCODE TEXT  
  CUSTOMERNAME TEXT  
  PHONE TEXT  
  ADDRESSLINE1 TEXT  
  ADDRESSLINE2 TEXT  
  CITY TEXT  
  STATE TEXT  
  POSTALCODE TEXT  
  COUNTRY TEXT  
  TERRITORY TEXT  
  CONTACTLASTNAME TEXT  
  CONTACTFIRSTNAME TEXT  
  DEALSIZE TEXT  
);

-- Data Dictionary (sample values) --
OR