In [82]:
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from langchain.output_parsers import PydanticOutputParser
from langchain_community.utilities import SQLDatabase
from langchain_core.prompts import ChatPromptTemplate
from langchain.chat_models import init_chat_model
from langgraph.graph import START, StateGraph
from pydantic import BaseModel, Field

In [113]:
db = SQLDatabase.from_uri("duckdb:////Users/viswhavijay/Developer/copilot/analytics.duckdb")

db.run('SELECT COUNT(*) FROM SUPERSTORE')



'[(9994,)]'

In [133]:
class State(BaseModel):
    question : str
    query : str = ''
    result : str = ''
    answer : str = ''

class QueryOutput(BaseModel):
    query: str = Field(description = 'Syntactically valid SQL query. Dates must be in YYYY-MM-DD format.')
    explanation: str = Field(description = 'Any additional explanation, apart from the queries')

In [134]:
model = init_chat_model(model='dolphin-llama3:8b', model_provider='ollama')

In [135]:
system_message = """
Given an input question, create a syntactically correct {dialect} query to
run to help find the answer. Unless the user specifies in his question a
specific number of examples they wish to obtain, always limit your query to
at most {top_k} results. You can order the results by a relevant column to
return the most interesting examples in the database.

Never query for all the columns from a specific table, only ask for a the
few relevant columns given the question.

Pay attention to use only the column names that you can see in the schema
description. Be careful to not query for columns that do not exist. Also,
pay attention to which column is in which table.

Only use the following tables:
{table_info}

Return ONLY valid JSON. Do not include any explanations or text outside the JSON object.
{format_instructions}
"""

user_prompt = "Question: {input}"

query_prompt = ChatPromptTemplate(
    [('system', system_message), ('user', user_prompt)]
)

In [136]:
def write_query(state : State):
    parser = PydanticOutputParser(pydantic_object=QueryOutput)
    prompt = query_prompt.invoke(
        {
            'dialect': db.dialect,
            'top_k': 10,
            'table_info': db.get_table_info(),
            'input': state.question,
            'format_instructions': parser.get_format_instructions()
        }
    )

    result = model.invoke(prompt)
    return {"query": parser.parse(result.content).query}

def execute_query(state: State):
    query_tool = QuerySQLDataBaseTool(db=db)
    return {'result': query_tool.invoke(state.query)}

def generate_answer(state: State):
    prompt = (
        "Given the following user question, corresponding SQL query, "
        "and SQL result, answer the user question.\n\n"
        f"Question: {state.question}\n"
        f"SQL Query: {state.query}\n"
        f"SQL Result: {state.result}"
    )

    response = model.invoke(prompt)
    return {'answer': response.content}

In [137]:
graph_builder = StateGraph(State).add_sequence(
    [write_query, execute_query, generate_answer]
)    

graph_builder.add_edge(START, 'write_query')
graph = graph_builder.compile()

In [140]:
for step in graph.stream(
    State(question = "Which category of product gets sold most?"), stream_mode="updates"
):
    print(step)

{'write_query': {'query': 'SELECT category, SUM(sales) AS total_sales FROM superstore GROUP BY category ORDER BY total_sales DESC LIMIT 1'}}
{'execute_query': {'result': "[('Technology', 836154.0999999908)]"}}
{'generate_answer': {'answer': "The category of product that gets sold the most is 'Technology'."}}
