# Dynamic few-shot examples

If we have enough examples, we may want to only include the most relevant ones in the prompt, either because they don't fit in the model's context window or because the long tail of examples distracts the model. And specifically, given any input we want to include the examples most relevant to that input.

We can do just this using an ExampleSelector. 
In this case we'll use a SemanticSimilarityExampleSelector, which will store the examples in the vector database of our choosing. At runtime it will perform a similarity search between the input and our examples, and return the most semantically similar ones.

https://python.langchain.com/v0.2/api_reference/core/example_selectors/langchain_core.example_selectors.semantic_similarity.SemanticSimilarityExampleSelector.html

We default to OpenAI embeddings here, but you can swap them out for the model provider of your choice.

https://python.langchain.com/v0.2/docs/how_to/sql_prompting/#dynamic-few-shot-examples



In [33]:
import os
from dotenv import load_dotenv
load_dotenv()
os.environ["LANGCHAIN_API_KEY"]=os.environ.get('LANGCHAIN_API_KEY')
os.environ["LANGCHAIN_TRACING_V2"]="true"
os.environ["LANGCHAIN_PROJECT"]="Q&A_over_SQL_data"

# SQL DB Creation
Create an SQL database that we can query



In [34]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///Chinook.db")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10;")

sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

# Hook the DB to LLM

In [35]:
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(model="gpt-4o-mini")

In [36]:
from langchain.chains import create_sql_query_chain

chain = create_sql_query_chain(llm, db)
response = chain.invoke({"question": "How many employees are there?"})
response

'SQLQuery: SELECT COUNT("EmployeeId") AS "EmployeeCount" FROM "Employee";'

# Look at the default prompt

This has a placeholder for {table_info} and {input}

In [37]:
chain.get_prompts()[0].pretty_print()

You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result

# Table definitions and example rows and dialect

We will add more place holders:

1) ```{dialect}```
2) ```{top_k}```
3) Value of```{table_info}```


First get the table_info

In [38]:
context = db.get_context()
print(list(context))
print(context["table_info"])

['table_info', 'table_names']

CREATE TABLE "Album" (
	"AlbumId" INTEGER NOT NULL, 
	"Title" NVARCHAR(160) NOT NULL, 
	"ArtistId" INTEGER NOT NULL, 
	PRIMARY KEY ("AlbumId"), 
	FOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")
)

/*
3 rows from Album table:
AlbumId	Title	ArtistId
1	For Those About To Rock We Salute You	1
2	Balls to the Wall	2
3	Restless and Wild	2
*/


CREATE TABLE "Artist" (
	"ArtistId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("ArtistId")
)

/*
3 rows from Artist table:
ArtistId	Name
1	AC/DC
2	Accept
3	Aerosmith
*/


CREATE TABLE "Customer" (
	"CustomerId" INTEGER NOT NULL, 
	"FirstName" NVARCHAR(40) NOT NULL, 
	"LastName" NVARCHAR(20) NOT NULL, 
	"Company" NVARCHAR(80), 
	"Address" NVARCHAR(70), 
	"City" NVARCHAR(40), 
	"State" NVARCHAR(40), 
	"Country" NVARCHAR(40), 
	"PostalCode" NVARCHAR(10), 
	"Phone" NVARCHAR(24), 
	"Fax" NVARCHAR(24), 
	"Email" NVARCHAR(60) NOT NULL, 
	"SupportRepId" INTEGER, 
	PRIMARY KEY ("CustomerId"), 
	FOREIGN KEY("S

# Prepare prompt with table_info Context


When we don't have too many, or too wide of, tables, we can just insert the entirety of this information in our prompt:



In [52]:
system = """You are a {dialect} expert. Given an input question, create a syntactically correct {dialect} query to run.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per {dialect}. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".

Only use the following tables:
{table_info}

"""


In [53]:
db.dialect

'sqlite'

In [54]:
from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate

example_prompt = PromptTemplate.from_template("User input: {input}\nSQL query: {query}")


In [55]:
example_prompt.pretty_print()

User input: [33;1m[1;3m{input}[0m
SQL query: [33;1m[1;3m{query}[0m


In [56]:
examples = [
    {"input": "List all artists.", "query": "SELECT * FROM Artist;"},
    {
        "input": "Find all albums for the artist 'AC/DC'.",
        "query": "SELECT * FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'AC/DC');",
    },
    {
        "input": "List all tracks in the 'Rock' genre.",
        "query": "SELECT * FROM Track WHERE GenreId = (SELECT GenreId FROM Genre WHERE Name = 'Rock');",
    },
    {
        "input": "Find the total duration of all tracks.",
        "query": "SELECT SUM(Milliseconds) FROM Track;",
    },
    {
        "input": "List all customers from Canada.",
        "query": "SELECT * FROM Customer WHERE Country = 'Canada';",
    },
    {
        "input": "How many tracks are there in the album with ID 5?",
        "query": "SELECT COUNT(*) FROM Track WHERE AlbumId = 5;",
    },
    {
        "input": "Find the total number of invoices.",
        "query": "SELECT COUNT(*) FROM Invoice;",
    },
    {
        "input": "List all tracks that are longer than 5 minutes.",
        "query": "SELECT * FROM Track WHERE Milliseconds > 300000;",
    },
    {
        "input": "Who are the top 5 customers by total purchase?",
        "query": "SELECT CustomerId, SUM(Total) AS TotalPurchase FROM Invoice GROUP BY CustomerId ORDER BY TotalPurchase DESC LIMIT 5;",
    },
    {
        "input": "Which albums are from the year 2000?",
        "query": "SELECT * FROM Album WHERE strftime('%Y', ReleaseDate) = '2000';",
    },
    {
        "input": "How many employees are there",
        "query": 'SELECT COUNT(*) FROM "Employee"',
    },
]

At runtime it will perform a similarity search between the input and our examples, and return the most semantically similar ones.

In [57]:
from langchain_community.vectorstores import FAISS
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_openai import OpenAIEmbeddings

example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples,
    OpenAIEmbeddings(),
    FAISS,
    k=5,
    input_keys=["input"],
)

Test the example selector

In [58]:
example_selector.select_examples({"input": "how many artists are there?"})

[{'input': 'List all artists.', 'query': 'SELECT * FROM Artist;'},
 {'input': 'How many employees are there',
  'query': 'SELECT COUNT(*) FROM "Employee"'},
 {'input': 'How many tracks are there in the album with ID 5?',
  'query': 'SELECT COUNT(*) FROM Track WHERE AlbumId = 5;'},
 {'input': 'Which albums are from the year 2000?',
  'query': "SELECT * FROM Album WHERE strftime('%Y', ReleaseDate) = '2000';"},
 {'input': "List all tracks in the 'Rock' genre.",
  'query': "SELECT * FROM Track WHERE GenreId = (SELECT GenreId FROM Genre WHERE Name = 'Rock');"}]

To use it, we can pass the ExampleSelector directly in to our FewShotPromptTemplate:

In [59]:
prompt = FewShotPromptTemplate(
    example_selector=example_selector,
    example_prompt=example_prompt,
    prefix=system,
    suffix="User input: {input}\nSQL query: ",
    input_variables=["input", "top_k", "table_info"],
)

In [60]:
prompt.pretty_print()

You are a [33;1m[1;3m{dialect}[0m expert. Given an input question, create a syntactically correct [33;1m[1;3m{dialect}[0m query to run.
Unless the user specifies in the question a specific number of examples to obtain, query for at most [33;1m[1;3m{top_k}[0m results using the LIMIT clause as per [33;1m[1;3m{dialect}[0m. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".

Only use the following tables:
[33;1m[1;3m{table_info}[0m



User input: How many employees are th

In [61]:
chain = create_sql_query_chain(llm, db, prompt=prompt)

In [62]:
response = chain.invoke({"question": "How many employees are there?"})
response

'```sql\nSELECT COUNT(*) FROM "Employee";\n```'

In [63]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate

system = """Double check the user's {dialect} query for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins
- Giving output of query in backticks like ```query```


If there are any of the above mistakes, rewrite the query.
Give only the SQL query and no other characters. Not even header like SQLQuery: etc
Don't use backticks around the query ```
If there are no mistakes, just reproduce the original query with no further commentary.

Output the final SQL query only."""
prompt = ChatPromptTemplate.from_messages(
    [("system", system), ("human", "{query}")]
).partial(dialect=db.dialect)
validation_chain = prompt | llm | StrOutputParser()

In [64]:
write_query_with_validation = {"query": chain} | validation_chain

In [65]:
response = write_query_with_validation.invoke({"question": "How many employees are there?"})
response

'SELECT COUNT(*) FROM "Employee";'

In [67]:
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool

execute_query = QuerySQLDataBaseTool(db=db)
execute_query_chain = write_query_with_validation | execute_query

In [68]:
query_execution_result = execute_query_chain.invoke(
    {
        "question": "How many employees are there."
    }
)
print(query_execution_result)

[(8,)]


# Answer the question in Natural Language

Now that we've got a way to automatically generate and execute queries, we just need to combine the original question and SQL query result to generate a final answer. We can do this by passing question and result to the LLM once more:

In [70]:
from operator import itemgetter

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough

answer_prompt = PromptTemplate.from_template(
    """Given the following user question, corresponding SQL query, and SQL result, answer the user question. 

Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer: """
)

sql_qa_chain = (
    RunnablePassthrough.assign(query=write_query_with_validation).assign(
        result=itemgetter("query") | execute_query
    )
    | answer_prompt
    | llm
    | StrOutputParser()
)


In [71]:
sql_qa_chain.invoke({"question": "How many employees are there."})

'There are 8 employees.'

In [72]:
sql_qa_chain.invoke({"question": "List all artists."})


'The artists listed are:\n\n1. AC/DC\n2. Accept\n3. Aerosmith\n4. Alanis Morissette\n5. Alice In Chains'

In [73]:

sql_qa_chain.invoke({"question": "Find all albums for the artist 'AC/DC'."})


"The albums for the artist 'AC/DC' are:\n\n1. For Those About To Rock We Salute You\n2. Let There Be Rock"

In [74]:

sql_qa_chain.invoke({"question": "List all tracks in the 'Rock' genre."})


"Here are some tracks in the 'Rock' genre:\n\n1. For Those About To Rock (We Salute You) - TrackId: 1, AlbumId: 1, Unit Price: $0.99\n2. Balls to the Wall - TrackId: 2, AlbumId: 2, Unit Price: $0.99\n3. Fast As a Shark - TrackId: 3, AlbumId: 3, Unit Price: $0.99\n4. Restless and Wild - TrackId: 4, AlbumId: 3, Unit Price: $0.99\n5. Princess of the Dawn - TrackId: 5, AlbumId: 3, Unit Price: $0.99\n\nThese are the first five tracks listed in the 'Rock' genre."

In [75]:

sql_qa_chain.invoke({"question": "Find the total duration of all tracks."})


'The total duration of all tracks is 1,378,778,040 milliseconds.'

In [76]:

sql_qa_chain.invoke({"question": "List all customers from Canada."})


'Here are some customers from Canada:\n\n1. François Tremblay - ftremblay@gmail.com\n2. Mark Philips - mphilips12@shaw.ca\n3. Jennifer Peterson - jenniferp@rogers.ca\n4. Robert Brown - robbrown@shaw.ca\n5. Edward Francis - edfrancis@yachoo.ca'

In [77]:

sql_qa_chain.invoke({"question": "How many tracks are there in the album with ID 5?"})


'There are 15 tracks in the album with ID 5.'

In [78]:

sql_qa_chain.invoke({"question": "Find the total number of invoices."})


'The total number of invoices is 412.'

In [79]:

sql_qa_chain.invoke({"question": "List all tracks that are longer than 5 minutes."})


'The tracks that are longer than 5 minutes are:\n\n1. For Those About To Rock (We Salute You) - 343719 milliseconds\n2. Balls to the Wall - 342562 milliseconds\n3. Princess of the Dawn - 375418 milliseconds\n4. Go Down - 331180 milliseconds\n5. Let There Be Rock - 366654 milliseconds\n\n(Note: All tracks listed have a duration greater than 5 minutes, as 5 minutes is equivalent to 300,000 milliseconds.)'

In [80]:

sql_qa_chain.invoke({"question": "Who are the top 5 customers by total purchase?"})


'The top 5 customers by total purchase are:\n\n1. Customer ID 6 with a total purchase of $49.62\n2. Customer ID 26 with a total purchase of $47.62\n3. Customer ID 57 with a total purchase of $46.62\n4. Customer ID 45 with a total purchase of $45.62\n5. Customer ID 46 with a total purchase of $45.62'

In [81]:

sql_qa_chain.invoke({"question": "Which albums are from the year 2000?"})


'The SQL query appears to have an error because it is trying to extract the year from the "AlbumId" column instead of a date column that would contain the release year of the albums. To accurately answer the question regarding which albums are from the year 2000, the SQL query should reference a date column, such as "ReleaseDate" or an equivalent column that contains the year information.\n\nHowever, based on the provided SQL result, it seems there are no albums listed because the query is not correctly targeting the intended data. \n\nIf the query were corrected to reflect a proper date column, the answer would list the album titles and their corresponding artist IDs released in the year 2000. If you have a specific dataset or corrected query, please provide that for a more accurate answer.'

In [82]:

sql_qa_chain.invoke({"question": "How many employees are there."})


'There are 8 employees.'