In [2]:
import os
from dotenv import load_dotenv

In [3]:
# Load environment variables from .env file
load_dotenv()
# Retrieve API key from environment variables
api_key = os.getenv('OPENAI_API_KEY')

## Step 3: Create SQL Chain Prompt

In [4]:
from langchain_core.prompts import ChatPromptTemplate

template = """
Based on the table schema below, write a SQL query that would answer the user's question:
{schema}

Question: {question}
SQL Query:
"""

# prompt = ChatPromptTemplate.from_template(template) # Per documentation, from_template() deprecated. Recommend from_messages()
prompt = ChatPromptTemplate.from_messages([template])

In [5]:
prompt.format(schema="my schema", question="how many users are there?")

"Human: \nBased on the table schema below, write a SQL query that would answer the user's question:\nmy schema\n\nQuestion: how many users are there?\nSQL Query:\n"

## Step 4: Load MySQL Database in Python

In [6]:
from langchain_community.utilities import SQLDatabase

db_uri = f"mysql+mysqlconnector://root:{os.getenv('MYSQL_PASSWORD')}@localhost:3306/Chinook"
db = SQLDatabase.from_uri(db_uri)

In [7]:
db.run("SELECT * FROM Album LIMIT 5")

"[(1, 'For Those About To Rock We Salute You', 1), (2, 'Balls to the Wall', 2), (3, 'Restless and Wild', 2), (4, 'Let There Be Rock', 1), (5, 'Big Ones', 3)]"

## Step 5: Create SQL Chain
- SQL Chain takes in two entities: user question + database schema.
- SQL Chain feeds two entities to LLM.
- SQL Chain outputs corresponding SQL query.

In [10]:
def get_schema(_):
    '''
    No input needed for this function
    However, we have to put a placeholder _ here
    because when we assign this fcn to RunnablePassthrough, an argument is required.
    _ works as an argument.
    '''
    return db.get_table_info()

In [9]:
# test get_schema()
get_schema(None)

'\nCREATE TABLE `Album` (\n\t`AlbumId` INTEGER NOT NULL, \n\t`Title` VARCHAR(160) CHARACTER SET utf8mb3 COLLATE utf8mb3_general_ci NOT NULL, \n\t`ArtistId` INTEGER NOT NULL, \n\tPRIMARY KEY (`AlbumId`), \n\tCONSTRAINT `FK_AlbumArtistId` FOREIGN KEY(`ArtistId`) REFERENCES `Artist` (`ArtistId`)\n)ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE utf8mb4_0900_ai_ci\n\n/*\n3 rows from Album table:\nAlbumId\tTitle\tArtistId\n1\tFor Those About To Rock We Salute You\t1\n2\tBalls to the Wall\t2\n3\tRestless and Wild\t2\n*/\n\n\nCREATE TABLE `Artist` (\n\t`ArtistId` INTEGER NOT NULL, \n\t`Name` VARCHAR(120) CHARACTER SET utf8mb3 COLLATE utf8mb3_general_ci, \n\tPRIMARY KEY (`ArtistId`)\n)ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE utf8mb4_0900_ai_ci\n\n/*\n3 rows from Artist table:\nArtistId\tName\n1\tAC/DC\n2\tAccept\n3\tAerosmith\n*/\n\n\nCREATE TABLE `Customer` (\n\t`CustomerId` INTEGER NOT NULL, \n\t`FirstName` VARCHAR(40) CHARACTER SET utf8mb3 COLLATE utf8mb3_general_ci NOT NULL, \n\t`LastNa

In [16]:
# sql query is passed through the chain as string
from langchain_core.output_parsers import StrOutputParser
# pass get_schema() as runnable in the chain
from langchain_core.runnables import RunnablePassthrough
# LLM model inside SQL Chain module
from langchain_openai import ChatOpenAI

# llm = ChatOpenAI()
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)

sql_chain = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | llm.bind(stop="\nSQL Result:")
    | StrOutputParser()
)

ImportError: cannot import name 'InvalidToolCall' from 'langchain_core.messages' (/opt/anaconda3/envs/chat-with-mysql-practice/lib/python3.10/site-packages/langchain_core/messages/__init__.py)

In [None]:
sql_chain.invoke({"question": "how many artists are there?"})