In [4]:
import os
from dotenv import load_dotenv

In [5]:
# Load environment variables from .env file
# Now, OPENAI_API_KEY exists in the environment
# equivalent to: os.environ['OPENAI_API_KEY'] = "..."
load_dotenv()
# Retrieve/Access API key from environment variables
api_key = os.getenv('OPENAI_API_KEY')

## Step 3: Create SQL Chain Prompt

In [6]:
from langchain_core.prompts import ChatPromptTemplate

template = """
Based on the table schema below, write a SQL query that would answer the user's question:
{schema}

Question: {question}
SQL Query:
"""

# prompt = ChatPromptTemplate.from_template(template) # Per documentation, from_template() deprecated. Recommend from_messages()
prompt = ChatPromptTemplate.from_messages([template])

In [7]:
prompt.format(schema="my schema", question="how many users are there?")

"Human: \nBased on the table schema below, write a SQL query that would answer the user's question:\nmy schema\n\nQuestion: how many users are there?\nSQL Query:\n"

## Step 4: Load MySQL Database in Python

In [8]:
from langchain_community.utilities import SQLDatabase

db_uri = f"mysql+mysqlconnector://root:{os.getenv('MYSQL_PASSWORD')}@localhost:3306/Chinook"
db = SQLDatabase.from_uri(db_uri)

In [9]:
db.run("SELECT * FROM Album LIMIT 5")

"[(1, 'For Those About To Rock We Salute You', 1), (2, 'Balls to the Wall', 2), (3, 'Restless and Wild', 2), (4, 'Let There Be Rock', 1), (5, 'Big Ones', 3)]"

## Step 5: Create SQL Chain
- SQL Chain takes in two entities: user question + database schema.
- SQL Chain feeds two entities to LLM.
- SQL Chain outputs corresponding SQL query.

In [10]:
def get_schema(_):
    '''
    No input needed for this function
    However, we have to put a placeholder _ here
    because when we assign this fcn to RunnablePassthrough, an argument is required.
    _ works as an argument.
    '''
    return db.get_table_info()

In [11]:
# test get_schema()
get_schema(None)

'\nCREATE TABLE `Album` (\n\t`AlbumId` INTEGER NOT NULL, \n\t`Title` VARCHAR(160) CHARACTER SET utf8mb3 COLLATE utf8mb3_general_ci NOT NULL, \n\t`ArtistId` INTEGER NOT NULL, \n\tPRIMARY KEY (`AlbumId`), \n\tCONSTRAINT `FK_AlbumArtistId` FOREIGN KEY(`ArtistId`) REFERENCES `Artist` (`ArtistId`)\n)COLLATE utf8mb4_0900_ai_ci DEFAULT CHARSET=utf8mb4 ENGINE=InnoDB\n\n/*\n3 rows from Album table:\nAlbumId\tTitle\tArtistId\n1\tFor Those About To Rock We Salute You\t1\n2\tBalls to the Wall\t2\n3\tRestless and Wild\t2\n*/\n\n\nCREATE TABLE `Artist` (\n\t`ArtistId` INTEGER NOT NULL, \n\t`Name` VARCHAR(120) CHARACTER SET utf8mb3 COLLATE utf8mb3_general_ci, \n\tPRIMARY KEY (`ArtistId`)\n)COLLATE utf8mb4_0900_ai_ci DEFAULT CHARSET=utf8mb4 ENGINE=InnoDB\n\n/*\n3 rows from Artist table:\nArtistId\tName\n1\tAC/DC\n2\tAccept\n3\tAerosmith\n*/\n\n\nCREATE TABLE `Customer` (\n\t`CustomerId` INTEGER NOT NULL, \n\t`FirstName` VARCHAR(40) CHARACTER SET utf8mb3 COLLATE utf8mb3_general_ci NOT NULL, \n\t`LastNa

In [12]:
# sql query is passed through the chain as string
from langchain_core.output_parsers import StrOutputParser
# pass get_schema() as runnable in the chain
from langchain_core.runnables import RunnablePassthrough
# LLM model inside SQL Chain module
from langchain_openai import ChatOpenAI

llm = ChatOpenAI()

sql_chain = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | llm.bind(stop="\nSQL Result:")
    | StrOutputParser()
)

In [13]:
sql_chain.invoke({"question": "how many artists are there?"})

RateLimitError: Error code: 429 - {'error': {'message': 'You exceeded your current quota, please check your plan and billing details. For more information on this error, read the docs: https://platform.openai.com/docs/guides/error-codes/api-errors.', 'type': 'insufficient_quota', 'param': None, 'code': 'insufficient_quota'}}

## Step 6: Create run_query Function And Final Prompt

In [None]:
template = """
Based on the table schema below, user question, sql query, and sql response, write a natrual language response:
{schema}

User Question: {question}
SQL Query: {sql_query}
SQL Reponse: {sql_response}
"""

prompt = ChatPromptTemplate.from_messages([template])

In [14]:
def run_query(query):
    '''
    query: Str
    '''
    return db.run(query)

In [16]:
run_query("SELECT COUNT(ArtistId) AS TotalArtists FROM Artist;")

'[(275,)]'

## Step 7: Create Full Chain

In [None]:
full_chain = (
    RunnablePassthrough.assign(query=sql_chain).assign(
        schema=get_schema,
        response=lambda vars: run_query(vars["query"])
    )
    | prompt
    | llm
    | StrOutputParser()
)

In [None]:
full_chain.invoke({"question": "how many artists are there?"})