In [2]:
import os
from typing import Dict, List, Optional, Any, Tuple
from pydantic import BaseModel, Field
from langchain_core.messages import AIMessage, HumanMessage
from langchain_openai import ChatOpenAI
from langgraph.graph import StateGraph, END
import psycopg2
import sqlalchemy
from sqlalchemy import create_engine, text, inspect

In [3]:
# Define state models
class DBConnection(BaseModel):
    """Database connection details"""
    type: str = Field(description="Type of database (PostgreSQL, SQL)")
    connection_string: str = Field(description="Connection string or parameters for the database")
    name: str = Field(description="A name to identify this connection")

class SchemaInfo(BaseModel):
    """Information about database schema"""
    tables: Dict[str, List[str]] = Field(default_factory=dict, description="Map of table names to column names")
    relationships: List[str] = Field(default_factory=list, description="Known foreign key relationships")
    description: str = Field(default="", description="Human-readable description of the schema")

class DBAssistantState(BaseModel):
    """State for the DB Assistant agent"""
    query: str = Field(description="The user's original query")
    connections: Dict[str, DBConnection] = Field(default_factory=dict, description="Available database connections")
    schema_info: Dict[str, SchemaInfo] = Field(default_factory=dict, description="Schema information for each connection")
    analysis_plan: Optional[List[str]] = Field(default=None, description="Steps to analyze and solve the query")
    current_step: int = Field(default=0, description="Current step in the analysis plan")
    results: List[Dict[str, Any]] = Field(default_factory=list, description="Results from database queries")
    messages: List[Any] = Field(default_factory=list, description="Conversation history")
    errors: List[str] = Field(default_factory=list, description="Any errors encountered")
    final_answer: Optional[str] = Field(default=None, description="Final answer to the user's query")

In [4]:
# Database connection and schema handling
def connect_to_db(connection: DBConnection):
    """Create a connection to the specified database"""
    if connection.type.lower() == "postgresql":
        return create_engine(connection.connection_string)
    elif connection.type.lower() == "sql" or connection.type.lower() == "mysql":
        # For general SQL databases
        return create_engine(connection.connection_string)
    else:
        raise ValueError(f"Unsupported database type: {connection.type}")

def get_schema_info(engine) -> SchemaInfo:
    """Extract schema information from a database"""
    inspector = inspect(engine)
    schema_info = SchemaInfo()
    
    # Get all tables and their columns
    for table_name in inspector.get_table_names():
        columns = [column['name'] for column in inspector.get_columns(table_name)]
        schema_info.tables[table_name] = columns
        
        # Get foreign key relationships
        for fk in inspector.get_foreign_keys(table_name):
            relationship = f"Table '{table_name}' column '{fk['constrained_columns'][0]}' references "
            relationship += f"table '{fk['referred_table']}' column '{fk['referred_columns'][0]}'"
            schema_info.relationships.append(relationship)
    
    # Create a human-readable description
    description_parts = ["Database schema:"]
    
    for table_name, columns in schema_info.tables.items():
        description_parts.append(f"- Table: {table_name}")
        description_parts.append(f"  Columns: {', '.join(columns)}")
    
    if schema_info.relationships:
        description_parts.append("\nRelationships:")
        for relationship in schema_info.relationships:
            description_parts.append(f"- {relationship}")
    
    schema_info.description = "\n".join(description_parts)
    return schema_info

def execute_query(engine, query_string):
    """Execute a SQL query and return results"""
    try:
        with engine.connect() as connection:
            result = connection.execute(text(query_string))
            # Convert to list of dictionaries
            columns = result.keys()
            return [dict(zip(columns, row)) for row in result.fetchall()]
    except Exception as e:
        return {"error": str(e)}

# LangGraph agent nodes
def parse_user_input(state: DBAssistantState) -> DBAssistantState:
    """Parse the user input to identify the type of query and required information"""
    llm = ChatOpenAI(model="gpt-4-turbo")
    
    # Prepare schema information for prompt
    schema_context = ""
    for conn_name, schema in state.schema_info.items():
        schema_context += f"\n\nConnection: {conn_name}\n{schema.description}"
    
    messages = [
        HumanMessage(content=f"""
        Given this user query: "{state.query}"
        
        Here is information about the available database schemas:
        {schema_context}
        
        The available database connections are: {list(state.connections.keys())}
        
        Based on the user query and the database schemas, create a step-by-step analysis plan.
        For each step, specify:
        1. Which database connection to use
        2. What information to retrieve or calculate
        3. Which tables and columns you'll need to query
        
        Make your plan as efficient as possible, using appropriate joins and aggregations.
        Format your response as a numbered list of steps, with each step clearly defined.
        """)
    ]
    
    response = llm.invoke(messages)
    # Parse the analysis plan from the LLM response
    plan_lines = [line.strip() for line in response.content.split('\n') if line.strip()]
    # Filter out any non-plan lines (intro text, etc.)
    plan = [line for line in plan_lines if line and (line.startswith("Step") or line.startswith("-") or line.startswith("*") or 
                                                    (len(line) > 0 and line[0].isdigit() and "." in line[:3]))]
    
    return DBAssistantState(
        **state.dict(),
        analysis_plan=plan,
        messages=state.messages + [HumanMessage(content=state.query), AIMessage(content=f"I've analyzed your query and created a plan based on the database schema:\n\n" + "\n".join(plan))]
    )

def execute_plan_step(state: DBAssistantState) -> DBAssistantState:
    """Execute the current step in the analysis plan"""
    if not state.analysis_plan or state.current_step >= len(state.analysis_plan):
        return state
    
    current_step = state.analysis_plan[state.current_step]
    llm = ChatOpenAI(model="gpt-4-turbo")
    
    # Prepare schema information for prompt
    schema_context = ""
    for conn_name, schema in state.schema_info.items():
        schema_context += f"\n\nConnection: {conn_name}\n{schema.description}"
    
    # Include previous results for context if available
    previous_results = ""
    if state.results:
        previous_results = "\n\nResults from previous steps:\n"
        for i, result in enumerate(state.results):
            previous_results += f"Step {i+1}: {result['description']}\n"
            previous_results += f"Results: {result['result']}\n"
    
    # Use LLM to generate SQL for the current step
    messages = [
        HumanMessage(content=f"""
        Based on the user query: "{state.query}"
        
        I need to execute this step in my analysis plan:
        "{current_step}"
        
        Here is information about the available database schemas:
        {schema_context}
        
        Available database connections: {list(state.connections.keys())}
        {previous_results}
        
        Generate the exact SQL query needed for this step.
        Make sure to:
        - Use the correct table and column names from the schema
        - Use appropriate joins if needed
        - Use proper syntax for the database type
        
        Format your response as:
        CONNECTION: [connection_name]
        SQL: [sql_query]
        """)
    ]
    
    response = llm.invoke(messages)
    
    # Parse the SQL query and connection from the response
    lines = response.content.split('\n')
    connection_name = None
    sql_query = ""
    
    for line in lines:
        if line.startswith("CONNECTION:"):
            connection_name = line.replace("CONNECTION:", "").strip()
        elif line.startswith("SQL:"):
            sql_query = line.replace("SQL:", "").strip()
        elif sql_query and line:  # For multi-line SQL
            sql_query += " " + line.strip()
    
    # Execute the SQL query
    try:
        if connection_name and connection_name in state.connections and sql_query:
            connection = state.connections[connection_name]
            engine = connect_to_db(connection)
            result = execute_query(engine, sql_query)
            
            # Store the results
            step_result = {
                "step": state.current_step,
                "description": current_step,
                "sql": sql_query,
                "connection": connection_name,
                "result": result
            }
            
            # Prepare message content
            message_content = f"Executed step {state.current_step + 1}:\n{current_step}\n\nSQL used:\n```sql\n{sql_query}\n```\n\nResults: {result}"
            
            # Update state
            return DBAssistantState(
                **state.dict(),
                current_step=state.current_step + 1,
                results=state.results + [step_result],
                messages=state.messages + [AIMessage(content=message_content)]
            )
        else:
            error = "Failed to extract valid connection or SQL query from LLM response."
            return DBAssistantState(
                **state.dict(),
                errors=state.errors + [error],
                messages=state.messages + [AIMessage(content=f"Error in step {state.current_step + 1}: {error}")]
            )
    except Exception as e:
        error = f"Error executing SQL: {str(e)}"
        return DBAssistantState(
            **state.dict(),
            errors=state.errors + [error],
            messages=state.messages + [AIMessage(content=f"Error in step {state.current_step + 1}: {error}")]
        )

def generate_final_answer(state: DBAssistantState) -> DBAssistantState:
    """Generate the final answer based on all the query results"""
    llm = ChatOpenAI(model="gpt-4-turbo")
    
    # Compile all results to create context for the final answer
    results_context = "\n\n".join([
        f"Step {r['step'] + 1}: {r['description']}\nSQL: {r['sql']}\nResults: {r['result']}"
        for r in state.results
    ])
    
    messages = [
        HumanMessage(content=f"""
        Based on the user query: "{state.query}"
        
        And these analysis results:
        {results_context}
        
        Generate a concise, natural language answer that directly addresses the user's query.
        Include relevant numbers and statistics when appropriate.
        """)
    ]
    
    response = llm.invoke(messages)
    
    return DBAssistantState(
        **state.dict(),
        final_answer=response.content,
        messages=state.messages + [AIMessage(content=response.content)]
    )

def should_continue_plan(state: DBAssistantState) -> str:
    """Determine if we should continue with the plan or move to final answer"""
    if not state.analysis_plan:
        return "generate_final_answer"
    
    if state.current_step < len(state.analysis_plan):
        return "execute_plan_step"
    else:
        return "generate_final_answer"

In [5]:
# Build the graph
def build_db_assistant_graph():
    """Build the LangGraph workflow for the DB Assistant agent"""
    graph = StateGraph(DBAssistantState)
    
    # Add nodes
    graph.add_node("parse_user_input", parse_user_input)
    graph.add_node("execute_plan_step", execute_plan_step)
    graph.add_node("generate_final_answer", generate_final_answer)
    
    # Add edges
    graph.add_edge("parse_user_input", should_continue_plan)
    graph.add_edge("execute_plan_step", should_continue_plan)
    graph.add_edge("generate_final_answer", END)
    
    # Set entrypoint
    graph.set_entry_point("parse_user_input")
    
    return graph.compile()

# Main handler function
def db_assistant(query: str, connections: List[Dict[str, Any]] = None) -> Dict[str, Any]:
    """
    Main function to handle a DB Assistant query
    
    Args:
        query: The user's query
        connections: List of DB connection details
        
    Returns:
        Dict containing the response and any results
    """
    # Process connections and get schema information
    processed_connections = {}
    schema_info = {}
    
    if connections:
        for conn in connections:
            db_conn = DBConnection(
                type=conn["type"],
                connection_string=conn["connection_string"],
                name=conn["name"]
            )
            processed_connections[conn["name"]] = db_conn
            
            # Extract schema information
            try:
                engine = connect_to_db(db_conn)
                schema_info[conn["name"]] = get_schema_info(engine)
            except Exception as e:
                # If we can't get schema info, create an empty schema with an error note
                schema_info[conn["name"]] = SchemaInfo(
                    description=f"Error retrieving schema: {str(e)}"
                )
    
    # Initialize state
    initial_state = DBAssistantState(
        query=query,
        connections=processed_connections,
        schema_info=schema_info
    )
    
    # Build and run the graph
    graph = build_db_assistant_graph()
    final_state = graph.invoke(initial_state)
    
    # Return results
    return {
        "answer": final_state.final_answer,
        "steps": len(final_state.analysis_plan) if final_state.analysis_plan else 0,
        "results": final_state.results,
        "errors": final_state.errors,
        "messages": [{"role": m.type, "content": m.content} for m in final_state.messages]
    }

In [6]:
# Example usage
if __name__ == "__main__":
    # Example connection
    connections = [
        {
            "name": "sales_db",
            "type": "postgresql",
            "connection_string": "postgresql://username:password@localhost:5432/sales_database"
        }
    ]
    
    # Example query
    query = "How many potatoes did we sell in 2024?"
    
    # Run the assistant
    result = db_assistant(query, connections)
    print(f"Answer: {result['answer']}")

ValueError: Found edge ending at unknown node `<function should_continue_plan at 0x11720ae80>`

In [None]:
from db_assistant import db_assistant
from dotenv import load_dotenv
import os
import time

# Load environment variables
load_dotenv()

def main():
    # Set up database connections
    connections = [
        {
            "name": "sales_db",
            "type": "postgresql",
            "connection_string": f"postgresql://{os.getenv('PG_USER')}:{os.getenv('PG_PASSWORD')}@{os.getenv('PG_HOST')}:{os.getenv('PG_PORT')}/{os.getenv('PG_DATABASE')}"
        },
        {
            "name": "inventory_db",
            "type": "mysql",
            "connection_string": f"mysql+pymysql://{os.getenv('MYSQL_USER')}:{os.getenv('MYSQL_PASSWORD')}@{os.getenv('MYSQL_HOST')}:{os.getenv('MYSQL_PORT')}/{os.getenv('MYSQL_DATABASE')}"
        },
        {
            "name": "customer_db",
            "type": "postgresql",
            "connection_string": f"postgresql://{os.getenv('PG_USER')}:{os.getenv('PG_PASSWORD')}@{os.getenv('PG_HOST')}:{os.getenv('PG_PORT')}/{os.getenv('CUSTOMER_DATABASE', 'customer_database')}"
        }
    ]
    
    print("Connecting to databases and extracting schema information...")
    # You could add a loading spinner here for a better user experience
    
    # Example queries that leverage multiple databases
    queries = [
        "How many potatoes did we sell in 2024?",
        "What's our best-selling product in Q1 2024?",
        "Calculate the monthly sales growth percentage for 2024 so far.",
        "Compare our top 5 customers by sales volume with their current inventory levels.",
        "Which products are close to running out of stock based on current inventory and sales velocity?"
    ]
    
    # Run each query through the DB Assistant
    for query in queries:
        print(f"\n\n{'='*80}\nQUERY: {query}\n{'='*80}")
        
        result = db_assistant(query, connections)
        
        print(f"\nANSWER: {result['answer']}")
        print(f"\nSTEPS EXECUTED: {result['steps']}")
        
        if result['errors']:
            print(f"\nERRORS: {result['errors']}")
        
        print("\nRESULTS:")
        for i, res in enumerate(result['results']):
            print(f"\nStep {i+1}: {res['description']}")
            print(f"SQL: {res['sql']}")
            print(f"Results: {res['result']}")

if __name__ == "__main__":
    main()