In [None]:
import os
import re
from typing import List, Dict, Any, Optional
import pandas as pd                                           
from pprint import pprint

# LangChain imports
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
from langchain_groq import ChatGroq
from langchain.sql_database import SQLDatabase
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.pydantic_v1 import BaseModel, Field

# LangGraph imports
from langgraph.graph import StateGraph, END ,MessagesState
# from langgraph.prebuilt import 

# Database connection
from sqlalchemy import create_engine, inspect, MetaData, text

# Set Groq API key
os.environ["GROQ_API_KEY"] = ""  # Replace with actual key

# Initialize the LLM
llm = ChatGroq(model="llama3-70b-8192", temperature=0)

# Set up database connection
db_uri = "mysql+pymysql://root:somia@localhost:3306/customer_products"
engine = create_engine(db_uri)
db = SQLDatabase.from_uri(db_uri)

def read_msisdns_from_file(file_path: str) -> List[str]:
    """Read MSISDNs from a file, ignoring empty lines and invalid entries."""
    msisdns = []
    try:
        with open(file_path, 'r') as file:
            for line in file:
                line = line.strip()
                if line and line.isdigit():
                    msisdns.append(line)
    except Exception as e:
        print(f"Error reading MSISDNs from file: {e}")
    return msisdns

msisdn_file_path = "C:/Users/somia.kumari/Desktop/msisdn.txt"
msisdns = read_msisdns_from_file(msisdn_file_path)
print(f"MSISDNs read from file: {msisdns}")

# Extend MessagesState with additional fields
class ExtendedState(MessagesState):
    available_tables: Optional[List[str]] = Field(default=None, description="List of available tables in the database")
    relevant_tables: Optional[List[str]] = Field(default=None, description="List of tables relevant to the query")
    table_schemas: Optional[Dict[str, str]] = Field(default=None, description="DDL definitions of relevant tables")
    generated_sql: Optional[str] = Field(default=None, description="SQL query generated from the user question")
    execution_result: Optional[Any] = Field(default=None, description="Result of executing the SQL query")
    execution_error: Optional[str] = Field(default=None, description="Error message if SQL execution failed")
    corrected_sql: Optional[str] = Field(default=None, description="Corrected SQL query if original query failed")
    msisdns: Optional[List[str]] = Field(default=None, description="List of MSISDNs to filter the query")

# Step 1: Fetch available tables
def fetch_available_tables(state: ExtendedState) -> ExtendedState:
    inspector = inspect(engine)
    available_tables = inspector.get_table_names()
    return {"available_tables": available_tables}

# Step 2: Determine relevant tables
def determine_relevant_tables(state: ExtendedState) -> ExtendedState:
    latest_query = state["messages"][-1].content if state["messages"] else ""
    available_tables = state["available_tables"]
    
    prompt = PromptTemplate.from_template("""
    You are an expert SQL database analyst. Given the user's question and the list of available tables,
    determine which tables are most likely needed to answer the question.
    
    User question: {query}
    
    Available tables: {available_tables}
    
    Return only a comma-separated list of the relevant table names, nothing else.
    """)
    
    chain = prompt | llm | StrOutputParser()
    tables_string = chain.invoke({
        "query": latest_query,
        "available_tables": ", ".join(available_tables)
    })
    
    table_pattern = re.compile(r'[\w_]+')
    potential_tables = table_pattern.findall(tables_string)
    relevant_tables = [table for table in potential_tables if table in available_tables]
    
    if not relevant_tables and available_tables:
        relevant_tables = available_tables
    
    return {"relevant_tables": relevant_tables}

# Step 3: Retrieve table DDL
def retrieve_table_schemas(state: ExtendedState) -> ExtendedState:
    relevant_tables = state["relevant_tables"]
    table_schemas = {}
    
    for table in relevant_tables:
        try:
            with engine.connect() as connection:
                inspector = inspect(engine)
                columns = inspector.get_columns(table)
                column_details = []
                
                for column in columns:
                    col_type = str(column['type'])
                    nullable = "NULL" if column.get('nullable', True) else "NOT NULL"
                    default = f"DEFAULT {column.get('default')}" if column.get('default') is not None else ""
                    column_details.append(f"{column['name']} {col_type} {nullable} {default}".strip())
                
                pk_info = inspector.get_pk_constraint(table)
                if pk_info and pk_info.get('constrained_columns'):
                    pk_cols = ", ".join(pk_info['constrained_columns'])
                    column_details.append(f"PRIMARY KEY ({pk_cols})")
                
                fk_info = inspector.get_foreign_keys(table)
                for fk in fk_info:
                    src_cols = ", ".join(fk['constrained_columns'])
                    ref_table = fk['referred_table']
                    ref_cols = ", ".join(fk['referred_columns'])
                    column_details.append(f"FOREIGN KEY ({src_cols}) REFERENCES {ref_table}({ref_cols})")
                
                create_stmt = f"CREATE TABLE {table} (\n  " + ",\n  ".join(column_details) + "\n);"
                table_schemas[table] = create_stmt
                
                sample_data_query = f"SELECT * FROM {table} LIMIT 5"
                sample_df = pd.read_sql(sample_data_query, connection)
                if not sample_df.empty:
                    sample_str = f"\n\n-- Sample data from {table}:\n"
                    sample_str += sample_df.to_string(index=False)
                    table_schemas[table] += sample_str
                    
        except Exception as e:
            table_schemas[table] = f"Error retrieving schema: {str(e)}"
    
    return {"table_schemas": table_schemas}

# Step 4: Generate SQL query with conversation history
def generate_sql_query(state: ExtendedState) -> ExtendedState:
    latest_query = state["messages"][-1].content if state["messages"] else ""
    table_schemas = state["table_schemas"]
    msisdns = state["msisdns"]
    
    print(f"MSISDNs being used: {msisdns}")
    print("Current state before generating SQL:", state)  # Added print statement
    
    if msisdns is None:
        msisdns = []
    msisdns_str = ", ".join([f"'{msisdn}'" for msisdn in msisdns])
    
    conversation_history = "\n".join([f"{m.type}: {m.content}" for m in state["messages"]])
    
    prompt = PromptTemplate.from_template("""
    You are an expert SQL developer specializing in MySQL syntax. Generate a SQL query to answer this question based on the conversation history:
    
    Conversation history:
    {conversation_history}
    
    Latest user question: {query}
    
    Relevant database table definitions:
    {table_schemas}
    
    Filter the results to only include rows where the MSISDN is in: {msisdns_str}
    
    Guidelines:
    1. Use table joins when needed (prefer INNER JOIN)
    2. Use appropriate WHERE clauses
    3. Include ORDER BY for sorting
    4. Add LIMIT clauses for "top N" results
    5. Use GROUP BY with aggregates when appropriate
    6. Fully qualify column references (table.column)
    7. Consider previous queries/results to interpret phrases like "other than those" or "the rest"
    
    Return only the SQL query.
    """)
    
    schemas_text = "\n\n".join([f"-- Table: {t}\n{schema}" for t, schema in table_schemas.items()])
    
    chain = prompt | llm | StrOutputParser()
    sql_query = chain.invoke({
        "query": latest_query,
        "conversation_history": conversation_history,
        "table_schemas": schemas_text,
        "msisdns_str": msisdns_str
    })
    
    sql_query = re.sub(r'```sql\s*|\s*```', '', sql_query).strip()
    print(f"Generated SQL Query: {sql_query}")
    
    return {"generated_sql": sql_query}

# Step 5: Validate and optimize SQL
def validate_sql(state: ExtendedState) -> ExtendedState:
    latest_query = state["messages"][-1].content if state["messages"] else ""
    sql_query = state["generated_sql"]
    table_schemas = state["table_schemas"]
    msisdns = state["msisdns"]
    
    if msisdns is None:
        msisdns = []
    msisdns_str = ", ".join([f"'{msisdn}'" for msisdn in msisdns])
    
    prompt = PromptTemplate.from_template("""
    You are an expert SQL reviewer. Review this SQL query:
    
    User question: {query}
    Relevant Tables Schemas:
    {table_schemas}
    
    Generated SQL Query: 
    {sql_query}
    
    MSISDN filter list: {msisdns_str}
    
    Check for:
    1. Syntax errors
    2. Missing JOIN conditions
    3. Inefficient patterns
    4. Logic errors
    5. Proper aggregation
    6. MSISDN filter preservation
    
    Return only the corrected SQL query with no additional explanatory text or comments.
    """)
    
    schemas_text = "\n\n".join([f"-- Table: {t}\n{schema}" for t, schema in table_schemas.items()])
    
    chain = prompt | llm | StrOutputParser()
    validated_sql = chain.invoke({
        "query": latest_query,
        "sql_query": sql_query,
        "table_schemas": schemas_text,
        "msisdns_str": msisdns_str
    })
    
    validated_sql = re.sub(r'```sql\s*|\s*```', '', validated_sql).strip()
    validated_sql_lines = [line.strip() for line in validated_sql.split('\n') if line.strip() and not line.strip().startswith(('Here', 'After', '*', '1.', '2.', '3.', '4.', '5.', '6.'))]
    validated_sql = '\n'.join(validated_sql_lines)
    
    print(f"Validated SQL Query: {validated_sql}")
    
    return {"generated_sql": validated_sql} if validated_sql != sql_query else {}

# Step 6: Execute SQL query
def execute_sql_query(state: ExtendedState) -> ExtendedState:
    sql_query = state["generated_sql"]
    print(f"Executing SQL Query: {sql_query}")
    
    try:
        result = db.run(sql_query)
        print(f"Execution Result: {result}")
        if not result or result.strip() == "":
            result = "Query executed successfully, but returned no results."
        return {"execution_result": result, "execution_error": None}
    except Exception as e:
        error_message = str(e)
        print(f"Execution Error: {error_message}")
        return {"execution_result": None, "execution_error": error_message}

# Step 7: Correct SQL on error
def correct_sql_on_error(state: ExtendedState) -> ExtendedState:
    if not state["execution_error"]:
        return {}
    
    latest_query = state["messages"][-1].content if state["messages"] else ""
    error_message = state["execution_error"]
    original_sql = state["generated_sql"]
    table_schemas = state["table_schemas"]
    
    prompt = PromptTemplate.from_template("""
    You are an expert SQL debugger. Fix this failed SQL query:
    
    User question: {query}
    Failed SQL query: {original_sql}
    Error: {error_message}
    Schema: {table_schemas}
    
    Return only the corrected SQL query with no additional explanatory text or comments.
    """)
    
    schemas_text = "\n\n".join([f"-- Table: {t}\n{schema}" for t, schema in table_schemas.items()])
    
    chain = prompt | llm | StrOutputParser()
    corrected_sql = chain.invoke({
        "query": latest_query,
        "original_sql": original_sql,
        "error_message": error_message,
        "table_schemas": schemas_text
    })
    
    corrected_sql = re.sub(r'```sql\s*|\s*```', '', corrected_sql).strip()
    corrected_sql_lines = [line.strip() for line in corrected_sql.split('\n') if line.strip() and not line.strip().startswith(('Here', 'After', '*', '1.', '2.', '3.', '4.', '5.', '6.'))]
    corrected_sql = '\n'.join(corrected_sql_lines)
    
    print(f"Corrected SQL Query: {corrected_sql}")
    
    try:
        result = db.run(corrected_sql)
        if not result or result.strip() == "":
            result = "Query executed successfully after correction, but returned no results."
        return {"execution_result": result, "execution_error": None, "corrected_sql": corrected_sql}
    except Exception as e:
        secondary_error = str(e)
        return {
            "execution_error": f"Original error: {state['execution_error']}\nCorrection failed: {secondary_error}",
            "corrected_sql": corrected_sql
        }

# Step 8: Formulate response
def formulate_response(state: ExtendedState) -> ExtendedState:
    latest_query = state["messages"][-1].content if state["messages"] else ""
    # Use .get() to safely handle cases where corrected_sql might not exist
    sql_query = state.get("corrected_sql", state["generated_sql"])
    result = state["execution_result"]
    error = state["execution_error"]
    
    if error:
        prompt = PromptTemplate.from_template("""
        The query failed with errors:
        
        User question: {query}
        SQL attempted: {sql_query}
        Error: {error}
        
        Provide a response that:
        1. Acknowledges the issue
        2. Explains simply what might be wrong
        3. Suggests rephrasing
        """)
        chain = prompt | llm | StrOutputParser()
        response = chain.invoke({"query": latest_query, "sql_query": sql_query, "error": error})
    else:
        prompt = PromptTemplate.from_template("""
        Format these SQL results into a clear answer:
        
        User question: {query}
        Results: {result}
        
        Provide a conversational response that:
        1. Answers the question
        2. Presents data clearly
        3. Highlights insights
        """)
        chain = prompt | llm | StrOutputParser()
        response = chain.invoke({"query": latest_query, "result": result})
    
    return {"messages": state["messages"] + [AIMessage(content=response)]}

# Define conditional edges
def should_correct_sql(state):
    return "correct_sql" if state["execution_error"] else "formulate_response"

# Build the LangGraph
def build_sql_agent_graph():
    workflow = StateGraph(ExtendedState)
    
    workflow.add_node("fetch_tables", fetch_available_tables)
    workflow.add_node("determine_tables", determine_relevant_tables)
    workflow.add_node("retrieve_schemas", retrieve_table_schemas)
    workflow.add_node("generate_sql", generate_sql_query)
    workflow.add_node("validate_sql", validate_sql)
    workflow.add_node("execute_sql", execute_sql_query)
    workflow.add_node("correct_sql", correct_sql_on_error)
    workflow.add_node("formulate_response", formulate_response)
    
    workflow.add_edge("fetch_tables", "determine_tables")
    workflow.add_edge("determine_tables", "retrieve_schemas")
    workflow.add_edge("retrieve_schemas", "generate_sql")
    workflow.add_edge("generate_sql", "validate_sql")
    workflow.add_edge("validate_sql", "execute_sql")
    
    workflow.add_conditional_edges(
        "execute_sql",
        should_correct_sql,
        {"correct_sql": "correct_sql", "formulate_response": "formulate_response"}
    )
    
    workflow.add_edge("correct_sql", "formulate_response")
    workflow.add_edge("formulate_response", END)
    
    workflow.set_entry_point("fetch_tables")
    
    return workflow.compile()

# Create the agent
sql_agent = build_sql_agent_graph()

def query_sql_agent(question: str, msisdn_file_path: str, previous_state: Optional[ExtendedState] = None) -> tuple[str, ExtendedState]:
    msisdns = read_msisdns_from_file(msisdn_file_path)
    print(f"Previous state messages: {previous_state['messages'] if previous_state else 'None'}")
    
    if previous_state and "messages" in previous_state:
        initial_state = ExtendedState(messages=previous_state["messages"] + [HumanMessage(content=question)], msisdns=msisdns)
    else:
        initial_state = ExtendedState(messages=[HumanMessage(content=question)], msisdns=msisdns)
    
    result = sql_agent.invoke(initial_state)
    
    final_response = next((m.content for m in reversed(result["messages"]) if isinstance(m, AIMessage)), "No response found")
    
    return final_response, result

if __name__ == "__main__":
    sample_questions = [
       "What are the top 3 products by sales?",
       "tell me the top 2 from the remaining."
        
    ]
    
    previous_state = None
    for question in sample_questions:
        print(f"\nQuestion: {question}")
        print("-" * 50)
        answer, previous_state = query_sql_agent(question, msisdn_file_path, previous_state)
        print(f"Answer: {answer}")
        print("=" * 80)
        


MSISDNs read from file: ['9123456789', '9087654321', '9555555555', '9444444444']

Question: What are the top 3 products by sales?
--------------------------------------------------
MSISDNs being used: ['9123456789', '9087654321', '9555555555', '9444444444']
Generated SQL Query: Here is the SQL query to answer the question:
SELECT 
  p.product_name, 
  COUNT(cp.product_id) AS sales_count
FROM 
  customer_purchases cp 
  INNER JOIN products p ON cp.product_id = p.product_id
WHERE 
  cp.msisdn IN ('9123456789', '9087654321', '9555555555', '9444444444')
GROUP BY 
  p.product_name
ORDER BY 
  sales_count DESC
LIMIT 3;
Validated SQL Query: SELECT
p.product_name,
COUNT(cp.product_id) AS sales_count
FROM
customer_purchases cp
INNER JOIN products p ON cp.product_id = p.product_id
WHERE
cp.msisdn IN ('9123456789', '9087654321', '9555555555', '9444444444')
GROUP BY
p.product_name
ORDER BY
sales_count DESC
LIMIT 3;
Executing SQL Query: SELECT
p.product_name,
COUNT(cp.product_id) AS sales_count
FRO

In [None]:
from IPython.display import Image, display  # Fixed import for Image
from langchain_core.runnables.graph import MermaidDrawMethod
# Display the workflow graph
display(
    Image(
        sql_agent.get_graph().draw_mermaid_png(
            draw_method=MermaidDrawMethod.API,
        )
    )
)