In [81]:
import langchain
from dotenv import load_dotenv,find_dotenv
from langchain_community.utilities import SQLDatabase
from langchain_core.runnables import RunnablePassthrough,RunnableLambda,RunnableParallel
from langchain_core.output_parsers import StrOutputParser,JsonOutputParser
from langchain_core.prompts import (PromptTemplate, ChatPromptTemplate,FewShotPromptTemplate,
                            MessagesPlaceholder,SystemMessagePromptTemplate,HumanMessagePromptTemplate)
from langchain_community.vectorstores import FAISS, Chroma
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain.embeddings import SentenceTransformerEmbeddings
from langchain_google_genai import GoogleGenerativeAI,ChatGoogleGenerativeAI
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from langchain_community.agent_toolkits import create_sql_agent
from langchain_openai import ChatOpenAI
from langchain.chains.sql_database.prompt import SQL_PROMPTS
from langchain.chains import create_sql_query_chain

In [2]:
load_dotenv(find_dotenv("D:\LLM Courses\Master Langchain Udemy\.env"))

True

In [4]:
llm=ChatGoogleGenerativeAI(model="gemini-1.5-flash-001",temperature=0.3)
db=SQLDatabase.from_uri(database_uri="sqlite:///db/chinook.db/chinook.db")

In [5]:
db.dialect

'sqlite'

<h3>Validate Query with the second chain</h3>

In [113]:
system=SystemMessagePromptTemplate.from_template(
    template="""Double check the user's {dialect} query for common mistakes, including:
            - Using NOT IN with NULL values
            - Using UNION when UNION ALL should have been used
            - Using BETWEEN for exclusive ranges
            - Data type mismatch in predicates
            - Properly quoting identifiers
            - Using the correct number of arguments for functions
            - Casting to the correct data type
            - Using the proper columns for joins
            
            If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.
            
            Output the final SQL query only."""
            )
human=HumanMessagePromptTemplate.from_template(template="{query}")

In [114]:
system=SystemMessagePromptTemplate.from_template(
    template="""Double check the user's {dialect} query for common mistakes, including:
            - Using NOT IN with NULL values
            - Using UNION when UNION ALL should have been used
            - Using BETWEEN for exclusive ranges
            - Data type mismatch in predicates
            - Properly quoting identifiers
            - Using the correct number of arguments for functions
            - Casting to the correct data type
            - Using the proper columns for joins
            
            If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.
            
            Output the final SQL query only."""
            )
human=HumanMessagePromptTemplate.from_template(template="{query}")

In [115]:
prompt=ChatPromptTemplate.from_messages(
    messages=[
        system,
        human
    ]
).partial(dialect=db.dialect)

In [116]:
prompt.messages[0]

SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=['dialect'], template="Double check the user's {dialect} query for common mistakes, including:\n            - Using NOT IN with NULL values\n            - Using UNION when UNION ALL should have been used\n            - Using BETWEEN for exclusive ranges\n            - Data type mismatch in predicates\n            - Properly quoting identifiers\n            - Using the correct number of arguments for functions\n            - Casting to the correct data type\n            - Using the proper columns for joins\n            \n            If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.\n            \n            Output the final SQL query only."))

In [117]:
chain=create_sql_query_chain(llm=llm,db=db)
fullChain=chain | prompt | llm | StrOutputParser() | RunnableLambda(lambda d: d.replace("sqlite","").replace("```","").strip())

In [118]:
question="What's the average Invoice from an American customer whose Fax is missing since 2003 but before 2010"
responseQuery=fullChain.invoke(
            input={"question":question}
                )

In [119]:
print(responseQuery)

SELECT AVG(T1.Total) FROM "Invoice" AS T1 INNER JOIN "Customer" AS T2 ON T1.CustomerId = T2.CustomerId WHERE T2.Country = 'USA' AND T2.Fax IS NULL AND T1.InvoiceDate BETWEEN '2003-01-01' AND '2010-01-01'


In [120]:
db.run(command=responseQuery)

'[(6.633,)]'

<h3>Using Runnable Passthrough </h3>

In [121]:
fullChain=RunnablePassthrough.assign(query=chain) | prompt | llm | StrOutputParser() | RunnableLambda(lambda d: d.replace("sqlite","").replace("```","").strip())

In [122]:
responseQuery=fullChain.invoke(
            input={"question":question}
                )
print(responseQuery)

SELECT AVG(T1.Total) FROM "Invoice" AS T1 INNER JOIN "Customer" AS T2 ON T1.CustomerId = T2.CustomerId WHERE T2.Country = 'USA' AND T2.Fax IS NULL AND T1.InvoiceDate BETWEEN '2003-01-01' AND '2010-01-01'


In [123]:
db.run(command=responseQuery)

'[(6.633,)]'

<h3>Validating Query with a Single Chain</h3>

In [320]:
system=SystemMessagePromptTemplate.from_template(
    template="""
            You are a {dialect} expert. Given an input question, creat a syntactically correct {dialect} query to run.
            Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per {dialect}. You can order the results to return the most informative data in the database.
            Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
            
            Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
            Pay attention to use date('now') function to get the current date, if the question involves "today".
            Properly quote identifiers (in where clause) with double quotes ("), ex., <<Where C."Country"="Canada">>
            Only use the following tables:
            {table_info}
            
            Write an initial draft of the query. Then double check the {dialect} query for common mistakes, including:
            - Do Not use any preamble like : {{```}} , {{<<}}, {{>>}} or {{SQL}}: 
            - Using NOT IN with NULL values
            - Using UNION when UNION ALL should have been used
            - Using BETWEEN for exclusive ranges
            - Data type mismatch in predicates
            - Using the correct number of arguments for functions
            - Casting to the correct data type
            - Using the proper columns for joins
            
            Example Query:
                SELECT c.country
                FROM "Invoice" i
                JOIN "Customer" c ON i."CustomerId" = c."CustomerId"
                WHERE c."Country" = "Canada"                
            
            Use format:
            
            First draft: <<FIRST_DRAFT_QUERY>>
            Final answer: <<FINAL_ANSWER_QUERY>>
            """
)
human=HumanMessagePromptTemplate.from_template(
    template="{input}"
)

In [321]:
db.get_context()['table_names']

'Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track'

In [322]:
prompt=ChatPromptTemplate.from_messages(
    messages=[
        system,
        human
    ]
).partial(dialect=db.dialect,top_k=5,table_info=db.get_context()['table_names'])

In [323]:
prompt.messages[0].pretty_print()



            You are a [33;1m[1;3m{dialect}[0m expert. Given an input question, creat a syntactically correct [33;1m[1;3m{dialect}[0m query to run.
            Unless the user specifies in the question a specific number of examples to obtain, query for at most [33;1m[1;3m{top_k}[0m results using the LIMIT clause as per [33;1m[1;3m{dialect}[0m. You can order the results to return the most informative data in the database.
            Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
            
            Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
            Pay attention to use date('now') function to get the current date, if the question involves "today".
            Properly quote i

In [324]:
# llm=ChatOpenAI(model="gpt-3.5-turbo")
llm=GoogleGenerativeAI(model="gemini-1.5-pro")
chain=create_sql_query_chain(llm=llm,prompt=prompt,db=db)

In [327]:
fullChain=RunnablePassthrough.assign(input=chain) | RunnableLambda(lambda k: k['input'].split("Final answer: ")[1].replace("<<","").replace(">>",""))

In [328]:
response=fullChain.invoke(input={'question':question})
print(response)

SELECT AVG("Total") FROM "Invoice" AS T1 INNER JOIN "Customer" AS T2 ON T1."CustomerId" = T2."CustomerId" WHERE T2."Country" = 'USA' AND T2."Fax" IS NULL AND T1."InvoiceDate" BETWEEN '2003-01-01' AND '2009-12-31' LIMIT 5


In [329]:
# Problem persists
db.run(command=response)

'[(6.633,)]'