# Chinook DB

See https://github.com/arjunchndr/Analyzing-Chinook-Database-using-SQL-and-Python for the sample DB. And inspiration from https://python.langchain.com/docs/use_cases/sql/quickstart/

In [1]:
# libraries and models setup
import os
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings
from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv())

# Chat Model definition
llm = AzureChatOpenAI(
    openai_api_version="2023-09-01-preview",
    azure_endpoint=os.getenv('AZURE_API_ENDPOINT'),
    api_key=os.getenv('AZURE_OPENAI_KEY'),
    azure_deployment=os.getenv('OPENAI_DEPLOYMENT_NAME'),
    model_name=os.getenv('OPENAI_MODEL_NAME'),
    model_version=os.getenv('OPENAI_API_VERSION'),
    temperature=.7
)

# Embeddings model definition
embedding_model = AzureOpenAIEmbeddings(
    openai_api_version="2023-09-01-preview",
    azure_endpoint=os.getenv('AZURE_API_ENDPOINT'),
    api_key=os.getenv('AZURE_OPENAI_KEY'),
    azure_deployment=os.getenv('OPENAI_DEPLOYMENT_NAME_EMBEDDING')
)

## Creating the prompt template

In [2]:
from langchain.prompts import ChatPromptTemplate

template = """Based on the table schema below, write a SQL query that would answer the user's question:
{schema}

Question: {question}
SQL Query:"""
prompt = ChatPromptTemplate.from_template(template)


## Connect to the DB

In [5]:
from langchain.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///./data/chinook.db")

def get_schema(_):
    return db.get_table_info()

def run_query(query):
    return db.run(query)

## Chain to output SQL query

In [7]:
# V1
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough


sql_response = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | llm.bind(stop=["\nSQLResult:"])
    | StrOutputParser()
)

sql_response.invoke({"question": "List all the tables in the database?"})

"SELECT name FROM sqlite_master WHERE type='table'; -- Assuming SQLite database, but syntax may vary for other database systems"

In [21]:
# V2 : Update
from langchain.chains import create_sql_query_chain

chain = create_sql_query_chain(llm, db)
response = chain.invoke({"question": "How many employees are there"})
response

'SELECT COUNT(*) as num_employees FROM employee'

In [22]:
# The query can be executed 
db.run(response)

'[(8,)]'

Chaining the 2 previous steps:

In [23]:
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool

execute_query = QuerySQLDataBaseTool(db=db)
write_query = create_sql_query_chain(llm, db)
chain = write_query | execute_query
chain.invoke({"question": "How many employees are there"})

'[(8,)]'

In [24]:
# getting the completed prompt
chain.get_prompts()[0].pretty_print()

You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result

## Chain to create natural answers

In [11]:
# V1
# template string
template = """Based on the table schema below, question, sql query, and sql response, write a natural language response:
{schema}

Question: {question}
SQL Query: {query}
SQL Response: {response}
"""

# prompt template
prompt_response = ChatPromptTemplate.from_template(template)

# chain
full_chain = (
    RunnablePassthrough.assign(query=sql_response)
    | RunnablePassthrough.assign(
        schema=get_schema,
        response=lambda x: db.run(x["query"]),
    )
    | prompt_response
    | llm
)

# invoke chan
results = full_chain.invoke({"question": "How many employees are there?"})
print(results.content)

There are 8 employees in the database.


In [25]:
# V2
from operator import itemgetter

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough

answer_prompt = PromptTemplate.from_template(
    """Given the following user question, corresponding SQL query, and SQL result, answer the user question.

Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer: """
)

answer = answer_prompt | llm | StrOutputParser()
chain = (
    RunnablePassthrough.assign(query=write_query).assign(
        result=itemgetter("query") | execute_query
    )
    | answer
)

chain.invoke({"question": "How many employees are there"})

'There are 8 employees.'

## Additionnal questions

In [14]:
template = """Based on the table schema below, question, sql query, and sql response, write a natural language response:
{schema}

Question: {question}
SQL Query: {query}
SQL Response: {response}"""
prompt_response = ChatPromptTemplate.from_template(template)

full_chain = (
    RunnablePassthrough.assign(query=sql_response)
    | RunnablePassthrough.assign(
        schema=get_schema,
        response=lambda x: db.run(x["query"]),
    )
    | prompt_response
    | llm
)

results =full_chain.invoke({"question": "Compute the number tracks sold per genre?"})
print(results.content)

The SQL query shows the number of tracks sold per genre. The query joins the genre, track, and invoice_line tables and groups the results by genre name. The response shows the name of each genre and the total quantity of tracks sold for that genre. For example, the 'Rock' genre has sold a total of 2635 tracks.


In [19]:
full_chain.invoke({"question": "what are the top 5 sales per country?"}).content

'The top 5 sales per country are as follows: USA with a total of 1040.49, Canada with 535.59, Brazil with 427.68, France with 389.07, and Germany with 334.62. These figures were obtained from the invoice line, invoice, and customer tables, and represent the total sales in each country.'

In [20]:
sql_response.invoke({"question": "what are the top 5 sales per country?"})

'SELECT c.country, SUM(il.unit_price * il.quantity) AS total_sales\nFROM invoice_line il\nJOIN invoice i ON il.invoice_id = i.invoice_id\nJOIN customer c ON i.customer_id = c.customer_id\nGROUP BY c.country\nORDER BY total_sales DESC\nLIMIT 5;'

## SQL Agent

Voir : https://python.langchain.com/docs/integrations/providers/cnosdb/#sql-database-chain

Prompt techniques are especially useful when using agents. take the example below. the first part is paramount to get to the result.

In [46]:
from langchain.agents import create_sql_agent
from langchain_community.agent_toolkits import SQLDatabaseToolkit

toolkit = SQLDatabaseToolkit(db=db, llm=llm)
agent = create_sql_agent(llm=llm, toolkit=toolkit, verbose=True)

agent.invoke(
    "List all the tables names in the database. Which country's customers spent the most?"
)



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3mWe need to first get a list of tables and then identify which table has customer spending data. We can use sql_db_list_tables to get the list of tables and then use sql_db_schema to identify which table has customer spending data. Once we know the table, we can use sql_db_query to write a query to get the country with the highest customer spending.
Action 1: sql_db_list_tables
Action Input: ""[0m[38;5;200m[1;3malbum, artist, customer, employee, genre, invoice, invoice_line, media_type, playlist, playlist_track, track[0m[32;1m[1;3mWe need to use sql_db_schema to identify which table has customer spending data.
Action 2: sql_db_schema
Action Input: invoice_line[0m[33;1m[1;3m
CREATE TABLE invoice_line (
	invoice_line_id INTEGER NOT NULL, 
	invoice_id INTEGER NOT NULL, 
	track_id INTEGER NOT NULL, 
	unit_price NUMERIC(10, 2) NOT NULL, 
	quantity INTEGER NOT NULL, 
	PRIMARY KEY (invoice_line_id), 
	FOREIGN KEY(track_

{'input': "List all the tables names in the database. Which country's customers spent the most?",
 'output': 'USA.'}