In [74]:
from langchain.sql_database import SQLDatabase
from urllib.parse import quote_plus

def connect_to_database(db_type, username, password, host, port, db_name):
    """
    Connects to a database using LangChain's SQLDatabase.from_uri.

    Args:
        db_type (str): Type of the database ('postgresql', 'mysql', or 'mssql').
        username (str): Database username.
        password (str): Database password.
        host (str): Database host (typically 'localhost').
        port (int): Database port number.
        db_name (str): Name of the database.

    Returns:
        SQLDatabase instance connected to the specified database.
    """
    db_type = db_type.lower()

    if db_type == "postgresql":
        uri = f"postgresql+psycopg2://{username}:{password}@{host}:{port}/{db_name}"

    elif db_type == "mysql":
        uri = f"mysql+pymysql://{username}:{password}@{host}:{port}/{db_name}"

    elif db_type == "mssql":
        uri = f"mssql+pyodbc://{username}:{quote_plus(password)}@{host},{port}/{db_name}?driver=ODBC+Driver+17+for+SQL+Server"

    else:
        raise ValueError(f"Unsupported database type: {db_type}")

    return SQLDatabase.from_uri(uri)

db = connect_to_database("mysql", "root", "root", "localhost", 3306, "classicmodels")
db

<langchain_community.utilities.sql_database.SQLDatabase at 0x217d3ae5810>

In [75]:
from sqlalchemy import inspect

def get_compact_schema_summary(sql_db):
    inspector = inspect(sql_db._engine)
    summary_lines = []

    for table_name in inspector.get_table_names():
        columns = inspector.get_columns(table_name)
        column_defs = [
            f"{col['name']} ({col['type']})" for col in columns
        ]
        summary = f"Table: {table_name} → " + ", ".join(column_defs)
        summary_lines.append(summary)

    return "\n".join(summary_lines)

In [76]:
# len(db.get_table_info().split(' '))
get_compact_schema_summary(db)

'Table: customers → customerNumber (INTEGER), customerName (VARCHAR(50)), contactLastName (VARCHAR(50)), contactFirstName (VARCHAR(50)), phone (VARCHAR(50)), addressLine1 (VARCHAR(50)), addressLine2 (VARCHAR(50)), city (VARCHAR(50)), state (VARCHAR(50)), postalCode (VARCHAR(15)), country (VARCHAR(50)), salesRepEmployeeNumber (INTEGER), creditLimit (DECIMAL(10, 2))\nTable: employees → employeeNumber (INTEGER), lastName (VARCHAR(50)), firstName (VARCHAR(50)), extension (VARCHAR(10)), email (VARCHAR(100)), officeCode (VARCHAR(10)), reportsTo (INTEGER), jobTitle (VARCHAR(50))\nTable: offices → officeCode (VARCHAR(10)), city (VARCHAR(50)), phone (VARCHAR(50)), addressLine1 (VARCHAR(50)), addressLine2 (VARCHAR(50)), state (VARCHAR(50)), country (VARCHAR(50)), postalCode (VARCHAR(15)), territory (VARCHAR(10))\nTable: orderdetails → orderNumber (INTEGER), productCode (VARCHAR(15)), quantityOrdered (INTEGER), priceEach (DECIMAL(10, 2)), orderLineNumber (SMALLINT)\nTable: orders → orderNumber (I

In [77]:
from dotenv import load_dotenv

# This loads variables from .env into our environment
load_dotenv()

True

In [78]:
from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
import os


llm = ChatGoogleGenerativeAI(
    model='gemini-2.0-flash',
    google_api_key=os.getenv("GOOGLE_API_KEY"),
    temperature=0.2
)

In [79]:
llm.invoke('Hii').content

'Hi there! How can I help you today?'

In [80]:
from typing_extensions import TypedDict
import pandas as pd
from functools import wraps
from typing import Callable

class State(TypedDict, total=False):
    status: bool
    user_request: str
    generated_query: str
    query_result: str
    answer: str
    query_result_df: pd.DataFrame
    graph_required: bool
    graph_code: str
    reason: str
    table_required: bool
    answer: str

def manage_state(func: Callable) -> Callable:
    """Handels state object updation, error handelling."""
    @wraps(func)
    def wrapper(*args, **kwargs):
        # Fetch 'state' from either args[0] or kwargs['state']
        state = kwargs.get("state") or (args[0] if len(args) > 0 else None)

        # If state status is False we are not executing the function and just returning the existing state
        if state is None or not state.get("status", False):
            return state
        
        result = func(*args, **kwargs)

        # Updating the state
        for key, val in result.items():
            state[key] = result[key]
        
        return state
    return wrapper


In [81]:
from langchain_core.prompts import ChatPromptTemplate

system_message = """
Given an input question, generate a syntactically correct {dialect} SELECT query that can help retrieve the answer.

# Constraints and Expectations:

- You must generate only SELECT-type queries.  
  If the input asks for a non-SELECT operation (e.g., UPDATE, INSERT, DELETE, CREATE), or cannot be satisfied using data selection alone (e.g., training models, modifying schemas), set `status` to False and provide a clear explanation in `reason`. Do not return any query in such cases.

- If the question is about database metadata (e.g., listing tables, showing columns, describing schema), generate a SELECT query using standard catalog views such as `information_schema.tables` or `information_schema.columns` (if supported in {dialect}).

- If the input includes terms like "plot", "graph", "visualize", or "chart", assume the user wants to visualize the data — and generate the SELECT query required to retrieve that data. Do not reject these inputs.

- Never use `SELECT *`. Always select only the relevant columns based on the question.

- If the user requests “all”, “entire list”, “complete list”, or explicitly wants all records, DO NOT include a LIMIT clause.  
  In all other cases, limit the query to at most {top_k} records.

- Use only the columns and tables that are defined in the provided schema. Do not reference any names that are not listed.

- If the query is successfully generated and `status` is True:
  - Set `reason` to an empty string (`""`).
  - Do NOT explain anything or include additional comments.

Schema available:
{table_info}
"""


user_prompt = "Request: {input}"

query_prompt_template = ChatPromptTemplate(
    [("system", system_message), ("user", user_prompt)]
)


In [82]:
# for message in query_prompt_template.messages:
    # message.pretty_print()

In [83]:
from typing_extensions import Annotated


class QueryOutput(TypedDict):
    """Structure of the generated SQL query output."""

    status: Annotated[bool, "True if the query was successfully generated and is valid; False otherwise."]
    reason: Annotated[str, "Reason for failure while generating the query (if query generation is successful keep it as empty string.)."]
    generated_query: Annotated[str, "Syntactically valid SQL SELECT query."]


@manage_state
def write_query(state: State, llm, db):
    """Generate SQL query to fetch information."""
    
    prompt = query_prompt_template.invoke(
        {
            "dialect": db.dialect,
            "top_k": 15,
            "table_info": get_compact_schema_summary(db), # db.get_table_info(),
            "input": state["user_request"],
        }
    )
    structured_llm = llm.with_structured_output(QueryOutput)
    result = structured_llm.invoke(prompt)
    return result

In [84]:
write_query({"status": True, "user_request": "show which products are sold most and their sold count"}, llm, db)

{'status': True,
 'user_request': 'show which products are sold most and their sold count',
 'reason': '',
 'generated_query': 'SELECT productCode, sum(quantityOrdered) AS total_quantity_ordered FROM orderdetails GROUP BY productCode ORDER BY total_quantity_ordered DESC LIMIT 15;'}

In [11]:
from typing_extensions import TypedDict, Annotated
from typing import List, Dict
from langchain_core.language_models.base import BaseLanguageModel


class ColumnExtractionOutput(TypedDict, total=False):
    status: Annotated[bool, "True if successful, False otherwise"]
    column_names: Annotated[List[str], "List of column names extracted"]
    reason: Annotated[str, "Reason for failure if status is False"]


@manage_state
def get_column_names_from_query(state: State, llm) -> ColumnExtractionOutput:
    """
    Use LLM to extract column names or aliases from a SQL SELECT query.
    These column names will be used to construct a DataFrame from the result of the query.
    """
    query = state.get("generated_query", "").strip()
    
    prompt = (
        "You are given a SQL SELECT query. Extract **only** the output column names or aliases "
        "in the exact order they will appear in the result set. "
        "Return the result as a JSON object with this format:\n"
        "{ \"status\": true, \"column_names\": [\"col1\", \"col2\", ...] }\n\n"
        "If extraction fails, return:\n"
        "{ \"status\": false, \"reason\": \"<failure reason>\" }\n\n"
        "This is needed to construct a DataFrame from the query result.\n\n"
        f"SQL Query:\n{query}"
    )

    structured_llm = llm.with_structured_output(ColumnExtractionOutput)
    result = structured_llm.invoke(prompt)

    return result


In [12]:
# get_column_names_from_query({'status': True, 'generated_query': "SELECT c.customer_id AS customer_id, CONCAT(c.first_name, ' ', c.last_name) AS full_name, a.address AS address, ci.city AS city, co.country AS country, COUNT(r.rental_id) AS total_rentals, SUM(p.amount) AS total_amount_paid, MAX(p.payment_date) AS last_payment_date FROM customer c JOIN address a ON c.address_id = a.address_id JOIN city ci ON a.city_id = ci.city_id JOIN country co ON ci.country_id = co.country_id LEFT JOIN rental r ON c.customer_id = r.customer_id LEFT JOIN payment p ON r.rental_id = p.rental_id WHERE co.country = 'United States' GROUP BY c.customer_id, full_name, a.address, ci.city, co.country ORDER BY total_amount_paid DESC LIMIT 10;"}, llm)

In [13]:
@manage_state
def execute_query(state: State, db):
    """Execute SQL query."""
    try:
        query_result = db._execute(state["generated_query"])
        result = {'status': True, 'query_result': query_result}
    except Exception as e:
        result = {'status': False, 'reason': str(e)}
    return result

In [14]:
# execute_query({'status': True, 'generated_query': 'SELECT actor_id, first_name, last_name, last_update FROM actor LIMIT 10'}, db)

In [32]:
import pandas as pd
from typing import Dict

@manage_state
def construct_dataframe(state: State) -> Dict:
    """
    Constructs a pandas DataFrame from the query result.
    """
    data = state.get("query_result")
    if not isinstance(data, (list, tuple)) or not data:
        raise ValueError("Error processing request. reason: internal server error.")
    
    df = pd.DataFrame(data)
    return {'query_result_df': df}


In [33]:
state = State(status=True)
state['user_request'] = 'plot actors and their film counts'

state = write_query(state, llm, db)

state = get_column_names_from_query(state, llm)

state = execute_query(state, db)

state = construct_dataframe(state)

state

  metadata_table_names = [tbl.name for tbl in self._metadata.sorted_tables]
  for tbl in self._metadata.sorted_tables


{'status': True,
 'user_request': 'plot actors and their film counts',
 'reason': '',
 'generated_query': 'SELECT A.first_name, A.last_name, count(FA.film_id) FROM actor AS A JOIN film_actor AS FA ON A.actor_id = FA.actor_id GROUP BY A.actor_id LIMIT 15',
 'column_names': ['first_name', 'last_name', 'count'],
 'query_result': [{'first_name': 'PENELOPE',
   'last_name': 'GUINESS',
   'count(FA.film_id)': 19},
  {'first_name': 'NICK', 'last_name': 'WAHLBERG', 'count(FA.film_id)': 25},
  {'first_name': 'ED', 'last_name': 'CHASE', 'count(FA.film_id)': 22},
  {'first_name': 'JENNIFER', 'last_name': 'DAVIS', 'count(FA.film_id)': 22},
  {'first_name': 'JOHNNY',
   'last_name': 'LOLLOBRIGIDA',
   'count(FA.film_id)': 29},
  {'first_name': 'BETTE', 'last_name': 'NICHOLSON', 'count(FA.film_id)': 20},
  {'first_name': 'GRACE', 'last_name': 'MOSTEL', 'count(FA.film_id)': 30},
  {'first_name': 'MATTHEW', 'last_name': 'JOHANSSON', 'count(FA.film_id)': 20},
  {'first_name': 'JOE', 'last_name': 'SWANK

In [17]:
from typing_extensions import Annotated, TypedDict, NotRequired
import pandas as pd

class AnswerOutput(TypedDict):
    status: Annotated[bool, "True if request processing was successful, otherwise False"]
    answer: Annotated[str, "Concise natural language summary; leave empty if not applicable"]
    graph_required: Annotated[bool, "True if graphical output is suitable"]
    graph_code: Annotated[str, "Matplotlib code using `query_result_df`; leave empty if not applicable"]
    table_required: Annotated[bool, "True if table output is suitable"]
    reason: Annotated[str, "Explain failure if status is false; otherwise leave empty"]

@manage_state
def generate_answer(state: State, llm) -> AnswerOutput:
    """
    Analyze user request, SQL query, and result sample to generate structured answer output.
    Uses the DataFrame variable name 'query_result_df' in prompt and expected output.
    """
    query_result_df = state['query_result_df']
    sample_rows = query_result_df.head(3).to_dict(orient="records")
    column_names = list(query_result_df.columns)

    prompt = (
        "You are a data analysis assistant. Analyze the following inputs:\n"
        "- A user question\n"
        "- A corresponding SQL query\n"
        "- A sample of the SQL query result (assume full data is in a pandas DataFrame named `query_result_df`)\n\n"
        "You must return a JSON object with the following structure:\n"
        "{\n"
        '  "status": true | false,\n'
        '  "answer": "concise natural language summary (only if status is true othersie leave it blank)",\n'
        '  "graph_required": true | false,\n'
        '  "graph_code": "matplotlib code using the query_result_df variable (only if graph_required is true othersie leave it blank)",\n'
        '  "table_required": true | false,\n'
        '  "reason": "explanation for failure (only if status is false othersie leave it blank)"\n'
        "}\n\n"
        "Rules:\n"
        "- Use the existing variable `query_result_df` in all graphing code.\n"
        "- Never include actual table data.\n"
        "- Avoid hardcoded values. Use column names and query_result_df operations.\n"
        "- Only return valid JSON. No explanations outside the JSON.\n"
        "- If the data includes rows and columns suitable for tabular display (like listings, joined information, or group summaries), set `table_required` to true.\n"
        "- If the user request explicitly or implicitly asks for a chart, visual representation, or uses terms like 'graph', 'plot', 'visualize', 'show trend', etc., set `graph_required` to true.\n"
        "- If unsure, prefer setting `table_required` to true and `graph_required` to false.\n\n"
        f"User Request: {state['user_request']}\n"
        f"SQL Query: {state['generated_query']}\n"
        f"Result Columns: {column_names}\n"
        f"Sample Rows: {sample_rows}\n"
    )

    structured_llm = llm.with_structured_output(AnswerOutput)
    result = structured_llm.invoke(prompt)
    return result

In [18]:
# generate_answer(state, llm)

In [19]:
def pipeline(state: State, llm, db) -> State:
    """Execute pipeline from generating query to retrieving desired answer."""
    
    # 1 Generate the SQL query
    state = write_query(state, llm, db)

    # 2 Execute the query and retrive the data from db
    state = execute_query(state, db)

    # 3 Construct the dataframe from data
    state = construct_dataframe(state)

    # 4 Generate the final answer
    state = generate_answer(state, llm)
    
    return state


In [20]:
state = State(status=True)
state['user_request'] = 'plot actors and their film counts'

state = pipeline(state, llm, db)

state

  metadata_table_names = [tbl.name for tbl in self._metadata.sorted_tables]
  for tbl in self._metadata.sorted_tables


{'status': True,
 'user_request': 'plot actors and their film counts',
 'generated_query': 'SELECT A.first_name, A.last_name, count(FA.film_id) FROM actor AS A JOIN film_actor AS FA ON A.actor_id = FA.actor_id GROUP BY A.actor_id LIMIT 15',
 'column_names': ['first_name', 'last_name', 'count'],
 'query_result': [{'first_name': 'PENELOPE',
   'last_name': 'GUINESS',
   'count(FA.film_id)': 19},
  {'first_name': 'NICK', 'last_name': 'WAHLBERG', 'count(FA.film_id)': 25},
  {'first_name': 'ED', 'last_name': 'CHASE', 'count(FA.film_id)': 22},
  {'first_name': 'JENNIFER', 'last_name': 'DAVIS', 'count(FA.film_id)': 22},
  {'first_name': 'JOHNNY',
   'last_name': 'LOLLOBRIGIDA',
   'count(FA.film_id)': 29},
  {'first_name': 'BETTE', 'last_name': 'NICHOLSON', 'count(FA.film_id)': 20},
  {'first_name': 'GRACE', 'last_name': 'MOSTEL', 'count(FA.film_id)': 30},
  {'first_name': 'MATTHEW', 'last_name': 'JOHANSSON', 'count(FA.film_id)': 20},
  {'first_name': 'JOE', 'last_name': 'SWANK', 'count(FA.fi