In [10]:
import os
import json
from datetime import datetime
from typing import TypedDict
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_groq import ChatGroq
import psycopg2
from psycopg2.extras import RealDictCursor
from langgraph.graph import StateGraph, START, END


In [11]:
# PostgreSQL connection setup
postgres_uri = "postgresql://postgres:password@localhost:5433/ecommerce"
conn = psycopg2.connect(postgres_uri)
cursor = conn.cursor(cursor_factory=RealDictCursor)

In [None]:
groq_api_key = "gsk_K9tWaPzTCwCNUqTHFNLhWGdyb3FYfmfsdfsfsfsfsf9TXXzl6xgNqBDmuUGiRPs"
llm = ChatGroq(model="llama3-70b-8192", groq_api_key=groq_api_key)

class State(TypedDict):
    question: str
    sql_query: str
    query_results: list
    answer: str
    error: str
    retry_count: int

In [13]:


def clean_sql_query(query: str) -> str:
    """Clean SQL query by removing markdown code fences."""
    return query.strip().replace("```sql", "").replace("```", "").strip()

def execute_sql_query(query: str) -> list:
    """Execute SQL query and return results as a list of dictionaries."""
    try:
        cursor.execute(query)
        results = cursor.fetchall()
        return [dict(row) for row in results]
    except Exception as e:
        raise ValueError(f"Error executing SQL query: {e}\nQuery: {query}")

sql_prompt = PromptTemplate.from_template(
    """You are an AI that generates PostgreSQL queries for a database with two tables:
    - sales (order_id, product_id, quantity, unit_price, order_date)
    - products (product_id, product_name, description)

Generate a PostgreSQL query to answer the question: {question}

IMPORTANT:
- Use valid SQL syntax compatible with PostgreSQL.
- Use ISO 8601 format for date literals (e.g., '2025-01-01 00:00:00').
- Return only the SQL query string — no markdown, no extra explanation.
"""
)

answer_prompt = PromptTemplate.from_template(
    """Based on the query results, provide a clear and concise answer to the question: {question}

Query Results: {results}

Instructions:
- For numerical results (e.g., revenue), format numbers to two decimal places.
- For trends, summarize totals, products, and relevant dates.
- If the results are empty, return: "No data found for the query."
- Return only the final answer without any prefix like "Answer:".
"""
)

def generate_query(state: State) -> State:
    try:
        sql_chain = sql_prompt | llm | StrOutputParser() | clean_sql_query
        sql_query = sql_chain.invoke({"question": state["question"]})
        print(f"\nGenerated SQL Query:\n{sql_query}")
        return {"sql_query": sql_query}
    except Exception as e:
        return {"error": f"Error generating query: {e}"}

def execute_query(state: State) -> State:
    if state.get("error"):
        return state
    try:
        results = execute_sql_query(state["sql_query"])
        print(f"\nPostgreSQL Results:\n{results}")
        return {"query_results": results}
    except Exception as e:
        return {"error": f"Error executing query: {e}"}

def format_answer(state: State) -> State:
    if state.get("error"):
        return {"answer": state["error"]}
    try:
        results_str = json.dumps(state["query_results"], default=str, indent=2)
        answer_chain = answer_prompt | llm | StrOutputParser()
        answer = answer_chain.invoke({
            "question": state["question"],
            "results": results_str
        })
        return {"answer": answer}
    except Exception as e:
        print(f"\nError in format_answer:\n{e}")
        return {"answer": f"Error formatting answer: {e}"}

# Define the state graph
graph = StateGraph(State)
graph.add_node("generate_query", generate_query)
graph.add_node("execute_query", execute_query)
graph.add_node("format_answer", format_answer)
graph.set_entry_point("generate_query")
graph.add_conditional_edges(
    "generate_query",
    lambda state: "execute_query" if not state.get("error") else "format_answer",
    {"execute_query": "execute_query", "format_answer": "format_answer"}
)
graph.add_conditional_edges(
    "execute_query",
    lambda state: "format_answer" if not state.get("error") else "format_answer",
    {"format_answer": "format_answer"}
)
graph.add_edge("format_answer", END)

agent = graph.compile()

def process_query(question: str) -> str:
    inputs = {"question": question, "retry_count": 0}
    final_state = None
    for output in agent.stream(inputs):
        final_state = output
    if 'format_answer' in final_state and 'answer' in final_state['format_answer']:
        return final_state['format_answer']['answer']
    return final_state.get('error', 'No answer generated')

def main():
    queries = [
        "What is the total sales revenue in 2025?",
        "Which product sold the most units?",
        "Summarize sales trends for March 2025."
    ]
    for query in queries:
        print(f"\nQuery: {query}")
        response = process_query(query)
        print(f"\nResponse: {response}")



In [15]:
def cleanup():
    cursor.close()
    conn.close()

if __name__ == "__main__":
    try:
        main()
    finally:
        cleanup()


Query: What is the total sales revenue in 2025?

Generated SQL Query:
SELECT SUM(quantity * unit_price) AS total_revenue
FROM sales
WHERE order_date >= '2025-01-01 00:00:00' AND order_date < '2026-01-01 00:00:00';

Response: Error executing query: Error executing SQL query: cursor already closed
Query: SELECT SUM(quantity * unit_price) AS total_revenue
FROM sales
WHERE order_date >= '2025-01-01 00:00:00' AND order_date < '2026-01-01 00:00:00';

Query: Which product sold the most units?

Generated SQL Query:
SELECT p.product_name, SUM(s.quantity) AS total_units_sold
FROM sales s
JOIN products p ON s.product_id = p.product_id
GROUP BY p.product_name
ORDER BY total_units_sold DESC
LIMIT 1;

Response: Error executing query: Error executing SQL query: cursor already closed
Query: SELECT p.product_name, SUM(s.quantity) AS total_units_sold
FROM sales s
JOIN products p ON s.product_id = p.product_id
GROUP BY p.product_name
ORDER BY total_units_sold DESC
LIMIT 1;

Query: Summarize sales trends