In [26]:
# from langchain_community.utilities import SQLDatabase
# from pathlib import Path


# db_path = "nba_roster.db"
# # rel = db_path.relative_to(Path.cwd())
# db_string = f"sqlite:///{db_path}"
# db = SQLDatabase.from_uri(db_string, sample_rows_in_table_info=2)

In [40]:
from pathlib import Path

from langchain.memory import ConversationBufferMemory
from langchain.pydantic_v1 import BaseModel
from langchain_community.chat_models import ChatOllama, ChatOpenAI
from langchain_community.utilities import SQLDatabase
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)

# Use the current working directory instead of __file__
db_path = Path.cwd() / "nba_roster.db"
rel = db_path.relative_to(Path.cwd())
db_string = f"sqlite:///{rel}"
db = SQLDatabase.from_uri(db_string, sample_rows_in_table_info=2)


In [42]:
from langchain_community.utilities import SQLDatabase

print(db.dialect)
print(db.get_usable_table_names())


sqlite
['nba_roster']


In [43]:
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model="gpt-4o-mini")

In [44]:
from langchain.chains import create_sql_query_chain

chain = create_sql_query_chain(llm, db)
response = chain.invoke({"question": "How many employees are there."})
response

'SQLQuery: SELECT COUNT(*) AS "EmployeeCount" FROM nba_roster;'

In [45]:
db.run("SELECT COUNT(*) AS 'employee_count' FROM nba_roster;")

'[(600,)]'

In [49]:
def get_schema(_):
    return db.get_table_info()


def run_query(query):
    return db.run(query)

In [57]:
# Prompt

template = """Based on the table schema below, write a SQL query that would answer the user's question:
{schema}

Question: {question}
SQL Query:"""  # noqa: E501
prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "Given an input question, convert it to a SQL query. No pre-amble. Don't use \n character. Don't print ```. Don't print anything but the sql query syntax."),
        ("human", template),
    ]
)

memory = ConversationBufferMemory(return_messages=True)

In [58]:
sql_chain = (
    RunnablePassthrough.assign(
        schema=get_schema,
    )
    | prompt
    | llm.bind(stop=["\nSQLResult:"])
    | StrOutputParser()
    | (lambda x: x.split("\n\n")[0])
)


In [59]:
response = sql_chain.invoke({"question": "How many employees are there."})


In [60]:
response

'SELECT COUNT(*) FROM nba_roster;'

In [54]:
db.run("SELECT COUNT(*) AS number_of_employees FROM nba_roster;")

'[(600,)]'

In [55]:
response = chain.invoke({"question": "How many employees are there."})


In [56]:
response

'SQLQuery: SELECT COUNT(*) AS "EmployeeCount" FROM nba_roster;'

In [62]:

# Chain to answer
template = """Based on the table schema below, question, sql query, and sql response, write a natural language response:
{schema}

Question: {question}
SQL Query: {query}
SQL Response: {response}"""  # noqa: E501
prompt_response = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "Given an input question and SQL response, convert it to a natural "
            "language answer. No pre-amble.",
        ),
        ("human", template),
    ]
)

In [63]:

# Supply the input types to the prompt
class InputType(BaseModel):
    question: str


sql_answer_chain = (
    RunnablePassthrough.assign(query=sql_chain).with_types(input_type=InputType)
    | RunnablePassthrough.assign(
        schema=get_schema,
        response=lambda x: db.run(x["query"]),
    )
    | RunnablePassthrough.assign(
        answer=prompt_response | ChatOpenAI() | StrOutputParser()
    )
    | (lambda x: f"Question: {x['question']}\n\nAnswer: {x['answer']}")
)


In [64]:
sql_answer_chain.invoke({"question": "How many employees are there?"})

'Question: How many employees are there?\n\nAnswer: There are 600 players in the NBA roster.'