In [149]:
import langchain
from dotenv import load_dotenv,find_dotenv
from langchain_core.messages import SystemMessage
from langchain_community.utilities import SQLDatabase
from langchain_core.runnables import RunnablePassthrough,RunnableLambda,RunnableParallel,Runnable
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import (PromptTemplate, ChatPromptTemplate,FewShotPromptTemplate,
                            MessagesPlaceholder,SystemMessagePromptTemplate,HumanMessagePromptTemplate)
from langchain_community.vectorstores import FAISS, Chroma
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain.embeddings import SentenceTransformerEmbeddings
from langchain_google_genai import GoogleGenerativeAI,ChatGoogleGenerativeAI
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from langchain_community.agent_toolkits import create_sql_agent
from langchain_openai import ChatOpenAI
from langchain.chains.sql_database.prompt import SQL_PROMPTS
from langchain.chains import create_sql_query_chain
from operator import itemgetter
from langchain_core.pydantic_v1 import BaseModel,Field
from langchain.chains.openai_tools import create_extraction_chain_pydantic
from typing import List

In [150]:
load_dotenv(find_dotenv("D:\LLM Courses\Master Langchain Udemy\.env"))

True

In [216]:
# llm=ChatGoogleGenerativeAI(model="gemini-1.5-pro",temperature=0.3)
llm=ChatOpenAI(model="gpt-3.5-turbo")
db=SQLDatabase.from_uri(database_uri="sqlite:///db/chinook.db/chinook.db")

In [152]:
class Table(BaseModel):
    "Table in SQL Database."
    name:str=Field(description="Name of Table in SQL Database")

In [153]:
db.get_usable_table_names()

['Album',
 'Artist',
 'Customer',
 'Employee',
 'Genre',
 'Invoice',
 'InvoiceLine',
 'MediaType',
 'Playlist',
 'PlaylistTrack',
 'Track']

In [154]:
tableNames="\n".join(db.get_usable_table_names())
print(tableNames)

Album
Artist
Customer
Employee
Genre
Invoice
InvoiceLine
MediaType
Playlist
PlaylistTrack
Track


In [155]:
system = f"""
            Return the names of ALL the SQL tables that MIGHT be relevant to the user question.
            The tables are:

            {tableNames}

            Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed."""

In [156]:
tableChain=create_extraction_chain_pydantic(pydantic_schemas=Table,llm=llm,system_message=system)

In [157]:
tableChain.invoke(input={"input":"What are all the genres of Alanis Morisette songs"})

[Table(name='Genre'), Table(name='Artist'), Table(name='Track')]

In [158]:
system="""
    Return the names of the SQL tables that are relevant to the user question. \
    The tables are:

    Music
    Business
"""

In [159]:
categoryChain=create_extraction_chain_pydantic(pydantic_schemas=Table,llm=llm,system_message=system)

In [160]:
categoryChain.invoke(input={"input":"What are all the genres of Alanis Morisette songs"})

[Table(name='Music'), Table(name='Business')]

In [161]:
categoryChain

ChatPromptTemplate(input_variables=['input'], messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], template='\n    Return the names of the SQL tables that are relevant to the user question.     The tables are:\n\n    Music\n    Business\n')), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['input'], template='{input}'))])
| RunnableBinding(bound=ChatOpenAI(client=<openai.resources.chat.completions.Completions object at 0x0000023E122A8690>, async_client=<openai.resources.chat.completions.AsyncCompletions object at 0x0000023E122AB590>, openai_api_key=SecretStr('**********'), openai_proxy=''), kwargs={'tools': [{'type': 'function', 'function': {'name': 'Table', 'description': 'Table in SQL Database.', 'parameters': {'type': 'object', 'properties': {'name': {'description': 'Name of Table in SQL Database', 'type': 'string'}}, 'required': ['name']}}}]})
| PydanticToolsParser(tools=[<class '__main__.Table'>])

In [162]:
def getTables(categories:List[Table]) -> List[str]:
    tables=[]
    for category in categories:
        if category.name == "Music":
            tables.extend(
                [
                    "Album","Artist","Genre",
                    "MediaType","Playlist","PlaylistTrack","Track",
                ]
            )
        elif category.name=="Business":
            tables.extend(
                [
                    "Customer", "Employee", "Invoice", "InvoiceLine"
                ]
            )
    return tables

In [163]:
tableChain=categoryChain | getTables
tableChain.invoke(input={"input":"What are all the genres of Alanis Morisette songs"})

['Album',
 'Artist',
 'Genre',
 'MediaType',
 'Playlist',
 'PlaylistTrack',
 'Track',
 'Customer',
 'Employee',
 'Invoice',
 'InvoiceLine']

In [256]:
queryChain=create_sql_query_chain(llm=llm,db=db)
fullChain=RunnablePassthrough.assign(input=itemgetter("question")).assign(table_info=itemgetter("input")|tableChain)|queryChain

In [260]:
queryChain.get_prompts()

[PromptTemplate(input_variables=['input', 'table_info'], partial_variables={'top_k': '5'}, template='You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.\nUnless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.\nNever query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.\nPay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\nPay attention to use date(\'now\') function to get the current date, if the question i

In [248]:
response=fullChain.invoke(input={"question":"What are all the genres of Alanis Morissette songs"})
print(response)

SELECT DISTINCT "g"."Name"
FROM "Track" AS "t"
JOIN "Genre" AS "g" ON "t"."GenreId" = "g"."GenreId"
JOIN "Album" AS "a" ON "t"."AlbumId" = "a"."AlbumId"
JOIN "Artist" AS "ar" ON "a"."ArtistId" = "ar"."ArtistId"
WHERE "ar"."Name" = 'Alanis Morissette'
ORDER BY "g"."Name";


In [249]:
db.run(command=response)

"[('Rock',)]"

<h3> Trying with a Prompt</h3>

In [311]:
## Note: The 'question' is not seen by the llm (here query chain, it is the raw input from the user
prompt=ChatPromptTemplate.from_template(
    template="""
        You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.\nUnless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per SQLite. 
        You can order the results to return the most informative data in the database.
        Never query for all columns from a table. 
        You must query only the columns that are needed to answer the question. 
        Wrap each column name in double quotes (") to denote them as delimited identifiers.
        Pay attention to use only the column names you can see in the tables below. 
        Be careful to not query for columns that do not exist. 
        Also, pay attention to which column is in which table.
        Pay attention to use date(\'now\') function to get the current date, if the question involves "today".
        Use the following format:\n\nQuestion:
        ######################
        User's Question: {input}: 
        Only Consider Tables from: {table_info}
        SQL Query:
        #####################     
    """
)

In [312]:
queryChain=create_sql_query_chain(llm=llm,db=db,prompt=prompt)

In [316]:
fullChain=RunnablePassthrough.assign(input=itemgetter("question")).assign(table_info=itemgetter("input")|tableChain)|queryChain|RunnableLambda(lambda s: s.split("SQL Query:")[1])

In [317]:
response=fullChain.invoke(input={"question":"What are all the genres of Alanis Morissette songs","top_k":5})

In [318]:
print(response)


        SELECT DISTINCT "Genre"."Name"
        FROM "Track"
        JOIN "Genre" ON "Track"."GenreId" = "Genre"."GenreId"
        JOIN "Album" ON "Track"."AlbumId" = "Album"."AlbumId"
        JOIN "Artist" ON "Album"."ArtistId" = "Artist"."ArtistId"
        WHERE "Artist"."Name" = 'Alanis Morissette';


In [319]:
db.run(command=response)

"[('Rock',)]"