In [27]:
from dotenv import load_dotenv

load_dotenv(override=True)

True

# Prompt setup

In [28]:
# The following is just a sample semantic context used for this example. You must update your context here
semantic_context = """
This is a online store where, the shop sell PC hardware. The customers are from various locations and the orders are being tracked in the source database attached here
"""

agent_prompt = f"""
You are a helpful assistant that can generate a SQL query based on the user's question and semantic context.
You can also execute a MySQL query and return the result.

here is the semantic context you can use to understand the business and generate the query:
{semantic_context}

Take the following steps to provide the answer:
1. write reasoning steps to approach the question.
2. generate the sql query based on the reasoning steps.
3. if you want to profile any column before you executing the query to add the correct filter, for example if there is a status column, you can run DISTICT query on that column to get the unique values and then use that to add the correct filter.
4. execute the sql query and return the result (please add max limit of records as 100 to the query before executing it, otherwise it may go out of LLM context window).
5. if the result is not what you expected, please write the new reasoning steps and generate the new sql query and execute it again.

Expectation:
- try to give the data in a table format.
- don't hallucinat, just give the answer always based on the information available in the source.
- if the information is not present in the db or the result is empty, please inform that to the user rather than hallucinating.
- if you don't have enough information to generate the query, ask for more information from the user.
"""


# Retrival

In [29]:
from vector_index import ModelVectorIndex
from pathlib import Path
import json

def get_db_context(reasoning_steps: str) -> str:
    """
    Generate SQL query based on natural language question using semantic context and table relationships.
    
    Args:
        reasoning_steps: The user's natural language reasoning_steps about the data
    
    Returns:
        str: Generated SQL query
    """
    try:
        # Initialize vector index
        vector_index = ModelVectorIndex()
        index = vector_index.load_index("fs_cache/vector_index")
        
        # Load relationships
        relationships_path = Path("fs_cache/relationships")
        relationships = []
        for file_path in relationships_path.glob("*.json"):
            with open(file_path, "r") as f:
                relationships.append(json.load(f))
        
        docs_and_scores = index.similarity_search_with_score(reasoning_steps, k=5)
        similarity_threshold = 1.6

        relevant_docs = [doc for doc, score in docs_and_scores if score <= similarity_threshold]

        models = [json.loads(doc.page_content) for doc in relevant_docs]

        grouped = []
        for model in models:
            model_name = model["name"]
            related_rels = [
                rel for rel in relationships
                if model_name in rel.get("models", [])
            ]
            grouped.append({
                "model": model,
                "relationships": related_rels
            })

        return grouped
  
    except Exception as e:
        raise e

# Augmentation & Generation

### Generate DDL for models and relationships

In [30]:
def generate_ddl_for_models_and_relationships(model_relationships):
    # Build a lookup for model name -> (db, columns)
    model_lookup = {}
    for entry in model_relationships:
        model = entry["model"]
        model_lookup[model["name"]] = {
            "db": model["database"],
            "columns": {col["name"]: col for col in model["columns"]},
        }

    ddls = []
    for entry in model_relationships:
        model = entry["model"]
        table_name = model["name"]
        db_name = model["database"]
        columns = model["columns"]
        table_desc = model.get("properties", {}).get("description", "")
        pk_candidates = [
            col["name"]
            for col in columns
            if col["name"].lower().endswith("id") and col.get("notNull", 0)
        ]
        # Compose DDL
        ddl_lines = []
        ddl_lines.append(f"-- Table: {db_name}.{table_name}")
        if table_desc:
            ddl_lines.append(f"-- Description: {table_desc}")
        ddl_lines.append(f"CREATE TABLE {db_name}.{table_name} (")
        col_lines = []
        for col in columns:
            col_name = col["name"]
            col_type = col["type"]
            not_null = "NOT NULL" if col.get("notNull", 0) else ""
            desc = col.get("properties", {}).get("description", "")
            comment = f" -- {desc}" if desc else ""
            col_lines.append(f"    {col_name} {col_type} {not_null}{comment}".rstrip())
        # Add primary key if any
        if pk_candidates:
            pk = pk_candidates[0]
            col_lines.append(f"    ,PRIMARY KEY ({pk})")
        ddl_lines.append(",\n".join(col_lines))
        ddl_lines.append(");")
        ddls.append("\n".join(ddl_lines))

    # Now generate foreign key constraints
    fk_lines = []
    for entry in model_relationships:
        relationships = entry.get("relationships", [])
        for rel in relationships:
            # Parse the join condition: "table1.col1 = table2.col2"
            cond = rel.get("condition", "")
            if "=" not in cond:
                continue
            left, right = [x.strip() for x in cond.split("=")]
            left_table, left_col = left.split(".")
            right_table, right_col = right.split(".")
            # Only add FK if both tables are in the model_lookup
            if left_table in model_lookup and right_table in model_lookup:
                # Try to add FK from left_table to right_table
                fk_lines.append(
                    f"ALTER TABLE {model_lookup[left_table]['db']}.{left_table}\n"
                    f"    ADD FOREIGN KEY ({left_col}) REFERENCES {model_lookup[right_table]['db']}.{right_table}({right_col});"
                )
                # For MANY_TO_MANY, you might want to add both directions, but usually only one is needed

    return "\n\n".join(ddls + fk_lines)

### Generate SQL query from natural language

In [31]:
from langchain_openai import ChatOpenAI

def generate_sql_query_from_context_and_ddl(ddl: str, question: str, semantic_context: str, resoning_steps: str, feedback:str = None) -> list:
    """
    This function generates a SQL query based on the provided DDL, question, and semantic context.

    Args:
        ddl: The DDL of the models
        question: The user's question
        semantic_context: The business context and semantic information about the domain
        resoning_steps: The reasoning steps for the query
        feedback: provide the error message or any feedback to improve the query, if this tool is called again.

    Returns:
        sql_query: The generated SQL query
    """

    try:
        prompt = f"""
        Generate a SQL query based on the following DDLs, question, and semantic context:

        DDLs:
        ----*****DDLs*****----
        {ddl}
        ----*****END OF DDLs*****----

        Question:
        ----*****Question*****----
        {question}
        ----*****END OF Question*****----

            Resoning Steps to approach the question:
        ----*****Resoning Steps*****----
        {resoning_steps}
        ----*****END OF Resoning Steps*****----

        Semantic Context:
        ----*****Semantic Context*****----
        {semantic_context}
        ----*****END OF Semantic Context*****----

        Note: please include the database name in the query. and only use the table names and column names that are present in the DDLs and relationship. please don't halucinate or add any new table names or column names. if you don't have enough information to generate the query, please return "No information found"

        Expectation:
            - Please only returns the SQL query, nothing else.
            - you may generate one or more queries to answer the question.
            - please try to use joins whenever possble
            - always write the query in mysql syntax.
        """

        if feedback:
            prompt += f"\n\nFeedback previous attempt: {feedback}"

        model = ChatOpenAI(
            model="gpt-4o-mini",
            temperature=0.0,
        )

        response = model.invoke(prompt)

        return response.content.strip()
    except Exception as e:
        raise e
        

# Tools

In [32]:
from langchain_core.tools import tool

### generate_sql_query

In [38]:
@tool
def generate_sql_query(user_question: str, semantic_context: str, reasoning_steps: str, feedback:str = None) -> str:
    """
    Generate a SQL query based on the user's question and semantic context.

    Args:
        user_question: The user's question
        semantic_context: The business context and semantic information about the domain
        reasoning_steps: The reasoning steps for the query
        feedback: provide the error message or any feedback along with the previous attempt query to improve the query, if this tool is called again.
    Returns:
        sql_query: The generated SQL query
    """

    try:
        context = get_db_context(reasoning_steps)
        ddl = generate_ddl_for_models_and_relationships(context)
        sql_query = generate_sql_query_from_context_and_ddl(ddl, user_question, semantic_context, reasoning_steps, feedback)
        return sql_query
    except Exception as e:
        raise e
    

### execute_mysql_query

In [39]:
import os
from sqlalchemy import create_engine, text

@tool
def execute_mysql_query(query: str, database_name: str) -> list[str]:
    """
    This function will execute the mysql query and return the result.
    
    Args:
        query (str): The MySQL query to execute
        database_name (str): The specific database to connect to
    
    Returns:
        list: The query results
    """
    # base_url = os.getenv("MYSQL_URL")

    username = os.getenv("DB_USER")
    password = os.getenv("DB_PASS")
    host = os.getenv("DB_HOST")
    port = os.getenv("DB_PORT")
    url = f"mysql+mysqlconnector://{username}:{password}@{host}/{database_name}"

    
    engine = create_engine(url)
    try:
        print(f"Executing query: {query} in database: {database_name}")
        with engine.connect() as connection:
            # Convert the string query to a SQLAlchemy text object
            sql_query = text(query)
            result = connection.execute(sql_query)
            return result.fetchall()
    finally:
        engine.dispose()

# Agent

### Building the sql agent

In [55]:
from langgraph.prebuilt import create_react_agent
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, ToolMessage, AIMessage
from langgraph.checkpoint.memory import MemorySaver

memory = MemorySaver()

model = ChatOpenAI(model="gpt-4.1", temperature=0.0)

sql_agent = create_react_agent(
    name="sql_agent",
    model=model,
    tools=[generate_sql_query, execute_mysql_query],
    prompt=agent_prompt.format(semantic_context=semantic_context),
    checkpointer=memory,
)

### Testing the sql agent

In [90]:
# Your desired question goes here
messages = [HumanMessage(content="""
Who is my best customer? and why?
""")]

In [91]:
thread = {"configurable": {"thread_id": "123"}, "recursion_limit": 20}
final_message = ""
for event in sql_agent.stream({"messages": messages}, thread):
    for v in event.values():
        print(v)
        if v['messages'] and isinstance(v['messages'], list):
            final_message = v['messages'][-1].content

{'messages': [AIMessage(content='Reasoning Steps:\n1. "Best customer" typically refers to the customer who has contributed the most value to the business, usually by total order value.\n2. To determine this, I need to sum the total order value for each customer and identify the customer with the highest total.\n3. I will join the customer and order tables, group by customer, sum the order values, and order the result in descending order to get the top customer.\n4. I will also include the reason: the best customer is the one with the highest total order value.\n\nLet\'s generate and execute the query.', additional_kwargs={'tool_calls': [{'id': 'call_lfut6T9lwrGjkKbWC8xXptB9', 'function': {'arguments': '{"user_question":"Who is my best customer? and why?","semantic_context":"This is a online store where, the shop sell PC hardware. The customers are from various locations and the orders are being tracked in the source database attached here","reasoning_steps":"1. Sum the total order valu

### format the output

In [92]:
from IPython.display import Markdown, display
import pandas as pd
import io
import re

def display_markdown_with_tables(markdown_text):
    """
    Display markdown text with tables in a Jupyter notebook.
    Tables are rendered as pandas DataFrames for better visuals.
    """
    lines = markdown_text.split('\n')
    table_block = []
    in_table = False
    text_block = []

    def show_text_block():
        if text_block:
            display(Markdown('\n'.join(text_block)))
            text_block.clear()

    i = 0
    while i < len(lines):
        line = lines[i]
        # Detect start of a markdown table
        if re.match(r'^\s*\|.*\|\s*$', line):
            # If we were collecting text, display it
            show_text_block()
            # Start collecting table lines
            table_block = [line]
            i += 1
            # Collect all contiguous table lines
            while i < len(lines) and re.match(r'^\s*\|.*\|\s*$', lines[i]):
                table_block.append(lines[i])
                i += 1
            # Try to parse and display the table
            try:
                # Remove leading/trailing whitespace
                table_str = '\n'.join([l.strip() for l in table_block])
                # Pandas expects no leading/trailing pipes, so remove them
                table_str = re.sub(r'^\|', '', table_str, flags=re.MULTILINE)
                table_str = re.sub(r'\|$', '', table_str, flags=re.MULTILINE)
                df = pd.read_csv(io.StringIO(table_str), sep='|')
                display(df)
            except Exception as e:
                # If parsing fails, just display as markdown
                display(Markdown('\n'.join(table_block)))
        else:
            text_block.append(line)
            i += 1
    # Show any remaining text
    show_text_block()


display_markdown_with_tables(final_message)

Your best customer is Emma Johnson (CustomerId: 2).

Reason: Emma Johnson has the highest total order value, having spent a total of 4634.33 in your store. This makes her your most valuable customer based on total purchases.