#Natural Language to SQL

In [1]:
import os
import random
from datetime import date, timedelta
from sqlalchemy.orm import declarative_base, relationship, sessionmaker
from sqlalchemy import create_engine, Column, Integer, String, ForeignKey, Date, Float
import sqlite3
import contextlib

In [2]:
%pip install -U -q langchain-community langchain-core langgraph langchain-groq

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.7/43.7 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m49.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m438.3/438.3 kB[0m [31m20.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.9/154.9 kB[0m [31m10.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m129.4/129.4 kB[0m [31m8.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.2/44.2 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.0/50.0 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.4/44.4 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [3]:
from langchain_groq import ChatGroq
from google.colab import userdata

#os.environ['GROQ_API_KEY'] = userdata.get('GROQ_API_KEY')
os.environ['GROQ_API_KEY'] = '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'


llm = ChatGroq(model='llama-3.3-70b-versatile')

print(llm.invoke('who are you?').content)

BadRequestError: Error code: 400 - {'error': {'message': 'Organization has been restricted. Please reach out to support if you believe this was in error.', 'type': 'invalid_request_error', 'code': 'organization_restricted'}}

In [None]:
from langchain_community.utilities.sql_database import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///example.db")

print(db.dialect)

print("=" * 30, "\n")
print(db.get_usable_table_names())
print("=" * 30, "\n")
print(db.get_table_info())

In [None]:
from typing_extensions import TypedDict
from pydantic import BaseModel, Field
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langgraph.graph import StateGraph, END


## Initial state
class AgentState(TypedDict):
    question: str
    sql_query: str
    query_result: str
    query_rows: list
    attempts: int
    relevance: str
    final_answer: str
    sql_error: bool

In [None]:
## Agent 1: Determines if the user NL-input relevant to the db schema

class CheckRelevance(BaseModel):
    relevance: str = Field(
        description="Indicates whether the question is related to the database schema. 'relevant' or 'not_relevant'."
    )

def check_relevance(state: AgentState):
    question = state["question"]
    print(f"Checking relevance of the question: {question}")
    system = """You are a highly skilled SQL expert. Your task is to evaluate whether a given natural language question is relevant to a database based on its schema.

    Use the schema below to determine whether the question can reasonably be answered using the available tables and columns.

    ---
    Schema:
    {schema}
    ---

    Instructions:
    1. Carefully read the user's question.
    2. Check whether the schema contains the necessary tables and columns to answer the question.
    3. If the question can be answered using the schema, respond with **"relevant"**.
    4. If the schema lacks the necessary information, respond with **"not_relevant"**.
    5. Your response must be either **"relevant"** or **"not_relevant"** only—do not explain or elaborate.

    Output Format:
    relevant
    or
    not_relevant
    """
    human = f"Question: {question}"
    check_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system),
            ("human", human),
        ]
    )

    llm = ChatGroq(model='llama-3.3-70b-versatile')
    structured_llm = llm.with_structured_output(CheckRelevance)
    relevance_checker = check_prompt | structured_llm
    try:
      relevance = relevance_checker.invoke({'schema': db.get_table_info()})
      state["relevance"] = relevance.relevance.lower().strip()
      print(f"Relevance determined: {state['relevance']}")
    except Exception as e:
      print(f"Error during relevance check: {e}")
      state["relevance"] = "not_relevant"
      state["sql_error"] = True
    return state

In [None]:
## Agent 2: Converts natural language to SQL queries

class ConvertToSQL(BaseModel):
    sql_query: str = Field(
        description="The SQL query corresponding to the user's natural language question."
    )


def convert_nl_to_sql(state: AgentState):
    question = state["question"]
    system = """
    You are a highly skilled SQL generation assistant. Your task is to convert a user's natural language question into a valid, syntactically correct, and semantically meaningful SQL query using the correct {dialect} dialect.

    Follow these strict rules:

    1. **Use only the tables and columns provided in the schema below**. Do not invent or reference tables/columns that are not explicitly listed.
    2. **Understand the relationships between tables** (e.g., foreign keys, primary keys) and use JOINs accordingly where appropriate.
    3. **Avoid using `SELECT *`**. Instead, return only the specific columns that are relevant to answering the user's question.
    4. Use appropriate **filters, sorting, and grouping** based on the user's intent (e.g., time ranges, categories, totals).
    5. If necessary, use **aggregations** (COUNT, AVG, MAX, etc.) when the question asks for summaries or statistics.
    6. Maintain clarity and simplicity. Prioritize correctness over cleverness.

    Before generating the SQL:
    - Carefully analyze the user's question.
    - Infer any implicit intent (e.g., filtering, ordering) only if it logically follows from the question.
    - Never assume facts that are not supported by the schema or the question.

    Schema:
    {table_info}

    Now, generate the SQL query that answers the following user question:
    """.format(dialect=db.dialect, table_info=db.get_table_info())

    convert_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system),
            ("human", "Question: {question}"),
        ]
    )

    llm = ChatGroq(model='llama-3.3-70b-versatile')
    structured_llm = llm.with_structured_output(ConvertToSQL)
    sql_generator = convert_prompt | structured_llm
    try:
      result = sql_generator.invoke({"question": question})
      state["sql_query"] = result.sql_query.strip()
    except Exception as e:
      print(f"Failed to generate SQL: {e}")
      state["sql_query"] = ""
      state["sql_error"] = True
    return state

In [None]:
from langchain_community.tools import QuerySQLDatabaseTool
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate

## Agent 3: Executes generated SQL query -> retrieve the data from db -> returns updated state
def execute_query(state: AgentState):
    """Execute SQL query and update state based on outcome."""
    execute_query_tool = QuerySQLDatabaseTool(db=db)
    try:
        result = execute_query_tool.invoke(state["sql_query"])
        state["query_result"] = result
        state["sql_error"] = False
    except Exception as e:
        print(f"Error during SQL execution: {e}")
        state["sql_error"] = True
        state["query_result"] = str(e)
    return state

In [None]:
## Agent 4: Based on SQL result, generates NL answer
def generate_answer(state: AgentState):
    """Answer question using retrieved information as context."""
    template = PromptTemplate.from_template(
    """You are an intelligent data assistant. Your task is to help answer the user's natural language question by using the provided SQL query and its result.

    You will be given:
    1. The original user question.
    2. The SQL query that was generated to answer the question.
    3. The result returned by executing that SQL query.

    Use this information to provide a helpful, clear, and concise answer to the user's question. If the result is empty or insufficient to answer the question confidently, respond accordingly.

    ---
    Question: {question}
    SQL Query: {sql_query}
    SQL Result: {query_result}
    ---

    Final Answer:""")

    llm_chain = template | llm | StrOutputParser()
    answer = llm_chain.invoke({
        "question": state["question"],
        "sql_query": state["sql_query"],
        "query_result": state["query_result"]
    })
    state["final_answer"] = answer
    return state

In [None]:
## Agent 5: Generates funny response if the user's query is not relevent to the db
def generate_funny_response(state: AgentState):
    print("Generating a funny response for an unrelated question.")
    system = """
    You are a witty, charming, and funny assistant whose job is to entertain users when they ask questions unrelated to the database or when no relevant answer can be provided.

    Your responses should:
    - Be playful and light-hearted.
    - Stay appropriate and friendly.
    - Acknowledge that the question isn't answerable via the database.
    - Gently steer the user back on track with a smile (figuratively).

    You are not required to be helpful — just be delightfully unhelpful in a clever way.
    """

    human_message = f"""
    The user asked a question that is unrelated to the database:
    '{state['question']}'
    Craft a humorous and creative response."""

    funny_prompt = ChatPromptTemplate.from_messages([
        ("system", system),
        ("human", human_message),
    ])

    llm = ChatGroq(model='llama-3.3-70b-versatile')
    funny_response = funny_prompt | llm | StrOutputParser()
    message = funny_response.invoke({})
    state["final_answer"] = message
    print("Generated funny response.")
    return state

In [None]:
## Agent 6: Rewrites the og question if there isn't enough info
class RewrittenQuestion(BaseModel):
    question: str = Field(description="The rewritten question.")

def regenerate_query(state: AgentState):
    question = state["question"]
    print("Regenerating the SQL query by rewriting the question.")
    system = """
    You are an expert in SQL and natural language understanding.

    Your task is to **rewrite a user's natural language question** so that:
    - It is clear, complete, and unambiguous.
    - It is optimized to be converted into a precise and valid SQL query.
    - All necessary details (e.g. filters, relationships between tables, required joins, and any implied logic) are included.
    - The reformulated version preserves the intent and meaning of the original question but improves its structure for programmatic interpretation.

    Avoid making assumptions not supported by the original question or schema.
    """

    rewrite_prompt = ChatPromptTemplate.from_messages([
        ("system", system),
        ("human", "Original Question: {question}\n\nRewrite this question to make it clearer and more suitable for SQL generation, including all relevant details.")
    ])

    llm = ChatGroq(model='llama-3.3-70b-versatile', temperature=0)
    structured_llm = llm.with_structured_output(RewrittenQuestion)
    rewriter = rewrite_prompt | structured_llm

    rewritten = rewriter.invoke({"question": question})
    state["rewritten_question"] = rewritten.question
    state["attempts"] = state.get("attempts", 0) + 1
    print(f"Rewritten question: {state['rewritten_question']}")
    return state

In [None]:
## Conditionnal nodes
def end_max_iter(state: AgentState):
    print("Maximum attempts reached. Ending the workflow.")
    state["query_result"] = "Please try again."
    return state

def router(state: AgentState):
    print("Routing based on relevance...")
    if state["relevance"].lower() == "relevant":
        return "convert_to_sql"
    else:
        return "generate_funny_response"

def check_attempts(state: AgentState):
    print(f"Attempt #{state['attempts']}")
    if state["attempts"] < 3:
        return "convert_to_sql"
    else:
        return "end_max_iter"

def execute_sql(state: AgentState):
    print("Routing based on SQL execution result...")
    if not state.get("sql_error", False):
        return "generate_answer"
    else:
        return "regenerate_query"

In [None]:
## Constructing the the Graph
from langgraph.graph import StateGraph, END

workflow = StateGraph(AgentState)

# Add nodes
workflow.add_node("check_relevance", check_relevance)
workflow.add_node("convert_to_sql", convert_nl_to_sql)
workflow.add_node("execute_sql", execute_query)
workflow.add_node("generate_answer", generate_answer)
workflow.add_node("generate_funny_response", generate_funny_response)
workflow.add_node("regenerate_query", regenerate_query)
workflow.add_node("end_max_iter", end_max_iter)

# Conditional logic
workflow.add_conditional_edges(
    "check_relevance",
    router,
    {
        "convert_to_sql": "convert_to_sql",
        "generate_funny_response": "generate_funny_response",
    },
)

workflow.add_edge("convert_to_sql", "execute_sql")

workflow.add_conditional_edges(
    "execute_sql",
    execute_sql,
    {
        "generate_answer": "generate_answer",
        "regenerate_query": "regenerate_query",
    },
)

workflow.add_conditional_edges(
    "regenerate_query",
    check_attempts,
    {
        "convert_to_sql": "convert_to_sql",
        "end_max_iter": "end_max_iter",
    },
)

# Terminal paths
workflow.add_edge("generate_answer", END)
workflow.add_edge("generate_funny_response", END)
workflow.add_edge("end_max_iter", END)

# Start point
workflow.set_entry_point("check_relevance")

# Compile
app = workflow.compile()

In [None]:
from IPython.display import Image, display

try:
  display(Image(app.get_graph(xray=True).draw_mermaid_png(max_retries=5, retry_delay=2.0)))
except:
  pass

In [None]:
# display(Image(app.get_graph(xray=True).draw_mermaid_png(max_retries=5, retry_delay=2.0)))

In [None]:
"""
# Initialize state
state = {
    "question": "",
    "chat_history": [], # adding memory
    "sql_query": "",
    "query_result": "",
    "query_rows": [],
    "attempts": 0,
    "relevance": "",
    "final_answer": "",
    "sql_error": False,
}


while True:
    user_input = input("User: ").strip()
    if user_input.lower() in ["exit", "quit"]:
        print("Conversation ended.")
        break

    state["question"] = user_input
    state["attempts"] = 0  # reset attempts each new question

    result = app.invoke(state)

    answer = result.get("final_answer", "No response available.")

    print(f"Assistant: {answer}\n")

    state["chat_history"].append({"user": user_input, "assistant": answer})
"""

In [None]:
import re
import sqlparse
import sqlite3
import json
from difflib import SequenceMatcher

def load_spider_dev(json_path):
    with open(json_path, "r") as f:
        return json.load(f)

def load_dev_gold(sql_path):
    with open(sql_path, "r") as f:
        return [line.strip().split('\t')[0] for line in f if line.strip()]

def get_tables_in_db(conn):
    cursor = conn.cursor()
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = set(row[0].lower() for row in cursor.fetchall())
    cursor.close()
    return tables

def fix_table_names_in_sql(sql, existing_tables):
    def replacement(match):
        tbl = match.group(0)
        if tbl.lower() in existing_tables:
            return tbl
        if tbl.lower().endswith('s'):
            singular = tbl[:-1]
            if singular.lower() in existing_tables:
                return singular
        return tbl

    pattern = re.compile(r"(?<=FROM\s)(\w+)|(?<=JOIN\s)(\w+)", re.IGNORECASE)
    return pattern.sub(replacement, sql)

def normalize_sql(sql):
    sql = sql.strip().rstrip(';')
    parsed = sqlparse.format(sql, keyword_case='lower', strip_comments=True, reindent=True)
    return " ".join(parsed.strip().split()).lower()

def normalize_result(result):
    if result is None:
        return None
    return set(tuple(str(item).strip() for item in row) for row in result)

def validate_sql(sql, conn):
    try:
        conn.execute("EXPLAIN QUERY PLAN " + sql)
        return True
    except Exception as e:
        return False

def execute_sql_direct(sql, conn):
    try:
        cursor = conn.cursor()
        cursor.execute(sql)
        result = cursor.fetchall()
        cursor.close()
        return result
    except Exception:
        return None

def execute_query(state, conn):
    try:
        cursor = conn.cursor()
        cursor.execute(state["sql_query"])
        result = cursor.fetchall()
        cursor.close()
        state["query_result"] = result
        state["sql_error"] = False
    except Exception:
        state["query_result"] = None
        state["sql_error"] = True
    return state

def extract_components(sql):
    parsed = sqlparse.parse(sql)
    if not parsed:
        return {"select": set(), "from": set(), "where": set()}
    stmt = parsed[0]

    select_tokens = set()
    from_tokens = set()
    where_tokens = set()

    is_select = False
    is_from = False
    is_where = False

    for token in stmt.tokens:
        if token.is_group:
            for subtoken in token.flatten():
                tval = subtoken.value.lower().strip()
                if tval in ("select", "from", "where"):
                    is_select = tval == "select"
                    is_from = tval == "from"
                    is_where = tval == "where"
                    continue
                if is_select and subtoken.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Wildcard):
                    select_tokens.add(tval)
                elif is_from and subtoken.ttype in (sqlparse.tokens.Name,):
                    from_tokens.add(tval)
                elif is_where and subtoken.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Literal.Number.Integer, sqlparse.tokens.Operator.Comparison):
                    where_tokens.add(tval)
        else:
            tval = token.value.lower().strip()
            if tval in ("select", "from", "where"):
                is_select = tval == "select"
                is_from = tval == "from"
                is_where = tval == "where"
                continue
            if is_select and token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Wildcard):
                select_tokens.add(tval)
            elif is_from and token.ttype in (sqlparse.tokens.Name,):
                from_tokens.add(tval)
            elif is_where and token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Literal.Number.Integer, sqlparse.tokens.Operator.Comparison):
                where_tokens.add(tval)

    return {"select": select_tokens, "from": from_tokens, "where": where_tokens}

def jaccard_similarity(set1, set2):
    if not set1 and not set2:
        return 1.0
    intersection = len(set1.intersection(set2))
    union = len(set1.union(set2))
    return intersection / union if union > 0 else 0

def string_similarity(a, b):
    return SequenceMatcher(None, a, b).ratio()

def precision_recall_f1(pred, gold):
    if not pred and not gold:
        return (1.0, 1.0, 1.0)
    if not pred:
        return (0.0, 0.0, 0.0)
    if not gold:
        return (0.0, 0.0, 0.0)

    tp = len(pred.intersection(gold))
    precision = tp / len(pred) if pred else 0
    recall = tp / len(gold) if gold else 0
    f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0
    return (precision, recall, f1)

def evaluate_pipeline(spider_dev, dev_gold, conn):
    exact_matches = 0
    execution_matches = 0
    valid_total = 0

    existing_tables = get_tables_in_db(conn)

    for i in range(len(spider_dev)):
        question = spider_dev[i]["question"]
        gold_sql = dev_gold[i].strip()

        state = {
            "question": question,
            "attempts": 0,
            "sql_query": "",
            "query_result": None,
            "sql_error": False,
            "relevance": "relevant",
        }

        # Generate SQL (your NL->SQL logic here)
        state = convert_nl_to_sql(state)
        gen_sql = state["sql_query"].strip()

        if not gen_sql:
            continue

        fixed_sql = fix_table_names_in_sql(gen_sql, existing_tables)
        state["sql_query"] = fixed_sql

        if not validate_sql(fixed_sql, conn):
            continue

        state = execute_query(state, conn)
        gen_result = state["query_result"]
        gold_result = execute_sql_direct(gold_sql, conn)

        norm_gen_sql = normalize_sql(fixed_sql)
        norm_gold_sql = normalize_sql(gold_sql)

        norm_gen_result = normalize_result(gen_result)
        norm_gold_result = normalize_result(gold_result)

        exact_match = norm_gen_sql == norm_gold_sql
        exec_match = (
            norm_gold_result is not None and
            norm_gen_result is not None and
            norm_gold_result == norm_gen_result
        )

        gen_components = extract_components(norm_gen_sql)
        gold_components = extract_components(norm_gold_sql)

        select_p, select_r, select_f1 = precision_recall_f1(gen_components["select"], gold_components["select"])
        from_p, from_r, from_f1 = precision_recall_f1(gen_components["from"], gold_components["from"])
        where_p, where_r, where_f1 = precision_recall_f1(gen_components["where"], gold_components["where"])

        select_jaccard = jaccard_similarity(gen_components["select"], gold_components["select"])
        from_jaccard = jaccard_similarity(gen_components["from"], gold_components["from"])
        where_jaccard = jaccard_similarity(gen_components["where"], gold_components["where"])

        overall_str_sim = string_similarity(norm_gen_sql, norm_gold_sql)

        valid_total += 1
        if exact_match:
            exact_matches += 1
        if exec_match:
            execution_matches += 1

        print(f"\n--- Example {valid_total} ---")
        print(f"Question:          {question}")
        print(f"Gold SQL:   {norm_gold_sql}")
        print(f"Generated SQL (norm):    {norm_gen_sql}")
        print(f"Exact Match:       {exact_match}")
        print(f"Exec Match:        {exec_match}")
        print(f"Execution Result:  {norm_gold_result}")
        print("Component F1 Scores:")
        print(f"  SELECT - P: {select_p:.2f}, R: {select_r:.2f}, F1: {select_f1:.2f}")
        print(f"  FROM   - P: {from_p:.2f}, R: {from_r:.2f}, F1: {from_f1:.2f}")
        print(f"  WHERE  - P: {where_p:.2f}, R: {where_r:.2f}, F1: {where_f1:.2f}")
        print("Component Jaccard Similarity:")
        print(f"  SELECT - {select_jaccard:.2f}")
        print(f"  FROM   - {from_jaccard:.2f}")
        print(f"  WHERE  - {where_jaccard:.2f}")
        print(f"Overall SQL String Similarity (SeqMatch): {overall_str_sim:.2f}")

    print("\n=== Evaluation Summary ===")
    print(f"Valid examples:           {valid_total}")
    print(f"Exact match count:        {exact_matches}")
    print(f"Execution match count:    {execution_matches}")
    if valid_total > 0:
        print(f"Exact match accuracy:     {exact_matches / valid_total:.2f}")
        print(f"Execution match accuracy: {execution_matches / valid_total:.2f}")
    else:
        print("No valid examples to evaluate.")

In [None]:
# Load dev set
spider_dev = load_spider_dev("/content/dev.json")

# Load gold SQL
dev_gold = []
with open('/content/dev_gold.sql', 'r') as f:
    for line in f:
        sql = line.strip().split('\t')[0]
        dev_gold.append(sql)

# Filter by db_id (from dev.json) and level
target_db = "employee_hire_evaluation"
target_level = "easy"  # change to e.g., "easy", "medium", "hard"

filtered_dev = []
filtered_gold = []

for ex, sql in zip(spider_dev, dev_gold):
    if ex.get("db_id", "").lower() == target_db and ex.get("type", "").lower() == target_level:
        filtered_dev.append(ex)
        filtered_gold.append(sql)

print(f"Filtered {len(filtered_dev)} examples from database '{target_db}' with level '{target_level}'.")

# set number of examples to run (will skip ones with validation errors)
n = 50
filtered_dev = filtered_dev[:n]
filtered_gold = filtered_gold[:n]

# Connect to the specific database
conn = sqlite3.connect('/content/employee_hire_evaluation.sqlite')

# Run evaluation
evaluate_pipeline(filtered_dev, filtered_gold, conn)

conn.close()