<a href="https://colab.research.google.com/github/vatsalagarwal09/GenAI/blob/main/Build_a_Text2SQL_AI_Workflow_with_LangChain.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Build a Text2SQL AI Workflow with LangChain




![](https://i.imgur.com/7WrLz9I.png)



## Install OpenAI, and LangChain dependencies

In [1]:
!pip install langchain==0.3.14
!pip install langchain-openai==0.3.0
!pip install langchain-community==0.3.14



In [2]:
!apt-get install sqlite3 -y

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
sqlite3 is already the newest version (3.37.2-2ubuntu0.5).
0 upgraded, 0 newly installed, 0 to remove and 35 not upgraded.


## Enter Open AI API Key

In [3]:
import google.generativeai as genai

In [4]:
from getpass import getpass

gemini_key = getpass("Enter Gemini Key")

Enter Gemini Key··········


## Setup Environment Variables

In [5]:
import os

os.environ["GOOGLE_API_KEY"] = gemini_key

## Get SQL DB Script

In [6]:
# in case of issues download from https://drive.google.com/file/d/16mZm3C7xKpPqp_86e64uzduLpM5mPUdq/view?usp=sharing and upload
!gdown 16mZm3C7xKpPqp_86e64uzduLpM5mPUdq

Downloading...
From: https://drive.google.com/uc?id=16mZm3C7xKpPqp_86e64uzduLpM5mPUdq
To: /content/comicdb_create_script.sql
  0% 0.00/14.5k [00:00<?, ?B/s]100% 14.5k/14.5k [00:00<00:00, 55.6MB/s]


## Create Comic Store Database

In [7]:
!sqlite3 --version

3.37.2 2022-01-06 13:25:41 872ba256cbf61d9290b571c0e6d82a20c224ca3ad82971edc46b29818d5dalt1


In [8]:
!sqlite3 ComicStore.db ".read ./comicdb_create_script.sql"

In [9]:
!sqlite3 ComicStore.db "SELECT name FROM sqlite_master WHERE type='table';"

Branch
Employee
Publisher
Comic
Inventory
Customer
Sale
SaleTransactions


In [10]:
%%bash
sqlite3 ComicStore.db <<EOF
.headers on
.mode column
SELECT * FROM Comic LIMIT 10;
EOF

ComicId  Title                            PublisherId  Genre            Price  ReleaseDate
-------  -------------------------------  -----------  ---------------  -----  -----------
1        Spider-Man: Homecoming           1            Superhero        19.99  2017-07-07 
2        Batman: Year One                 2            Superhero        14.99  1987-02-01 
3        Hellboy: Seed of Destruction     3            Supernatural     24.99  1994-10-01 
4        Saga Volume 1                    4            Fantasy          12.99  2012-03-14 
5        Transformers: All Hail Megatron  5            Science Fiction  25.99  2008-09-01 
6        X-Men: Days of Future Past       1            Superhero        18.99  1981-01-01 
7        The Killing Joke                 2            Superhero        14.99  1988-03-29 
8        Sin City: The Hard Goodbye       4            Noir             22.99  1991-06-01 
9        Usagi Yojimbo Volume 1           5            Adventure        20.99  1987-09-01 

In [10]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///ComicStore.db")
db

<langchain_community.utilities.sql_database.SQLDatabase at 0x7e8b8a9e7d10>

In [12]:
print(db.dialect)
print(db.get_usable_table_names())

sqlite
['Branch', 'Comic', 'Customer', 'Employee', 'Inventory', 'Publisher', 'Sale', 'SaleTransactions']


## Overview of the Comic Store Database

In this project, we will utilize a **Comic Store Database** to demonstrate the capabilities of a Text2SQL workflow. The database schema includes the following entities and relationships:

1. **Branch**: Stores details about comic store branches, including their location and contact information.

2. **Publisher**: Contains information about publishers, such as their name, country, and the year they were established.

3. **Comic**: Represents the comics, including their title, genre, price, release date, and associated publisher.

4. **Customer**: Tracks customer details, including their contact information and location.

5. **Employee**: Holds data about store employees, including their branch, title, and hire date.

6. **Inventory**: Manages the stock of comics available at different branches.

7. **Sale**: Records sales transactions, including the employee and customer involved, as well as the total amount and sale date.

8. **SaleTransactions**: Tracks individual items within a sale, including the quantity and price of each comic sold.

### Relationships:
- A **Publisher** publishes multiple **Comics**.
- A **Branch** stocks multiple **Comics** through the **Inventory** table.
- A **Customer** makes **Sales**, which are processed by **Employees**.
- Each **Sale** contains multiple items recorded in **SaleTransactions**.

This database schema is well-suited for queries related to inventory management, sales analysis, customer interactions, and employee performance in the context of a comic store business.


![](https://i.imgur.com/YzNCLpV.png)

In [None]:
db.run("SELECT * FROM Comic LIMIT 10;", include_columns=True)

In [None]:
db.run("SELECT * FROM Employee LIMIT 10;", include_columns=True)

In [None]:
db.run("SELECT * FROM Sale LIMIT 10;", include_columns=True)

In [None]:
db.run("SELECT * FROM SaleTransactions LIMIT 10;", include_columns=True)

In [None]:
print(db.get_table_info(table_names=['Comic', 'Sale']))

## Build Text2SQL Components for AI Workflow

In [11]:
# This prompt is customized from here
# https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/chains/sql_database/prompt.py
# you might need to customize it based on the LLM you are using
# the output format might vary so you may need to mention explicit instructions in the prompt

from langchain_core.prompts.prompt import PromptTemplate

PROMPT_SUFFIX = """Only use the following tables:
{table_info}

Question: {input}"""

_sqlite_prompt = """You are a SQLite expert.
Given an input question, first create a syntactically correct SQLite query to run,
then look at the results of the query and return the answer to the input question.

Unless the user specifies in the question a specific number of examples to obtain,
query for at most {top_k} results using the LIMIT clause as per SQLite.

You can order the results to return the most informative data in the database.
Never query for all columns from a table.

You must query only the columns that are needed to answer the question.
Wrap each column name in double quotes (") to denote them as delimited identifiers.

Pay attention to use only the column names you can see in the tables below.
Be careful to not query for columns that do not exist.
Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".
Pay attention to use table JOINS as necessary if you are adding relevant fields from different tables.

Generate the output in the exact following format:

SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here

The SQLQuery field above should have the correct SQLite query as plain text without any formatting or code blocks.
Do not include sql or similar markers.
Do not try to explain the query, just provide the query as-is, like this: SELECT ...
"""

SQLITE_PROMPT = PromptTemplate(
    input_variables=["input", "table_info", "top_k"],
    template=_sqlite_prompt + PROMPT_SUFFIX,
)

In [12]:
SQLITE_PROMPT

PromptTemplate(input_variables=['input', 'table_info', 'top_k'], input_types={}, partial_variables={}, template='You are a SQLite expert.\nGiven an input question, first create a syntactically correct SQLite query to run,\nthen look at the results of the query and return the answer to the input question.\n\nUnless the user specifies in the question a specific number of examples to obtain,\nquery for at most {top_k} results using the LIMIT clause as per SQLite.\n\nYou can order the results to return the most informative data in the database.\nNever query for all columns from a table.\n\nYou must query only the columns that are needed to answer the question.\nWrap each column name in double quotes (") to denote them as delimited identifiers.\n\nPay attention to use only the column names you can see in the tables below.\nBe careful to not query for columns that do not exist.\nAlso, pay attention to which column is in which table.\nPay attention to use date(\'now\') function to get the cur

In [13]:
import google.generativeai as genai

genai.configure(api_key=gemini_key)

In [14]:
!pip install langchain_google_genai

Collecting langchain-core<0.4.0,>=0.3.68 (from langchain_google_genai)
  Using cached langchain_core-0.3.74-py3-none-any.whl.metadata (5.8 kB)
Collecting langsmith>=0.3.45 (from langchain-core<0.4.0,>=0.3.68->langchain_google_genai)
  Using cached langsmith-0.4.16-py3-none-any.whl.metadata (14 kB)
Using cached langchain_core-0.3.74-py3-none-any.whl (443 kB)
Using cached langsmith-0.4.16-py3-none-any.whl (375 kB)
Installing collected packages: langsmith, langchain-core
  Attempting uninstall: langsmith
    Found existing installation: langsmith 0.2.11
    Uninstalling langsmith-0.2.11:
      Successfully uninstalled langsmith-0.2.11
  Attempting uninstall: langchain-core
    Found existing installation: langchain-core 0.3.63
    Uninstalling langchain-core-0.3.63:
      Successfully uninstalled langchain-core-0.3.63
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conf

In [15]:
# Implementation details in langchain source code:
# https://api.python.langchain.com/en/latest/_modules/langchain/chains/sql_database/query.html#create_sql_query_chain
from langchain_openai import ChatOpenAI
from langchain.chains import create_sql_query_chain
from langchain_google_genai import ChatGoogleGenerativeAI

# chatgpt = ChatOpenAI(model="gpt-4o", temperature=0)
gemini = ChatGoogleGenerativeAI(model="gemini-2.5-flash", temperature=0)
text2sql_chain = create_sql_query_chain(llm=gemini,
                                        db=db,
                                        prompt=SQLITE_PROMPT,
                                        k=5)
text2sql_chain

RunnableAssign(mapper={
  input: RunnableLambda(...),
  table_info: RunnableLambda(...)
})
| RunnableLambda(lambda x: {k: v for k, v in x.items() if k not in ('question', 'table_names_to_use')})
| PromptTemplate(input_variables=['input', 'table_info'], input_types={}, partial_variables={'top_k': '5'}, template='You are a SQLite expert.\nGiven an input question, first create a syntactically correct SQLite query to run,\nthen look at the results of the query and return the answer to the input question.\n\nUnless the user specifies in the question a specific number of examples to obtain,\nquery for at most {top_k} results using the LIMIT clause as per SQLite.\n\nYou can order the results to return the most informative data in the database.\nNever query for all columns from a table.\n\nYou must query only the columns that are needed to answer the question.\nWrap each column name in double quotes (") to denote them as delimited identifiers.\n\nPay attention to use only the column names you 

In [16]:
text2sql_chain.get_prompts()[0].pretty_print()

You are a SQLite expert.
Given an input question, first create a syntactically correct SQLite query to run,
then look at the results of the query and return the answer to the input question.

Unless the user specifies in the question a specific number of examples to obtain,
query for at most 5 results using the LIMIT clause as per SQLite.

You can order the results to return the most informative data in the database.
Never query for all columns from a table.

You must query only the columns that are needed to answer the question.
Wrap each column name in double quotes (") to denote them as delimited identifiers.

Pay attention to use only the column names you can see in the tables below.
Be careful to not query for columns that do not exist.
Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".
Pay attention to use table JOINS as necessary if you are adding relevant fields from differe

In [17]:
response = text2sql_chain.invoke({"question": "Top 5 most popular comics"})
print(response)

SQLQuery: SELECT T1.Title, SUM(T2.Quantity) AS TotalQuantitySold FROM Comic AS T1 INNER JOIN SaleTransactions AS T2 ON T1.ComicId = T2.ComicId GROUP BY T1.ComicId, T1.Title ORDER BY TotalQuantitySold DESC LIMIT 5


In [None]:
db.run(response)

"[('Wolverine: Old Man Logan', 3), ('V for Vendetta', 2), ('Usagi Yojimbo Volume 1', 2), ('Transformers: All Hail Megatron', 2), ('The Killing Joke', 2)]"

## Create SQL Query Write & Execute Workflow Chains

In [18]:
from langchain_community.tools import QuerySQLDatabaseTool

execute_query_tool = QuerySQLDatabaseTool(db=db)
execute_query_tool

QuerySQLDatabaseTool(db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x7e8b8a9e7d10>)

In [39]:
query_write_chain = create_sql_query_chain(llm=gemini,
                                           db=db,
                                           prompt=SQLITE_PROMPT,
                                           k=10)
# query_execute_chain = (query_write_chain
#                             |
#                        execute_query_tool)


query_execute_chain = (RunnablePassthrough.assign(query=query_write_chain)
                            |
                           RunnableParallel(
        result=execute_query_tool,
        query=lambda x: x['query']  # This passes the query key through to the next step
    ))


query_execute_chain.invoke({"question": "Top 5 most popular comics"})

{'result': "[('Wolverine: Old Man Logan', 3), ('V for Vendetta', 2), ('Usagi Yojimbo Volume 1', 2), ('Transformers: All Hail Megatron', 2), ('The Killing Joke', 2)]",
 'query': 'SELECT\n  T2."Title",\n  SUM(T1."Quantity") AS "TotalQuantitySold"\nFROM SaleTransactions AS T1\nINNER JOIN Comic AS T2\n  ON T1."ComicId" = T2."ComicId"\nGROUP BY\n  T2."Title"\nORDER BY\n  "TotalQuantitySold" DESC\nLIMIT 5;'}

In [20]:
query_execute_chain.invoke({"question": "Top 5 customers with most comics purchased"})

"[('Tony', 'Stark', 8), ('Sarah', 'Connor', 7), ('Robert', 'Taylor', 6), ('Diana', 'Prince', 6), ('Clark', 'Kent', 6)]"

In [None]:
query_execute_chain.invoke({"question": "Top 5 customers with most money spent"})

"[('Tony', 'Stark', 164.94), ('Bruce', 'Wayne', 139.94), ('Sarah', 'Connor', 124.96), ('Clark', 'Kent', 114.96), ('Natasha', 'Romanoff', 111.96000000000001)]"

In [None]:
query_execute_chain.invoke({"question": "Top 3 salesman with highest revenue"})

"[('John', 'Doe', 255.89), ('Alice', 'Brown', 234.94), ('Jane', 'Smith', 234.88)]"

## Create Text2SQL AI Workflow Chain

In [55]:
from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough, RunnableParallel

answer_prompt = PromptTemplate.from_template(
    """Given the following user question, corresponding SQL query, and SQL result,
       create a helpful answer the user question.

       When generating the final answer in markdown from the results,
       if there are special characters in the text, such as the dollar symbol,
       ensure they are escaped properly for correct rendering e.g $25.5 should become \$25.5

       Question: {question}
       SQL Query: {query}
       SQL Result: {result}
       Answer:
    """
)

text2sql_chain = (
    RunnablePassthrough.assign(result=query_execute_chain)
    |
    # This step pulls the 'query' and 'result' out of the nested 'result' dictionary
    RunnablePassthrough.assign(
        query=lambda x: x['result']['query'],
        result=lambda x: x['result']['result']
    )
        |
    answer_prompt
        |
    gemini
        |
    StrOutputParser()
)

# This chain was calling query_write_chain twice, so, updated the chain in the above code to make it call single time
# text2sql_chain = (
#     RunnablePassthrough.assign(query=query_write_chain)
#         |
#     RunnablePassthrough.assign(result=query_execute_chain)
#         |
#     answer_prompt
#         |
#     gemini
#         |
#     StrOutputParser()
# )

  ensure they are escaped properly for correct rendering e.g $25.5 should become \$25.5


## Test the Text2SQL AI Workflow

In [54]:
from IPython.display import display, Markdown
response = text2sql_chain.invoke({"question": "Total number of customers"})
response
# display(Markdown(response))

'The total number of customers is 20.'

In [23]:
response = text2sql_chain.invoke({"question": "What are the Top 10 most popular comics"})
display(Markdown(response))

Here are the Top 10 most popular comics:

1.  Wolverine: Old Man Logan - 3 units sold
2.  V for Vendetta - 2 units sold
3.  Usagi Yojimbo Volume 1 - 2 units sold
4.  Transformers: All Hail Megatron - 2 units sold
5.  The Killing Joke - 2 units sold
6.  The Boys Volume 1 - 2 units sold
7.  Superman: Red Son - 2 units sold
8.  Punisher: Welcome Back, Frank - 2 units sold
9.  Preacher Volume 1 - 2 units sold
10. Ms. Marvel Volume 1 - 2 units sold

In [24]:
response = text2sql_chain.invoke({"question": "Top 5 customers with most comics purchased"})
display(Markdown(response))

Here are the top 5 customers with the most comics purchased:

1.  **Tony Stark**: 8 comics
2.  **Sarah Connor**: 7 comics
3.  **Robert Taylor**: 6 comics
4.  **Diana Prince**: 6 comics
5.  **Clark Kent**: 6 comics

In [None]:
response = text2sql_chain.invoke({"question": "Which are the top 5 customers with most money spent"})
display(Markdown(response))

Here are the top 5 customers who have spent the most money:

1. **Tony Stark**: \$164.94
2. **Bruce Wayne**: \$139.94
3. **Sarah Connor**: \$124.96
4. **Clark Kent**: \$114.96
5. **Natasha Romanoff**: \$111.96

These customers have made significant purchases, contributing to their high total spending.

In [None]:
response = text2sql_chain.invoke({"question": "Which are the top 3 salesman with highest revenue"})
display(Markdown(response))

The top 3 salespeople with the highest revenue are:

1. **John Doe** with a total revenue of \$255.89
2. **Alice Brown** with a total revenue of \$234.94
3. **Jane Smith** with a total revenue of \$234.88

These individuals have achieved the highest sales figures in the company.