In [1]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_community.utilities import SQLDatabase
from langchain_openai import ChatOpenAI
import os

template = """
You are a SQL expert. Based on the table schema below, write just the SQL
query without the results that would answer the user's question.:
{schema}
Question: {question}
SQLQuery:
"""

prompt = ChatPromptTemplate.from_template(template)
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
db = SQLDatabase.from_uri("sqlite:///../database/northwind.db")

def get_schema(_):
    return db.get_table_info()

def run_query(query):
    return db.run(query)

sql_response = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | llm.bind(stop=["\nSQLResult:"])
    | StrOutputParser()
)

sql_response.invoke({"question": "How many employees are there?"})

'```sql\nSELECT COUNT(*) AS EmployeeCount FROM Employees;\n```'

In [3]:
sql_response.invoke({"question": "How many employees have been tenured for more than 11 years?"})

"```sql\nSELECT COUNT(*) AS TenuredEmployees\nFROM Employees\nWHERE (julianday('now') - julianday(HireDate)) / 365 > 11;\n```"

In [5]:
def extract_sql(text: str) -> str:
    """Extract SQL from Markdown code blocks."""
    if "```sql" in text:
        return text.split("```sql")[1].split("```")[0].strip()
    if "```" in text:
        return text.split("```")[1].split("```")[0].strip()
    return text.strip()

def run_query(query: str) -> str:
    clean_query = extract_sql(query)
    return db.run(clean_query)

template_response = """Based on the table schema below, question, SQL query, and SQL result,
write a natural language response:
{schema}

Question: {question}
SQL Query: {query}
SQL Result: {sql_result}"""

prompt_response = ChatPromptTemplate.from_template(template_response)

full_chain = (
    RunnablePassthrough.assign(query=sql_response)
    .assign(
        schema=get_schema,
        sql_result=lambda x: run_query(x["query"]),
    )
    | prompt_response
    | llm
    | StrOutputParser()
)

result = full_chain.invoke({"question": "How many employees have been tenured for more than 11 years?"})
print(result)

Based on the query executed, there are a total of 9 employees who have been tenured for more than 11 years.
