# Many tables

One of the main pieces of information we need to include in our prompt is the schemas of the relevant tables. 

When we have very many tables, we can't fit all of the schemas in a single prompt. 

What we can do in such cases is first extract the names of the tables related to the user input, and then include only their schemas.

One easy and reliable way to do this is using tool-calling. 

Below, we show how we can use this feature to obtain output conforming to a desired format (in this case, a list of table names). 

We use the chat model's .bind_tools method to bind a tool in Pydantic format, 

and feed this into an output parser to reconstruct the object from the model's response.

In [1]:
import os
from dotenv import load_dotenv
load_dotenv()
os.environ["LANGCHAIN_API_KEY"]=os.environ.get('LANGCHAIN_API_KEY')
os.environ["LANGCHAIN_TRACING_V2"]="true"
os.environ["LANGCHAIN_PROJECT"]="Q&A_over_SQL_data"

# SQL DB Creation
Create an SQL database that we can query



In [2]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///Chinook.db")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10;")

sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

# Hook the DB to LLM

In [3]:
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(model="gpt-4o-mini")

In [4]:
from langchain_core.output_parsers.openai_tools import PydanticToolsParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field


class Table(BaseModel):
    """Table in SQL database."""

    name: str = Field(description="Name of table in SQL database.")


table_names = "\n".join(db.get_usable_table_names())
system = f"""Return the names of ALL the SQL tables that MIGHT be relevant to the user question. \
The tables are:

{table_names}

Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed."""

prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "{input}"),
    ]
)
llm_with_tools = llm.bind_tools([Table])
output_parser = PydanticToolsParser(tools=[Table])

table_chain = prompt | llm_with_tools | output_parser

table_chain.invoke({"input": "What are all the genres of Alanis Morisette songs"})

[Table(name='Artist'),
 Table(name='Album'),
 Table(name='Genre'),
 Table(name='Track')]

In [5]:
system = """Return the names of any SQL tables that are relevant to the user question.
The tables are:

Music
Business
"""

prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "{input}"),
    ]
)

category_chain = prompt | llm_with_tools | output_parser
category_chain.invoke({"input": "What are all the genres of Alanis Morisette songs"})

[Table(name='Music')]

In [6]:
from typing import List


def get_tables(categories: List[Table]) -> List[str]:
    tables = []
    for category in categories:
        if category.name == "Music":
            tables.extend(
                [
                    "Album",
                    "Artist",
                    "Genre",
                    "MediaType",
                    "Playlist",
                    "PlaylistTrack",
                    "Track",
                ]
            )
        elif category.name == "Business":
            tables.extend(["Customer", "Employee", "Invoice", "InvoiceLine"])
    return tables


table_chain = category_chain | get_tables
table_chain.invoke({"input": "What are all the genres of Alanis Morisette songs"})

['Album', 'Artist', 'Genre', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']

In [7]:
from operator import itemgetter

from langchain.chains import create_sql_query_chain
from langchain_core.runnables import RunnablePassthrough

query_chain = create_sql_query_chain(llm, db)


In [8]:
query_chain.get_prompts()[0].pretty_print()

You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result

In [9]:
# Convert "question" key to the "input" key expected by current table_chain.
table_chain = {"input": itemgetter("question")} | table_chain
# Set table_names_to_use using table_chain.
full_chain = RunnablePassthrough.assign(table_names_to_use=table_chain) | query_chain

In [10]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate

system = """Double check the user's {dialect} query for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins
- Giving output of query in backticks like ```query```


If there are any of the above mistakes, rewrite the query.
Give only the SQL query and no other characters. Not even header like SQLQuery: etc
Don't use backticks around the query ```
If there are no mistakes, just reproduce the original query with no further commentary.

Output the final SQL query only."""
prompt = ChatPromptTemplate.from_messages(
    [("system", system), ("human", "{query}")]
).partial(dialect=db.dialect)
validation_chain = prompt | llm | StrOutputParser()

write_query_with_validation = {"query": full_chain} | validation_chain

In [11]:
query = write_query_with_validation.invoke(
    {"question": "What are all the genres of Alanis Morisette songs"}
)
print(query)

SELECT DISTINCT "Genre"."Name"
FROM "Track"
JOIN "Album" ON "Track"."AlbumId" = "Album"."AlbumId"
JOIN "Artist" ON "Album"."ArtistId" = "Artist"."ArtistId"
JOIN "Genre" ON "Track"."GenreId" = "Genre"."GenreId"
WHERE "Artist"."Name" = 'Alanis Morissette'
LIMIT 5;


In [12]:
db.run(query)

"[('Rock',)]"

In [13]:
query = write_query_with_validation.invoke(
    {"question": "List all artists."}
)
print(query)
db.run(query)

SELECT "ArtistId", "Name" FROM "Artist" ORDER BY "ArtistId" LIMIT 5;


"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains')]"

In [14]:

query = write_query_with_validation.invoke(
    {"question": "Find all albums for the artist 'AC/DC'."}
)
print(query)
db.run(query)

SELECT "Album"."Title" FROM "Album" JOIN "Artist" ON "Album"."ArtistId" = "Artist"."ArtistId" WHERE "Artist"."Name" = 'AC/DC' LIMIT 5;


"[('For Those About To Rock We Salute You',), ('Let There Be Rock',)]"

In [15]:

query = write_query_with_validation.invoke(
    {"question": "List all tracks in the 'Rock' genre."}
)
print(query)
db.run(query)

SELECT "Track"."Name" FROM "Track" JOIN "Genre" ON "Track"."GenreId" = "Genre"."GenreId" WHERE "Genre"."Name" = 'Rock' LIMIT 5;


"[('For Those About To Rock (We Salute You)',), ('Balls to the Wall',), ('Fast As a Shark',), ('Restless and Wild',), ('Princess of the Dawn',)]"

In [16]:

query = write_query_with_validation.invoke(
    {"question": "Find the total duration of all tracks."}
)
print(query)
db.run(query)

SELECT SUM("Milliseconds") AS "TotalDuration" FROM "Track"


'[(1378778040,)]'

In [17]:

query = write_query_with_validation.invoke(
    {"question": "List all customers from Canada."}
)
print(query)
db.run(query)

SELECT "FirstName", "LastName", "Company", "City", "State", "Country" 
FROM "Customer" 
WHERE "Country" = 'Canada' 
LIMIT 5;


"[('François', 'Tremblay', None, 'Montréal', 'QC', 'Canada'), ('Mark', 'Philips', 'Telus', 'Edmonton', 'AB', 'Canada'), ('Jennifer', 'Peterson', 'Rogers Canada', 'Vancouver', 'BC', 'Canada'), ('Robert', 'Brown', None, 'Toronto', 'ON', 'Canada'), ('Edward', 'Francis', None, 'Ottawa', 'ON', 'Canada')]"

In [18]:

query = write_query_with_validation.invoke(
    {"question": "How many tracks are there in the album with ID 5?"}
)
print(query)
db.run(query)

SELECT COUNT("TrackId") AS "TrackCount" FROM "Track" WHERE "AlbumId" = 5


'[(15,)]'

In [19]:

query = write_query_with_validation.invoke(
    {"question": "Find the total number of invoices."}
)
print(query)
db.run(query)

SELECT COUNT("InvoiceId") AS "TotalInvoices" FROM "Invoice";


'[(412,)]'

In [20]:

query = write_query_with_validation.invoke(
    {"question": "List all tracks that are longer than 5 minutes."}
)
print(query)
db.run(query)

SELECT "TrackId", "Name", "Milliseconds" FROM "Track" WHERE "Milliseconds" > 300000 LIMIT 5;


"[(1, 'For Those About To Rock (We Salute You)', 343719), (2, 'Balls to the Wall', 342562), (5, 'Princess of the Dawn', 375418), (15, 'Go Down', 331180), (17, 'Let There Be Rock', 366654)]"

In [21]:

query = write_query_with_validation.invoke(
    {"question": "Who are the top 5 customers by total purchase?"}
)
print(query)
db.run(query)

SELECT "Customer"."FirstName", "Customer"."LastName", SUM("Invoice"."Total") AS "TotalSpent"
FROM "Customer"
JOIN "Invoice" ON "Customer"."CustomerId" = "Invoice"."CustomerId"
GROUP BY "Customer"."CustomerId"
ORDER BY "TotalSpent" DESC
LIMIT 5;


'[(\'Helena\', \'Holý\', 49.62), (\'Richard\', \'Cunningham\', 47.62), (\'Luis\', \'Rojas\', 46.62), (\'Ladislav\', \'Kovács\', 45.62), (\'Hugh\', "O\'Reilly", 45.62)]'

In [22]:

query = write_query_with_validation.invoke(
    {"question": "Which albums are from the year 2000?"}
)
print(query)
db.run(query)

SELECT "Title" FROM "Album" WHERE strftime('%Y', 'now') = '2000' LIMIT 5;


''

In [23]:

query = write_query_with_validation.invoke(
    {"question": "How many employees are there"}
)
print(query)
db.run(query)

SELECT COUNT(EmployeeId) AS EmployeeCount FROM Employee;


'[(8,)]'