<a href="https://colab.research.google.com/github/soumyaGhoshh/nl2sql/blob/main/nl2sql_2026_soumya_ghosh.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

 ### **NL2SQL Implementation using Langchain**

**1. Building a basic NL2SQL model**

**2. Dynamic few-shot example selection**

**3. Dynamic relevant table selection**

**4. Adding memory to the chatbot**

In [112]:
!pip install langchain langchain-google-genai langchain-community langchain-chroma pymysql chromadb -q

**1. Setup & Database Connection**

We start by setting up the environment and connecting to the database. We will use the *classicmodels* schema referenced in your uploaded files.

In [113]:
import os
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_community.utilities import SQLDatabase

In [114]:
# Setup API Keys
os.environ["GOOGLE_API_KEY"] = ""

In [115]:
# Setup of Langsmith for Tracing
os.environ["LANGCHAIN_TRACING_V2"] = ""
os.environ["LANGCHAIN_PROJECT"] = ""
os.environ["LANGCHAIN_API_KEY"] = ""

print("Langsmith is set up.")

Langsmith is set up.


In [116]:
# connect to the database
db_user = ""
db_password = ""
db_host = ""
db_name = ""

In [117]:
# We use pymysql for MySQL connection
db = SQLDatabase.from_uri(f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}")
print("database connected.")

database connected.


In [118]:
# verifying connection
print(db.dialect)
print(db.get_usable_table_names())
print(db.get_table_info())

mysql
['customers', 'employees', 'offices', 'orderdetails', 'orders', 'payments', 'productlines', 'products']

CREATE TABLE customers (
	`customerNumber` INTEGER NOT NULL, 
	`customerName` VARCHAR(50) NOT NULL, 
	`contactLastName` VARCHAR(50) NOT NULL, 
	`contactFirstName` VARCHAR(50) NOT NULL, 
	phone VARCHAR(50) NOT NULL, 
	`addressLine1` VARCHAR(50) NOT NULL, 
	`addressLine2` VARCHAR(50), 
	city VARCHAR(50) NOT NULL, 
	state VARCHAR(50), 
	`postalCode` VARCHAR(15), 
	country VARCHAR(50) NOT NULL, 
	`salesRepEmployeeNumber` INTEGER, 
	`creditLimit` DECIMAL(10, 2), 
	PRIMARY KEY (`customerNumber`), 
	CONSTRAINT customers_ibfk_1 FOREIGN KEY(`salesRepEmployeeNumber`) REFERENCES employees (`employeeNumber`)
)COLLATE utf8mb4_0900_ai_ci DEFAULT CHARSET=utf8mb4 ENGINE=InnoDB

/*
3 rows from customers table:
customerNumber	customerName	contactLastName	contactFirstName	phone	addressLine1	addressLine2	city	state	postalCode	country	salesRepEmployeeNumber	creditLimit
103	Atelier graphique	Schmit

In [119]:
# This is the llm initializatoin or brain of our pipeline
llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", temperature=0)

**2. Building a Basic NL2SQL Model**

In this step, we manually build the chain. Instead of using utilities like *create_sql_query_chain*, we define the prompt and parsing logic ourselves. This gives us full control over how the schema is presented to the LLM. So initial prompt generation is less abstract and more clear now.




In [120]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda

In [121]:
# Helper function to get schema
def get_schema(_):
    return db.get_table_info()

In [122]:
# Define the Prompt manually unlike using predefined create_sql_query_chain where the initial raw prompt was provided by this utility itself
# We explicitly tell the model how to behave and provide the schema context
template = """You are a MySQL expert. Given an input question, create a syntactically correct MySQL query to run.
Unless the user specifies a specific number of examples, query for at most 5 results using the LIMIT clause.

Here is the relevant table info:
{schema}

Pay attention to use only the column names you can see in the tables above.
Be careful to not query for columns that do not exist.

Question: {question}
SQL Query:"""

# this time we define our prompt for llm from the starting itself
prompt = ChatPromptTemplate.from_template(template)

In [123]:
# Define the Chain using LCEL (LangChain Expression Language is still used)
# The chain flows: Input -> Get Schema -> Format Prompt -> LLM -> Parse Output
sql_chain = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | llm
    | StrOutputParser()
)

In [124]:
# Testing the basic chain providing it question
generated_sql = sql_chain.invoke({"question": "How many products are there ?"})
print(generated_sql)

```sql
SELECT
  COUNT(*)
FROM products;
```


In [164]:
#cleaning the query
import re

def clean_sql(text):
    # Remove markdown code blocks if present
    return text.replace("```sql", "").replace("```", "").strip()


clean_query = clean_sql(generated_sql)
print(clean_query)

In [126]:
# We use the database tool to run the raw SQL
# we then execute the query
from operator import itemgetter
from langchain_community.tools import QuerySQLDatabaseTool

execute_query = QuerySQLDatabaseTool(db=db)

result = execute_query.invoke(clean_query)
print(result)

[(110,)]


In [127]:
#we are rephrasing our answer to human readable
answer_prompt = ChatPromptTemplate.from_template(
    """Given the following user question, corresponding SQL query, and SQL result, answer the user question in full comprehensive sentence.

Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer: """
)

In [128]:
chain = (
    RunnablePassthrough.assign(query= sql_chain | RunnableLambda(clean_sql))
    .assign(result=itemgetter("query") | execute_query)
    | answer_prompt
    | llm
    | StrOutputParser()
)

result = chain.invoke({"question": "How many products are there ?"})
print(result)

There are 110 products.


**3. Adding Few-Shot Examples**

Here we provide a few example context helps the model understand complex queries or specific business logic so it's like specifically giving some special context or complex instances.

In [176]:
examples = [
    # -------- Customers --------
    {
        "input": "List all customers from the USA.",
        "query": "SELECT * FROM customers WHERE country = 'USA';"
    },
    {
        "input": "Find the number of customers in each country.",
        "query": "SELECT country, COUNT(*) FROM customers GROUP BY country;"
    },
    {
        "input": "Show customers whose credit limit is between 50000 and 100000.",
        "query": "SELECT customerName, creditLimit FROM customers WHERE creditLimit BETWEEN 50000 AND 100000;"
    },
    {
        "input": "List customers who do not have a sales representative.",
        "query": "SELECT customerName FROM customers WHERE salesRepEmployeeNumber IS NULL;"
    },

    # -------- Employees --------
    {
        "input": "List all employees working in the Sales department.",
        "query": "SELECT firstName, lastName FROM employees WHERE jobTitle LIKE '%Sales%';"
    },
    {
        "input": "Find employees working in the San Francisco office.",
        "query": "SELECT e.firstName, e.lastName FROM employees e JOIN offices o ON e.officeCode = o.officeCode WHERE o.city = 'San Francisco';"
    },
    {
        "input": "Count the number of employees in each office.",
        "query": "SELECT officeCode, COUNT(*) FROM employees GROUP BY officeCode;"
    },

    # -------- Products --------
    {
        "input": "List all products sorted by buy price in descending order.",
        "query": "SELECT productName, buyPrice FROM products ORDER BY buyPrice DESC;"
    },
    {
        "input": "Find the average buy price of products in each product line.",
        "query": "SELECT productLine, AVG(buyPrice) FROM products GROUP BY productLine;"
    },
    {
        "input": "Show products with MSRP greater than 100.",
        "query": "SELECT productName, MSRP FROM products WHERE MSRP > 100;"
    },

    # -------- Orders --------
    {
        "input": "List all orders placed in the year 2004.",
        "query": "SELECT * FROM orders WHERE YEAR(orderDate) = 2004;"
    },
    {
        "input": "Count the number of orders for each customer.",
        "query": "SELECT customerNumber, COUNT(*) FROM orders GROUP BY customerNumber;"
    },
    {
        "input": "Find orders that are still in process.",
        "query": "SELECT orderNumber, status FROM orders WHERE status = 'In Process';"
    },

    # -------- Order Details --------
    {
        "input": "Find total quantity ordered for each product.",
        "query": "SELECT productCode, SUM(quantityOrdered) FROM orderdetails GROUP BY productCode;"
    },
    {
        "input": "List products ordered more than 1000 units in total.",
        "query": "SELECT productCode, SUM(quantityOrdered) AS total FROM orderdetails GROUP BY productCode HAVING total > 1000;"
    },

    # -------- Payments --------
    {
        "input": "Calculate the total payment amount received from each customer.",
        "query": "SELECT customerNumber, SUM(amount) FROM payments GROUP BY customerNumber;"
    },
    {
        "input": "Find the average payment amount.",
        "query": "SELECT AVG(amount) FROM payments;"
    },
    {
        "input": "List payments made after January 1, 2004.",
        "query": "SELECT * FROM payments WHERE paymentDate > '2004-01-01';"
    },

    # -------- Joins (Important for NL2SQL) --------
    {
        "input": "List customer names along with their order numbers.",
        "query": "SELECT c.customerName, o.orderNumber FROM customers c JOIN orders o ON c.customerNumber = o.customerNumber;"
    },
    {
        "input": "Show order numbers along with the total amount for each order.",
        "query": "SELECT od.orderNumber, SUM(od.quantityOrdered * od.priceEach) AS totalAmount FROM orderdetails od GROUP BY od.orderNumber;"
    },
    {
        "input": "Find customers who have placed more than 5 orders.",
        "query": "SELECT customerNumber FROM orders GROUP BY customerNumber HAVING COUNT(*) > 5;"
    }
]


In [130]:
# Create the example Few-Shot Prompt Template
from langchain_core.prompts import FewShotChatMessagePromptTemplate
example_prompt = ChatPromptTemplate.from_messages(
    [
        ("human", "{input}"),
        ("ai", "{query}"),
    ]
)

few_shot_prompt = FewShotChatMessagePromptTemplate(
    example_prompt=example_prompt,
    examples=examples,
)
print(few_shot_prompt.format(input="How many products are there?"))

Human: List all customers in France with a credit limit over 20,000.
AI: SELECT * FROM customers WHERE country = 'France' AND creditLimit > 20000;
Human: Get the highest payment amount made by any customer.
AI: SELECT MAX(amount) FROM payments;
Human: Show product details for products in the 'Motorcycles' product line.
AI: SELECT * FROM products WHERE productLine = 'Motorcycles';
Human: Retrieve the names of employees who report to employee number 1002.
AI: SELECT firstName, lastName FROM employees WHERE reportsTo = 1002;
Human: List all products with a stock quantity less than 7000.
AI: SELECT productName, quantityInStock FROM products WHERE quantityInStock < 7000;
Human: what is price of `1968 Ford Mustang`
AI: SELECT `buyPrice`, `MSRP` FROM products  WHERE `productName` = '1968 Ford Mustang' LIMIT 1;




**4. Dynamic Few-Shot Example Selection**

Sending all examples tokens wastes tokens and might confuse the model. Sp here we use *SemanticSimilarityExampleSelector* to find the most relevant examples based on the user's question using embedding model(minilm)

In [131]:
!pip install -U langchain-huggingface -q

In [132]:
from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.example_selectors import SemanticSimilarityExampleSelector

embeddings = HuggingFaceEmbeddings(
    model_name="sentence-transformers/all-MiniLM-L6-v2"
)

In [133]:
# Create the Example Selector
# This stores examples in a vector DB and retrieves the top k most similar ones
example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples,
    embeddings,
    Chroma,
    k=2,
    input_keys=["input"],
)

In [134]:
# Create the Few-Shot Prompt Template
example_prompt = ChatPromptTemplate.from_messages(
    [
        ("human", "{input}"),
        ("ai", "{query}"),
    ]
)

few_shot_prompt = FewShotChatMessagePromptTemplate(
    example_selector=example_selector,
    example_prompt=example_prompt,
    input_variables=["input"],
)
print(few_shot_prompt.format(input="How many products are there?"))

Human: List all products with a stock quantity less than 7000.
AI: SELECT productName, quantityInStock FROM products WHERE quantityInStock < 7000;
Human: List all products with a stock quantity less than 7000.
AI: SELECT productName, quantityInStock FROM products WHERE quantityInStock < 7000;


In [135]:
# Integrate into Main Prompt with dynamic few shots
final_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "You are a MySQL expert. Use the following schema to answer the question.\n{schema}"),
        few_shot_prompt,
        ("human", "{question}"),
    ]
)

In [148]:
# updating the sql chain to few shot chain
few_shot_chain = (
    RunnablePassthrough.assign(schema=get_schema)
    | RunnablePassthrough.assign(input=itemgetter("question")) # Map 'question' to 'input' for few-shot prompt
    | final_prompt
    | llm
    | StrOutputParser()
)

# final chain after few shotting
chain = (
    RunnablePassthrough.assign(query= few_shot_chain | RunnableLambda(clean_sql))
    .assign(result=itemgetter("query") | execute_query)
    | answer_prompt
    | llm
    | StrOutputParser()
)

result = chain.invoke({"question": "How many products are there ?"})
print(result)

There are 110 products.


**5. Dynamic Relevant Table Selection (Large DB Issue)**

We implement a pre-processing step. We ask the LLM to look at the list of table names  and their db generated descriptions first and based on that decide which ones are relevant to the question. Then, we only fetch the schema for those specific tables

In [149]:
from typing import List
from pydantic import BaseModel, Field
from langchain_core.output_parsers import PydanticOutputParser

In [150]:
# Define Output Structure for Table Selection
class TableSelection(BaseModel):
    table_names: List[str] = Field(description="List of relevant table names")

In [151]:
# Define Table Selection Prompt
table_selection_prompt = ChatPromptTemplate.from_template(
    """Return the names of ALL the SQL tables that MIGHT be relevant to the user question.
The available tables are:
{table_list}

Question: {question}

Return ONLY the list of table names. output must be valid JSON."""
)

In [152]:
# Create the Table Selection Chain
# We use a structured output model if available, or just parse JSON
table_chain = (
    table_selection_prompt
    | llm.with_structured_output(TableSelection)
)

In [157]:
table_selection_prompt = ChatPromptTemplate.from_template(
    """Return the names of ALL the SQL tables that MIGHT be relevant to the user question.
You should use the table descriptions below to inform your decision.
{table_descriptions}

Question: {question}

Return ONLY the list of table names. output must be valid JSON."""
)

In [158]:
# Update Get Relevant Schema Function
def get_relevant_schema(inputs):
    question = inputs["question"]
    all_tables = db.get_usable_table_names()

    # Fetch detailed schema for all usable tables
    table_descriptions_list = []
    for table_name in all_tables:
        table_descriptions_list.append(db.get_table_info(table_names=[table_name]))

    table_descriptions_str = "\n\n".join(table_descriptions_list)

    # Ask LLM which tables are needed using the detailed descriptions
    selected_tables_obj = table_chain.invoke({"table_descriptions": table_descriptions_str, "question": question})
    selected_tables = selected_tables_obj.table_names

    print(f"Dynamically selected tables: {selected_tables}")

    # Fetch schema ONLY for selected tables
    return db.get_table_info(table_names=selected_tables)

In [169]:
# sql_chain after dynamic selection
dynamic_table_chain = (
    RunnablePassthrough.assign(schema=get_relevant_schema) #Uses dynamic selection
    | RunnablePassthrough.assign(input=itemgetter("question"))
    | final_prompt
    | llm
    | StrOutputParser()
)

# Test it
question = "give me details of customer and their order count"
print(dynamic_table_chain.invoke({"question": question}))

Dynamically selected tables: ['customers', 'orders']
```sql
SELECT
  c.customerNumber,
  c.customerName,
  c.contactLastName,
  c.contactFirstName,
  COUNT(o.orderNumber) AS orderCount
FROM customers AS c
LEFT JOIN orders AS o
  ON c.customerNumber = o.customerNumber
GROUP BY
  c.customerNumber,
  c.customerName,
  c.contactLastName,
  c.contactFirstName
ORDER BY
  orderCount DESC,
  c.customerName;
```


In [170]:
# the final chain after ddynamic table selection
chain = (
    RunnablePassthrough.assign(query= dynamic_table_chain | RunnableLambda(clean_sql))
    .assign(result=itemgetter("query") | execute_query)
    | answer_prompt
    | llm
    | StrOutputParser()
)

result = chain.invoke({"question": "give me details of customer and their order count"})
print(result)

Dynamically selected tables: ['customers', 'orders']
Euro+ Shopping Channel has the highest number of orders with 26, followed by Mini Gifts Distributors Ltd. with 17 orders. Several customers, including Australian Collectors, Co., Danish Wholesale Imports, Down Under Souveniers, Inc, Dragon Souveniers, Ltd., and Reims Collectables, each have 5 orders. There are also customers like Anna's Decorations, Ltd, Baane Mini Imports, and Blauer See Auto, Co. with 4 orders each, and many others with 3 or 2 orders. Additionally, 21 customers, such as American Souvenirs Inc, ANG Resellers, and Anton Designs, Ltd., currently have no orders.


In [171]:
chain.invoke({"question": "Can you list their names?"})

Dynamically selected tables: ['customers', 'employees', 'products', 'productlines']


'Mary Patterson and Jeff Firrelli are the names of the employees.'

**6. Adding Memory for Follow-up Questions**

To handle "What about in France?" after asking about the customers in USA,we need memory. We using *RunnableWithMessageHistory* here to get session history.

In [172]:
from langchain_core.chat_history import InMemoryChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory

# 1. Define History Store
store = {}

def get_session_history(session_id: str):
    if session_id not in store:
        store[session_id] = InMemoryChatMessageHistory()
    return store[session_id]

In [173]:
# 2. Update Prompt to include History
memory_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", final_prompt.messages[0].prompt.template),
        ("placeholder", "{chat_history}"),
        ("human", "{question}"),
    ]
)

In [174]:
# 3. Create the Chain with Memory
chain_with_memory = (
    RunnablePassthrough.assign(schema=get_relevant_schema)
    | memory_prompt
    | llm
    | StrOutputParser()
)

In [175]:
# 4. Wrap with History Runnable
final_conversational_chain = RunnableWithMessageHistory(
    chain_with_memory,
    get_session_history,
    input_messages_key="question",
    history_messages_key="chat_history",
)

def full_turn(question, session_id):
    # 1. Generate SQL
    generated_sql = final_conversational_chain.invoke(
        {"question": question},
        config={"configurable": {"session_id": session_id}}
    )
    clean_query = clean_sql(generated_sql)
    print(f"SQL: {clean_query}")

    # 2. Execute SQL
    try:
        sql_result = execute_query.invoke(clean_query)
    except Exception as e:
        return f"Error executing SQL: {e}"

    answer_chain = answer_prompt | llm | StrOutputParser()
    final_answer = answer_chain.invoke({"question": question, "query": clean_query, "result": sql_result})

    return final_answer

    # Usage
session_id = "user_123"
print(full_turn("How many customers are in USA?", session_id))
print(full_turn("What about in France?", session_id)) # Follow-up

Dynamically selected tables: ['customers']
SQL: SELECT COUNT(customerNumber) FROM customers WHERE country = 'USA'
There are 36 customers in the USA.
Dynamically selected tables: ['customers', 'employees', 'offices', 'orders', 'payments']
SQL: SELECT COUNT(customerNumber) FROM customers WHERE country = 'France'
There are 12 customers in France.
