# Querying a SQL DB

In [41]:
import os
from dotenv import load_dotenv

OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

In [42]:
from langchain_core.prompts import ChatPromptTemplate

template = """Based on the table schema below, write a SQL query that would answer the user's question:
{schema}

Question: {question}
SQL Query:"""
prompt = ChatPromptTemplate.from_template(template)

In [43]:
from langchain_community.utilities import SQLDatabase

Chinook Sample DB : https://database.guide/2-sample-databases-sqlite/

Download SQLite sample database: https://www.sqlitetutorial.net/sqlite-sample-database/

In [44]:
db_uri = "sqlite:///./Data/chinook.db"
db = SQLDatabase.from_uri(db_uri)

In [45]:
print(db.get_usable_table_names())

['albums', 'artists', 'customers', 'employees', 'genres', 'invoice_items', 'invoices', 'media_types', 'playlist_track', 'playlists', 'tracks']


In [46]:
def get_schema(_):
    return db.get_table_info()

In [47]:
def run_query(query):
    return db.run(query)

In [48]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI

model = ChatOpenAI()

sql_response = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | model.bind(stop=["\nSQLResult:"])
    | StrOutputParser()
)

In [50]:
sql_response.invoke({"question": "How many artists are there?"})

'SELECT COUNT(*) AS artist_count FROM artists'

In [51]:
template = """Based on the table schema below, question, sql query, and sql response, write a natural language response:
{schema}

Question: {question}
SQL Query: {query}
SQL Response: {response}"""
prompt_response = ChatPromptTemplate.from_template(template)

In [52]:
full_chain = (
    RunnablePassthrough.assign(query=sql_response).assign(
        schema=get_schema,
        response=lambda x: db.run(x["query"]),
    )
    | prompt_response
    | model
)

In [53]:
full_chain.invoke({"question": "How many artists are there?"})

AIMessage(content='There are 275 artists in the database.')