In [2]:
#!pip install streamlit 

In [3]:
#! pip install pypyodbc

In [4]:
#! pip install pyodbc

In [5]:
#!pip install openai

In [6]:
#pip install -U langchain-openai

In [7]:
#!pip install --upgrade --quiet langchain-community langchainhub langgraph

In [5]:
import os
import pyodbc

from langchain_community.utilities import SQLDatabase


conn_str = "mssql+pyodbc:///?odbc_connect=DRIVER={ODBC Driver 17 for SQL Server};SERVER=127.0.0.1;DATABASE=AdventureWorksDW2022;uid=sa;pwd=Welk0m01!;encrypt=yes;trustServerCertificate=yes"

db = SQLDatabase.from_uri(conn_str)

print(db.dialect)
print(db.get_usable_table_names())

mssql
['AdventureWorksDWBuildVersion', 'DatabaseLog', 'DimAccount', 'DimCurrency', 'DimCustomer', 'DimDate', 'DimDepartmentGroup', 'DimEmployee', 'DimGeography', 'DimOrganization', 'DimProduct', 'DimProductCategory', 'DimProductSubcategory', 'DimPromotion', 'DimReseller', 'DimSalesReason', 'DimSalesTerritory', 'DimScenario', 'FactAdditionalInternationalProductDescription', 'FactCallCenter', 'FactCurrencyRate', 'FactFinance', 'FactInternetSales', 'FactInternetSalesReason', 'FactProductInventory', 'FactResellerSales', 'FactSalesQuota', 'FactSurveyResponse', 'NewFactCurrencyRate', 'ProspectiveBuyer', 'sysdiagrams']


In [6]:
from typing_extensions import TypedDict


class State(TypedDict):
    question: str
    query: str
    result: str
    answer: str

In [None]:
import getpass
import os

if not os.environ.get("OPENAI_API_KEY"):
  os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter API key for OpenAI: ")

from langchain.chat_models import init_chat_model

llm = init_chat_model("gpt-4o-mini", model_provider="openai")

In [11]:
from langchain import hub

query_prompt_template = hub.pull("langchain-ai/sql-query-system-prompt")

assert len(query_prompt_template.messages) == 1
query_prompt_template.messages[0].pretty_print()




Given an input question, create a syntactically correct [33;1m[1;3m{dialect}[0m query to run to help find the answer. Unless the user specifies in his question a specific number of examples they wish to obtain, always limit your query to at most [33;1m[1;3m{top_k}[0m results. You can order the results by a relevant column to return the most interesting examples in the database.

Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.

Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.

Only use the following tables:
[33;1m[1;3m{table_info}[0m

Question: [33;1m[1;3m{input}[0m


In [12]:
from typing_extensions import Annotated


class QueryOutput(TypedDict):
    """Generated SQL query."""

    query: Annotated[str, ..., "Syntactically valid SQL query."]


def write_query(state: State):
    """Generate SQL query to fetch information."""
    prompt = query_prompt_template.invoke(
        {
            "dialect": db.dialect,
            "top_k": 10,
            "table_info": db.get_table_info(),
            "input": state["question"],
        }
    )
    structured_llm = llm.with_structured_output(QueryOutput)
    result = structured_llm.invoke(prompt)
    return {"query": result["query"]}

In [13]:
#write_query({"question": "How many Employees are there?"})

In [14]:
from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTool


def execute_query(state: State):
    """Execute SQL query."""
    execute_query_tool = QuerySQLDatabaseTool(db=db)
    return {"result": execute_query_tool.invoke(state["query"])}

In [15]:
def generate_answer(state: State):
    """Answer question using retrieved information as context."""
    prompt = (
        "Given the following user question, corresponding SQL query, "
        "and SQL result, answer the user question.\n\n"
        f'Question: {state["question"]}\n'
        f'SQL Query: {state["query"]}\n'
        f'SQL Result: {state["result"]}'
    )
    response = llm.invoke(prompt)
    return {"answer": response.content}

In [16]:
from langgraph.graph import START, StateGraph

graph_builder = StateGraph(State).add_sequence(
    [write_query, execute_query, generate_answer]
)
graph_builder.add_edge(START, "write_query")
graph = graph_builder.compile()

In [17]:
from langgraph.checkpoint.memory import MemorySaver

memory = MemorySaver()
graph = graph_builder.compile(checkpointer=memory, interrupt_before=["execute_query"])

# Now that we're using persistence, we need to specify a thread ID
# so that we can continue the run after review.
config = {"configurable": {"thread_id": "1"}}

In [18]:
for step in graph.stream(
    {"question": "Give me Top 3 Territories by highest total sale value?"},
    config,
    stream_mode="updates",
):
    print(step)

try:
    user_approval = input("Do you want to go to execute query? (yes/no): ")
except Exception:
    user_approval = "no"

if user_approval.lower() == "yes":
    # If approved, continue the graph execution
    for step in graph.stream(None, config, stream_mode="updates"):
        print(step)
else:
    print("Operation cancelled by user.")

{'write_query': {'query': 'SELECT TOP 3 st.SalesTerritoryRegion, SUM(fs.SalesAmount) AS TotalSales\nFROM FactResellerSales fs\nJOIN DimSalesTerritory st ON fs.SalesTerritoryKey = st.SalesTerritoryKey\nGROUP BY st.SalesTerritoryRegion\nORDER BY TotalSales DESC;'}}
{'__interrupt__': ()}
{'execute_query': {'result': "[('Southwest', Decimal('18466458.7925')), ('Canada', Decimal('14377925.5965')), ('Northwest', Decimal('12435076.0001'))]"}}
{'generate_answer': {'answer': 'The top 3 territories by highest total sale value are:\n\n1. Southwest - $18,466,458.79\n2. Canada - $14,377,925.60\n3. Northwest - $12,435,076.00'}}
