In [34]:


import re
from langchain_ollama import ChatOllama
from langchain_community.utilities import SQLDatabase
from langchain.chains import create_sql_query_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain.chat_models import init_chat_model
from typing_extensions import Annotated
from typing_extensions import TypedDict
from langchain_core.prompts import PromptTemplate

#TypedDicts like State define what fields are expected in the Python object 
#passed between steps, functions, or graph nodes.

class State(TypedDict):
    question: str
    query: str
    result: str
    answer: str

class QueryOutput(TypedDict):
    query: str

db = SQLDatabase.from_uri("sqlite:///finances.db")
 
 
llm = ChatOllama(model="llama3.1", temperature=0)
 
 

# 3) Prompt that shows schema + samples and enforces SQL-only
prompt = ChatPromptTemplate.from_messages([
    ("system",
     "You are a SQLite expert. Given an input question, create a syntactically correct {dialect} query to run.\n"
     "Unless otherwise specified, do not return more than {top_k} rows.\n"
     "Use only the following tables, columns, and sample rows:\n{table_info}\n\n"
     "Return only a single SQL statement. No explanations, no markdown, no code fences."),
    ("human", "{input}")
])



# 4) Chain with required variables (create_sql_query_chain injects table_info/top_k)
chain = create_sql_query_chain(llm, db, prompt=prompt)

# 5) Defensive extractor to strip any stray prose/markdown
SQL_PATTERN = re.compile(r'(?is)\\b(SELECT|WITH|INSERT|UPDATE|DELETE)\\b.*?;', re.DOTALL)

def extract_sql(text: str) -> str:
    cleaned = text.replace("``````", "").strip()
    m = SQL_PATTERN.search(cleaned.replace("\n", " "))
    return m.group(0).strip() if m else cleaned


system_message = """
Given an input question, create a syntactically correct {dialect} query to
run to help find the answer. Unless the user specifies in his question a
specific number of examples they wish to obtain, always limit your query to
at most {top_k} results. You can order the results by a relevant column to
return the most interesting examples in the database.

Never query for all the columns from a specific table, only ask for a the
few relevant columns given the question.

Pay attention to use only the column names that you can see in the schema
description. Be careful to not query for columns that do not exist. Also,
pay attention to which column is in which table.

Note: All dates in this database are stored as text strings in the exact format 'YYYY-MM-DD' (for example, '2025-07-14').
When filtering by month and year, use:
    strftime('%Y-%m', date_column) = 'YYYY-MM'
Do not use LIKE with '-MM-YYYY'; always use '%Y-%m' extracted from the date.

 
Only use the following tables:
{table_info}
"""

user_prompt = "Question: {input}"

query_prompt_template = ChatPromptTemplate(
    [("system", system_message), ("user", user_prompt)]
)
    

 

class QueryOutput(TypedDict):
    """Generated SQL query."""

    query: Annotated[str, ..., "Syntactically valid SQL query."]


def write_query(state: State):
    """Generate SQL query to fetch information."""
    prompt = query_prompt_template.invoke(
        {
            "dialect": db.dialect,
            "top_k": 10,
            "table_info": db.get_table_info(),
            "input": state["question"],
        }
    )
    structured_llm = llm.with_structured_output(QueryOutput)
    result = structured_llm.invoke(prompt)
    return {"query": result["query"]}
 

In [35]:
sqlQuery = write_query({"question": "How much money did I spend in February 2024?"})

In [36]:
from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTool


def execute_query(state: State):
    """Execute SQL query."""
    execute_query_tool = QuerySQLDatabaseTool(db=db)
    return {"result": execute_query_tool.invoke(state["query"])}

In [37]:
execute_query({"query": sqlQuery})

{'result': '[(-1000.0,), (-6.99,), (-0.82,), (-30.47,), (-0.58,), (-21.51,), (-1102.32,), (-2.99,), (-2.0,), (-1024.8,), (-12.81,), (-4.13,), (-7.5,), (-0.36,), (-13.3,), (-1.15,), (-42.44,), (-16.24,), (-2.74,), (-9.29,), (-12.99,), (-12.0,), (-16.27,), (-23.52,), (-119.9,), (-36.31,), (-34.16,), (-9.38,), (-8.72,), (-8.99,), (-12.89,), (-61.79,), (-105.72,), (-8.99,), (-42.51,), (-3.83,), (1400.0,), (807.86,), (-0.18,), (-6.54,), (-0.09,), (-3.39,), (-3.0,), (-40.05,), (-21.63,), (-60.0,), (-11.9,), (-4.24,), (-21.52,), (-11.15,), (-0.6,), (-22.18,), (-100.0,), (-8.85,), (-10.0,), (-9.6,), (-6.5,), (-2.0,), (-10.99,), (105.98,), (-19.56,), (-4.87,), (-180.3,), (-184.66,), (-37.22,), (-58.93,), (-14.4,), (-2.85,), (-5.4,), (-836.3,), (-16.65,), (-25.0,), (836.3,), (-296.45,), (-16.99,), (-349.45,), (-7.99,), (-4.9,), (-8.1,), (-5.0,), (-82.0,), (-20.68,), (-2.0,), (-8.64,), (105.98,), (-58.0,), (-15.0,), (-42.15,), (-12.5,), (-70.0,), (-0.38,), (-13.91,), (-1.73,), (-64.0,), (-6.3,), 

In [38]:
def generate_answer(state: State):
    """Answer question using retrieved information as context."""
    prompt = (
        "Given the following user question, corresponding SQL query, "
        "and SQL result, answer the user question.\n\n"
        f"Question: {state['question']}\n"
        f"SQL Query: {state['query']}\n"
        f"SQL Result: {state['result']}"
    )
    response = llm.invoke(prompt)
    return {"answer": response.content}

generate_answer(state)

KeyError: 'query'