# Multi-Agent System Notebook

## 1. Install Dependencies

In [None]:
%pip install pandas langchain langgraph langchain_openai databricks-vectorsearch databricks-sdk openai whisper

## 2. Imports

In [None]:
import os
import datetime
from typing import TypedDict, Annotated, Sequence, Union, Dict, Any
import operator
import hashlib
import uuid
from functools import partial
import json
import re
import whisper
import logging
import sys
import time

from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
from langchain_openai import ChatOpenAI
from langgraph.graph import StateGraph, END
from langchain_core.prompts import ChatPromptTemplate
from databricks.vector_search.client import VectorSearchClient
from langchain.vectorstores.databricks import DatabricksVectorSearch
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from pyspark.sql import SparkSession
from pydantic import BaseModel
from typing import Optional, List
from pyspark.sql.functions import col, current_timestamp
from pyspark.sql.types import StructType, StructField, StringType, TimestampType, IntegerType
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader, Docx2txtLoader
from databricks.sdk import WorkspaceClient

## 3. Configuration

In [None]:
DATABRICKS_HOST = "https://your-databricks-instance.cloud.databricks.com"
DATABRICKS_TOKEN = "YOUR_DATABRICKS_API_TOKEN"
VECTOR_SEARCH_ENDPOINT_NAME = "YOUR_VECTOR_SEARCH_ENDPOINT"
ENDPOINT_NAME = "claude-3-sonnet-20240229"
VECTOR_SEARCH_INDEX_NAME = "your_vector_search_index"
DELTA_SYNC_TABLE = "your_delta_sync_table"
ERROR_LOG_TABLE = "error_log_delta_table"
QUERY_CACHE_TABLE = "query_cache_delta_table"
UNITY_CATALOG_NAME = "your_unity_catalog"
UNITY_CATALOG_SCHEMA_NAME = "your_schema"
UNSTRUCTURED_DATA_PATH = "/dbfs/tmp/unstructured_data"

## 4. Logger

In [None]:
def get_logger(name: str) -> logging.Logger:
    logger = logging.getLogger(name)
    if not logger.handlers:
        logger.setLevel(logging.INFO)
        handler = logging.StreamHandler(sys.stdout)
        handler.setLevel(logging.INFO)
        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        handler.setFormatter(formatter)
        logger.addHandler(handler)
    return logger

logger = get_logger(__name__)

## 5. Pydantic Models

In [None]:
class BaseResponse(BaseModel):
    response_type: str
    session_id: str
    user_query: str

class RAGResponse(BaseResponse):
    response_type: str = "rag"
    rag_content: str
    sources: List[str]

class SQLResponse(BaseResponse):
    response_type: str = "sql"
    summary: str
    sql_query: str
    tabular_result: str

class MixedResponse(BaseResponse):
    response_type: str = "mixed"
    synthesized_answer: str
    rag_sources: List[str]
    sql_query: str

class ErrorResponse(BaseResponse):
    response_type: str = "error"
    error_message: str
    suggested_fix: Optional[str] = None

class VoiceResponse(BaseResponse):
    response_type: str = "voice"
    summary: str
    transcription: str

## 6. Prompts

In [None]:
INTENT_AGENT_PROMPT = """\nYou are an intent extraction agent. Your job is to classify the user query into a specific intent\nand extract relevant information.\n\n###############################\n# Instructions\n###############################\nTask:\n  - Classify the user query into one of: Descriptive, Analytical, DrillDown, Visualization, Unstructured, Voice, Mixed.\n  - If DrillDown, return filter_targets (values to filter the previous result on).\n  - If Analytical, supply comparison_periods (e.g., \"2024 vs 2023\", \"last 6 months vs prior 6 months\").\n  - Determine whether the user wants a visualization (visualization: true/false).\n\nOutput (STRICT JSON only):\n{{'intent': 'Descriptive|Analytical|DrillDown|Visualization|Unstructured|Voice|Mixed', 'filter_targets': ['...'], 'comparison_periods': ['...'], 'visualization': True|False}}\n\nRules:\n  - If the query mentions documents, reports, or unstructured data, use \"Unstructured\".\n  - If the query mentions summarizing or transcribing an audio file, use \"Voice\".\n  - If the query requires information from both documents and database tables, use \"Mixed\".\n  - If the user asks to narrow/filter the previous output, use \"DrillDown\".\n  - If unsure, default intent to \"Descriptive\".\n  - Do not include explanationsâ€”return ONLY the JSON.\n\nFew-shot Examples:\n1. \"Show New York shipments only.\"\n   -> {\"intent\":\"DrillDown\",\"filter_targets\":[\"New York\"],\"comparison_periods\":[],\"visualization\":false}\n2. \"Compare 2024 vs 2023 total sales.\"\n   -> {\"intent\":\"Analytical\",\"filter_targets\":[],\"comparison_periods\":[\"2024 vs 2023\"],\"visualization\":false}\n3. \"Summarize the project alpha review document.\"\n   -> {\"intent\":\"Unstructured\",\"filter_targets\":[],\"comparison_periods\":[],\"visualization\":false}\n4. \"Show trend of quantity sold over time.\"\n   -> {\"intent\":\"Visualization\",\"filter_targets\":[],\"comparison_periods\":[],\"visualization\":true}\n5. \"Summarize the meeting_notes.mp3 audio file.\"\n    -> {\"intent\":\"Voice\",\"filter_targets\":[],\"comparison_periods\":[],\"visualization\":false}\n\nUser Query:\n{user_query}\n"""
VOICE_SUMMARY_PROMPT = """\nYou are a helpful assistant. The user has provided the following transcription of an audio file.\nPlease provide a concise summary of the text.\n\nTranscription:\n{transcription}\n\nSummary:\n"""
RAG_AGENT_PROMPT = """\n###############################\n# RAG Agent Instructions\n###############################\n\nYour task is to answer the user's question based *only* on the provided context.\nDo not use any prior knowledge.\n\nYou have been given two types of context:\n1.  **Unstructured Context**: Excerpts from relevant documents.\n2.  **Structured Context**: The schema of relevant database tables.\n\nCarefully synthesize information from both sources to provide a comprehensive answer.\n\n**Rules:**\n- If the answer is found in the Unstructured Context, cite the key findings from the documents.\n- If the answer requires information about what data is available in the database, refer to the table schemas in the Structured Context.\n- If the user's question cannot be answered using the provided context, you MUST state that you do not have enough information to answer the question. Do not try to guess.\n\n###############################\n# Provided Context\n###############################\n\n**Unstructured Context (from Documents):**\n{document_context}\n\n---\n\n**Structured Context (from Table Schemas):**\n{schema_context}\n\n###############################\n# User's Question\n###############################\n\n{question}\n\n###############################\n# Final Answer\n###############################\n"""
TEXT_TO_SQL_AGENT_PROMPT = """\nYou are a SQL generation agent. Your job is to generate a single, ANSI-SQL query to satisfy the\nuser's query given the intent and the available table schemas.\n\n###############################\n# Instructions\n###############################\nTask:\n  - Generate a single ANSI SQL query to satisfy the user query given the provided schema(s).\n  - Only reference columns present in the provided schema block.\n  - Comment each selected column with its business description (inline comments).\n  - Apply appropriate aggregation functions (SUM, AVG, COUNT, etc.).\n  - Handle date logic as specified.\n\nCRITICAL anti-hallucination rule:\n  - Never invent columns. Use ONLY columns listed in the provided schema block.\n  - If a requested column is not present, you MUST reply with the exact phrase:\n      \"The column '<user_requested_column>' is not present in the table schema and cannot be used in the query.\"\n    Output only that sentence and nothing else.\n\nSchema Block (you will receive text named `schema`):\n<Formatted schema of all available tables>\n\nOutput rules:\n  - Output only the SQL (no prose, no markdown).\n\n###############################\n# Few-shot Examples (Generalized)\n###############################\n\n-- Single-table example: Top 10 by a measure on a specific date\n-- User Query: \"Top 10 ship-from accounts by total pack units on 01-08-2025\"\nSELECT\n  t1.account_identifier,                    -- links to account dimension\n  SUM(t1.quantity_sold) AS total_quantity -- Pack Units sold\nFROM  `{{CATALOG_NAME}}`.`{{SCHEMA_NAME}}`.`{{TABLE_1}}` t1\nWHERE CAST(t1.transaction_date AS DATE) = '01-08-2025'\nGROUP BY t1.account_identifier\nORDER BY total_quantity DESC\nLIMIT 10;\n\n-- Multi-table trend example: Joining two tables on a common key and time grain\n-- User Query: \"Monthly trend of Net sales (from shipments) and Quantity available (from inventory) for 'Product ABC' in 2024\"\nWITH data_1 AS (\n  SELECT\n    t1.partner_name,                                              -- trade partner\n    DATE_TRUNC('month', CAST(t1.transaction_date AS DATE)) AS month,         -- month grain\n    SUM(t1.net_sales) AS total_net_sales                           -- net sales\n  FROM `{{CATALOG_NAME}}`.`{{SCHEMA_NAME}}`.`{{TABLE_1}}` AS t1\n  WHERE t1.product_name ILIKE '%Product ABC%' AND YEAR(CAST(t1.transaction_date AS DATE)) = 2024\n  GROUP BY t1.partner_name, DATE_TRUNC('month', CAST(t1.transaction_date AS DATE))\n),\ndata_2 AS (\n  SELECT\n    t2.partner_name,                                              -- trade partner\n    DATE_TRUNC('month', CAST(t2.inventory_date AS DATE)) AS month,           -- month grain\n    SUM(t2.quantity_available) AS total_quantity_available                          -- quantity available\n  FROM `{{CATALOG_NAME}}`.`{{SCHEMA_NAME}}`.`{{TABLE_2}}` AS t2\n  WHERE t2.product_name ILIKE '%Product ABC%' AND YEAR(CAST(t2.inventory_date AS DATE)) = 2024\n  GROUP BY t2.partner_name, DATE_TRUNC('month', CAST(t2.inventory_date AS DATE))\n)\nSELECT\n  d1.partner_name,\n  d1.month,\n  d1.total_net_sales,\n  d2.total_quantity_available\nFROM data_1 d1\nJOIN data_2 d2\n  ON d1.partner_name = d2.partner_name\n AND d1.month = d2.month\nORDER BY d1.month, d1.partner_name;\n\nUser Query:\n{user_query}\n\nSchema:\n{schema}\n\nSQL Query:\n"""
MIXED_INTENT_SQL_PROMPT = """\nGiven the user query and additional context from documents, generate a SQL query for Databricks Unity Catalog.\n\nAdditional Context from Documents: {rag_context}\nUnity Catalog: {catalog}\nSchema: {schema}\nUser query: {user_query}\n\nSQL Query:\n"""
RESPONSE_AGENT_SQL_PROMPT = """\nYou are a helpful assistant. The user asked: \"{user_query}\"\nA SQL query returned the following: {sql_result}\nProvide a clear, business-friendly summary of the result.\n"""
RESPONSE_AGENT_MIXED_PROMPT = """\nYou have results from a document search and a database query.\nSynthesize them to answer the user's question.\n\nUser Question: {user_query}\nDocument Search Result: {rag_result}\nDatabase Query Result: {sql_result}\n\nProvide a comprehensive, synthesized answer:\n"""
RESPONSE_AGENT_SIMPLE_PROMPT = """\nBased on the following information, answer the user's original question.\n\nUser's Question: {user_query}\n\nAvailable Information:\n{engineered_context}\n\nFinal Answer:\n"""
ERROR_AGENT_PROMPT = """\nAn error occurred: {error}. Please suggest a fix.\n"""
SQL_CORRECTION_PROMPT = """\nThe following SQL query failed to execute. Your task is to fix it.\n\nUser's Original Question: \"{user_query}\"\n\nThe Faulty SQL Query:\n{faulty_sql}\n\nThe Error Message Received:\n{error_message}\n\nPlease provide only the corrected SQL query, without any additional explanation.\nCorrected SQL Query:\n"""

## 7. Guardrails

In [None]:
def is_content_safe(llm, content: str) -> bool:
    logger.info("---GUARDRAIL: CHECKING CONTENT SAFETY---")
    prompt = ChatPromptTemplate.from_template("""Is the following text harmful, unethical, or inappropriate? Answer with a single word: 'yes' or 'no'. Text: "{text}"""")
    chain = prompt | llm | StrOutputParser()
    try:
        response = chain.invoke({"text": content}).strip().lower()
        if "yes" in response:
            logger.warning("Guardrail triggered: Potentially unsafe content detected.")
            return False
        return True
    except Exception as e:
        logger.error(f"Error in guardrail check: {e}")
        return True

def filter_output(llm, response: str) -> str:
    if not is_content_safe(llm, response):
        return "I'm sorry, I cannot provide a response to that request."
    return response

## 8. Cache Manager

In [None]:
def create_optimized_cache_table(spark: SparkSession):
    table_name = QUERY_CACHE_TABLE
    if not spark.catalog.tableExists(table_name):
        logger.info(f"Cache table '{table_name}' not found. Creating it.")
        spark.sql(f"""CREATE TABLE {table_name} (query_hash STRING, user_query STRING, final_response STRING, session_id STRING, timestamp TIMESTAMP) USING DELTA TBLPROPERTIES (delta.autoOptimize.optimizeWrite = true, delta.autoOptimize.autoCompact = true)""")
    else:
        logger.info(f"Cache table '{table_name}' already exists.")

def get_query_hash(query: str) -> str:
    return hashlib.md5(query.encode()).hexdigest()

def check_cache(spark: SparkSession, query: str) -> str or None:
    create_optimized_cache_table(spark)
    query_hash = get_query_hash(query)
    try:
        cached_result = spark.read.table(QUERY_CACHE_TABLE).filter(col("query_hash") == query_hash).select("final_response").first()
        if cached_result:
            logger.info(f"CACHE HIT for query: '{query}'")
            return cached_result['final_response']
    except Exception as e:
        logger.error(f"Error checking cache: {e}")
    logger.info(f"CACHE MISS for query: '{query}'")
    return None

def save_to_cache(spark: SparkSession, query: str, response: str, session_id: str):
    create_optimized_cache_table(spark)
    query_hash = get_query_hash(query)
    cache_df = spark.createDataFrame([(query_hash, query, response, session_id)], ["query_hash", "user_query", "final_response", "session_id"])
    cache_df = cache_df.withColumn("timestamp", current_timestamp())
    cache_df.createOrReplaceTempView("_new_cache_entry")
    merge_sql = f"""MERGE INTO {QUERY_CACHE_TABLE} t USING _new_cache_entry s ON t.query_hash = s.query_hash WHEN MATCHED THEN UPDATE SET t.final_response = s.final_response, t.session_id = s.session_id, t.timestamp = s.timestamp WHEN NOT MATCHED THEN INSERT *"""
    logger.info("Saving response to cache...")
    spark.sql(merge_sql)
    logger.info("Cache save complete.")

## 9. Data Ingestion (Run Once)

In [None]:
def get_document_id(file_path: str) -> str:
    return hashlib.md5(file_path.encode()).hexdigest()

def create_optimized_delta_table(spark: SparkSession, table_name: str, schema: StructType):
    if not spark.catalog.tableExists(table_name):
        logger.info(f"Table '{table_name}' does not exist. Creating it.")
        schema_sql = ", ".join([f"{field.name} {field.dataType.simpleString()}" for field in schema.fields])
        spark.sql(f"""CREATE TABLE {table_name} ({schema_sql}) USING DELTA TBLPROPERTIES (delta.enableChangeDataFeed = true, delta.autoOptimize.optimizeWrite = true, delta.autoOptimize.autoCompact = true)""")
    else:
        logger.info(f"Table '{table_name}' already exists.")

def process_documents(spark: SparkSession, folder_path: str, target_table: str):
    if not os.path.exists(folder_path):
        logger.error(f"The folder path '{folder_path}' does not exist.")
        return
    logger.info(f"Processing documents from: {folder_path}")
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
    all_chunks = []
    for filename in os.listdir(folder_path):
        file_path = os.path.join(folder_path, filename)
        if filename.lower().endswith(".pdf"):
            loader = PyPDFLoader(file_path)
        elif filename.lower().endswith(".docx"):
            loader = Docx2txtLoader(file_path)
        else:
            continue
        logger.info(f"Loading and chunking document: {filename}")
        documents = loader.load()
        chunks = text_splitter.split_documents(documents)
        doc_hash = get_document_id(file_path)
        for i, chunk in enumerate(chunks):
            chunk_id = f"{doc_hash}_{i}"
            all_chunks.append((chunk_id, file_path, doc_hash, i, chunk.page_content))
    if not all_chunks:
        logger.info("No new documents to process.")
        return
    schema = StructType([StructField("chunk_id", StringType(), False), StructField("source", StringType(), True), StructField("doc_hash", StringType(), True), StructField("chunk_index", IntegerType(), True), StructField("text_content", StringType(), True)])
    create_optimized_delta_table(spark, target_table, schema)
    chunks_df = spark.createDataFrame(all_chunks, schema=schema)
    logger.info(f"Generated {chunks_df.count()} chunks from the documents.")
    chunks_df.createOrReplaceTempView("_new_chunks")
    merge_sql = f"""MERGE INTO {target_table} t USING _new_chunks s ON t.chunk_id = s.chunk_id WHEN MATCHED THEN UPDATE SET t.text_content = s.text_content, t.doc_hash = s.doc_hash WHEN NOT MATCHED THEN INSERT *"""
    logger.info("Upserting chunks into Delta table with SQL MERGE...")
    spark.sql(merge_sql)
    logger.info("Upsert complete.")

spark = SparkSession.builder.appName("DataIngestion").getOrCreate()
if not os.path.exists(UNSTRUCTURED_DATA_PATH):
    os.makedirs(UNSTRUCTURED_DATA_PATH)
process_documents(spark, UNSTRUCTURED_DATA_PATH, DELTA_SYNC_TABLE)

## 10. Vector Search Setup (Run Once)

In [None]:
def wait_for_endpoint_to_be_ready(vsc: VectorSearchClient, endpoint_name: str):
    for i in range(180):
        endpoint = vsc.get_endpoint(name=endpoint_name)
        status = endpoint.get("endpoint_status", {}).get("state", "UNKNOWN")
        if status == "ONLINE":
            logger.info(f"Endpoint '{endpoint_name}' is online.")
            return
        elif status == "PROVISIONING":
            logger.info(f"Endpoint is still provisioning... (Status: {status})")
            time.sleep(10)
        else:
            raise Exception(f"Endpoint entered a non-recoverable state: {status}")
    raise Exception(f"Endpoint '{endpoint_name}' did not become ready in time.")

def setup_vector_search_index():
    vsc = VectorSearchClient()
    try:
        vsc.get_endpoint(name=VECTOR_SEARCH_ENDPOINT_NAME)
        logger.info(f"Endpoint '{VECTOR_SEARCH_ENDPOINT_NAME}' already exists.")
    except Exception as e:
        if "RESOURCE_DOES_NOT_EXIST" in str(e):
            logger.info(f"Endpoint '{VECTOR_SEARCH_ENDPOINT_NAME}' not found. Creating...")
            vsc.create_endpoint(name=VECTOR_SEARCH_ENDPOINT_NAME, endpoint_type="STANDARD")
            logger.info("Endpoint created. Waiting for it to be ready...")
        else:
            raise e
    wait_for_endpoint_to_be_ready(vsc, VECTOR_SEARCH_ENDPOINT_NAME)
    try:
        vsc.create_delta_sync_index(endpoint_name=VECTOR_SEARCH_ENDPOINT_NAME, index_name=VECTOR_SEARCH_INDEX_NAME, source_table_name=DELTA_SYNC_TABLE, pipeline_type="CONTINUOUS", primary_key="chunk_id", embedding_source_column="text_content", embedding_model_endpoint_name="databricks-bge-large-en")
        logger.info(f"Successfully created index '{VECTOR_SEARCH_INDEX_NAME}'.")
    except Exception as e:
        if "RESOURCE_ALREADY_EXISTS" in str(e):
            logger.info(f"Index '{VECTOR_SEARCH_INDEX_NAME}' already exists. Attempting to sync.")
            vsc.get_index(endpoint_name=VECTOR_SEARCH_ENDPOINT_NAME, index_name=VECTOR_SEARCH_INDEX_NAME).sync()
        else:
            logger.error(f"An error occurred while creating/updating the index: {e}")

setup_vector_search_index()

## 11. Multi-Agent System

In [None]:
class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], operator.add]
    user_query: str
    session_id: str
    intent: str
    intent_details: Dict[str, Any]
    sql_query: str
    sql_result: str
    rag_result: str
    error: str
    engineered_context: str
    final_response: Union[BaseResponse, str]
    from_cache: bool

llm = ChatOpenAI(model=ENDPOINT_NAME, api_key=DATABRICKS_TOKEN, base_url=f"{DATABRICKS_HOST}/serving-endpoints")
vsc = VectorSearchClient()
whisper_model = whisper.load_model("base")

def get_table_schemas(spark: SparkSession, catalog: str, schema: str) -> str:
    logger.info(f"Fetching schemas from {catalog}.{schema}")
    try:
        tables = spark.sql(f"SHOW TABLES IN {catalog}.{schema}")
        schema_details = [f"Table '{row['tableName']}': " + ", ".join([f"{c['col_name']} ({c['data_type']})" for c in spark.sql(f"DESCRIBE TABLE {catalog}.{schema}.{row['tableName']}").collect()]) for row in tables.collect()]
        return "\n".join(schema_details)
    except Exception as e:
        logger.error(f"Could not retrieve table schemas: {e}")
        return "Could not retrieve table schemas."

def cache_agent(state: AgentState) -> dict:
    spark = SparkSession.builder.appName("CacheAgent").getOrCreate()
    cached_response = check_cache(spark, state['user_query'])
    if cached_response:
        return {"final_response": cached_response, "from_cache": True}
    return {"from_cache": False}

def router_agent(state: AgentState) -> dict:
    state['messages'] = [HumanMessage(content=state['user_query'])]
    return {"intent": "clarify_intent"}

def intent_agent(state: AgentState) -> dict:
    prompt = ChatPromptTemplate.from_template(INTENT_AGENT_PROMPT)
    chain = prompt | llm | StrOutputParser()
    try:
        response_json = chain.invoke({"user_query": state['user_query']})
        intent_details = json.loads(response_json)
        logger.info(f"Detected Intent Details: {intent_details}")
        intent = intent_details.get("intent", "Descriptive").lower()
        return {"intent": intent, "intent_details": intent_details}
    except (json.JSONDecodeError, KeyError) as e:
        logger.error(f"Failed to parse intent JSON: {e}. Defaulting to Descriptive.")
        return {"intent": "descriptive", "intent_details": {"intent": "Descriptive"}}

def text_to_sql_agent(state: AgentState) -> dict:
    spark = SparkSession.builder.appName("MultiAgentSystem").getOrCreate()
    schema_context = get_table_schemas(spark, UNITY_CATALOG_NAME, UNITY_CATALOG_SCHEMA_NAME)
    prompt = ChatPromptTemplate.from_template(TEXT_TO_SQL_AGENT_PROMPT)
    chain = prompt | llm | StrOutputParser()
    sql_query = chain.invoke({"catalog": UNITY_CATALOG_NAME, "schema": schema_context, "user_query": state['user_query']}).strip()
    for attempt in range(2):
        try:
            logger.info(f"Executing SQL Query (Attempt {attempt + 1}): {sql_query}")
            result_df = spark.sql(sql_query)
            result = result_df.limit(100).toPandas().to_string()
            return {"sql_query": sql_query, "sql_result": result}
        except Exception as e:
            error_message = str(e)
            logger.warning(f"SQL Query failed: {error_message}")
            correction_prompt = ChatPromptTemplate.from_template(SQL_CORRECTION_PROMPT)
            correction_chain = correction_prompt | llm | StrOutputParser()
            sql_query = correction_chain.invoke({"user_query": state['user_query'], "faulty_sql": sql_query, "error_message": error_message}).strip()
    return {"error": f"Failed to execute SQL query after correction: {error_message}"}

def rag_agent(state: AgentState, vsc: VectorSearchClient) -> dict:
    try:
        index = DatabricksVectorSearch(vsc.get_index(endpoint_name=VECTOR_SEARCH_ENDPOINT_NAME, index_name=VECTOR_SEARCH_INDEX_NAME))
        retriever = index.as_retriever()
        document_context = retriever.invoke(state['user_query'])
        spark = SparkSession.builder.appName("RAGAgent").getOrCreate()
        schema_context = get_table_schemas(spark, UNITY_CATALOG_NAME, UNITY_CATALOG_SCHEMA_NAME)
        prompt = ChatPromptTemplate.from_template(RAG_AGENT_PROMPT)
        chain = prompt | llm | StrOutputParser()
        result = chain.invoke({"question": state['user_query'], "document_context": document_context, "schema_context": schema_context})
        return {"rag_result": result}
    except Exception as e:
        logger.error(f"Error in RAG agent: {e}")
        return {"error": f"Error in RAG agent: {e}"}

def mixed_intent_agent(state: AgentState, vsc: VectorSearchClient) -> dict:
    rag_output = rag_agent(state, vsc)
    if "error" in rag_output:
        return {"error": rag_output['error']}
    rag_context = rag_output.get("rag_result", "")
    spark = SparkSession.builder.appName("MultiAgentSystem").getOrCreate()
    schema_context = get_table_schemas(spark, UNITY_CATALOG_NAME, UNITY_CATALOG_SCHEMA_NAME)
    prompt = ChatPromptTemplate.from_template(MIXED_INTENT_SQL_PROMPT)
    chain = prompt | llm | StrOutputParser()
    sql_query = chain.invoke({"rag_context": rag_context, "catalog": UNITY_CATALOG_NAME, "schema": schema_context, "user_query": state['user_query']}).strip()
    for attempt in range(2):
        try:
            logger.info(f"Executing Mixed-Intent SQL Query (Attempt {attempt + 1}): {sql_query}")
            result_df = spark.sql(sql_query)
            sql_result = result_df.limit(100).toPandas().to_string()
            return {"rag_result": rag_context, "sql_result": sql_result, "sql_query": sql_query}
        except Exception as e:
            error_message = str(e)
            logger.warning(f"Mixed-Intent SQL Query failed: {error_message}")
            correction_prompt = ChatPromptTemplate.from_template(SQL_CORRECTION_PROMPT)
            correction_chain = correction_prompt | llm | StrOutputParser()
            sql_query = correction_chain.invoke({"user_query": state['user_query'], "faulty_sql": sql_query, "error_message": error_message}).strip()
    return {"error": f"Failed to execute mixed-intent SQL query after correction: {error_message}"}

def voice_summarization_agent(state: AgentState) -> dict:
    user_query = state['user_query']
    match = re.search(r"[\s\"']?([^\s\"'/]+\.(mp3|wav|m4a|flac))[\s\"']?", user_query)
    if not match:
        return {"error": "Could not find a valid audio file name in the query."}
    audio_filename = match.group(1)
    audio_file_path = os.path.join(UNSTRUCTURED_DATA_PATH, audio_filename)
    if not os.path.exists(audio_file_path):
        return {"error": f"Audio file '{audio_filename}' not found in the configured data path."}
    try:
        logger.info(f"Transcribing audio file: {audio_file_path}")
        result = whisper_model.transcribe(audio_file_path)
        transcription = result["text"]
        state['messages'].append(AIMessage(content=f"Full transcription:\n{transcription}"))
        logger.info("Transcription complete. Generating summary...")
        prompt = ChatPromptTemplate.from_template(VOICE_SUMMARY_PROMPT)
        chain = prompt | llm | StrOutputParser()
        summary = chain.invoke({"transcription": transcription})
        response_model = VoiceResponse(session_id=state['session_id'], user_query=user_query, summary=summary, transcription=transcription)
        return {"final_response": response_model}
    except Exception as e:
        logger.error(f"Error during voice processing: {e}")
        return {"error": f"Error during voice processing: {e}"}

def context_engineer_agent(state: AgentState) -> dict:
    context_parts = []
    if state.get("messages"):
        history = "\n".join([f"{msg.type}: {msg.content}" for msg in state["messages"]])
        context_parts.append(f"Conversation History:\n{history}")
    if state.get("rag_result"):
        context_parts.append(f"Information from Documents:\n{state['rag_result']}")
    if state.get("sql_result"):
        context_parts.append(f"Database Query Results:\n{state['sql_result']}")
    if not context_parts:
        return {"error": "No context was generated by the previous steps."}
    engineered_context = "\n\n---\n\n".join(context_parts)
    logger.info(f"Engineered Context:\n{engineered_context}")
    return {"engineered_context": engineered_context}

def response_agent(state: AgentState) -> dict:
    if isinstance(state.get("final_response"), BaseResponse):
        return {"final_response": state["final_response"].model_dump_json()}
    engineered_context = state.get("engineered_context")
    if not engineered_context:
        return llm_error_agent({"error": "Context engineering failed."})
    prompt = ChatPromptTemplate.from_template(RESPONSE_AGENT_SIMPLE_PROMPT)
    chain = prompt | llm | StrOutputParser()
    final_answer = chain.invoke({"user_query": state['user_query'], "engineered_context": engineered_context})
    response_model = RAGResponse(session_id=state['session_id'], user_query=state['user_query'], rag_content=final_answer, sources=["Synthesized from multiple sources"])
    return {"final_response": response_model.model_dump_json()}

def create_optimized_error_log_table(spark: SparkSession):
    table_name = ERROR_LOG_TABLE
    if not spark.catalog.tableExists(table_name):
        logger.info(f"Error log table '{table_name}' not found. Creating it.")
        spark.sql(f"""CREATE TABLE {table_name} (timestamp TIMESTAMP, session_id STRING, user_query STRING, error_message STRING) USING DELTA TBLPROPERTIES (delta.autoOptimize.optimizeWrite = true, delta.autoOptimize.autoCompact = true)""")

def llm_error_agent(state: AgentState) -> dict:
    error_message = state.get("error", "An unknown error occurred.")
    spark = SparkSession.builder.appName("ErrorLogging").getOrCreate()
    create_optimized_error_log_table(spark)
    try:
        error_data = [(state.get("session_id"), state.get("user_query"), error_message)]
        error_df = spark.createDataFrame(error_data, ["session_id", "user_query", "error_message"])
        error_df = error_df.withColumn("timestamp", current_timestamp())
        error_df.write.format("delta").mode("append").saveAsTable(ERROR_LOG_TABLE)
        logger.info(f"Error successfully logged to {ERROR_LOG_TABLE}")
    except Exception as e:
        logger.critical(f"Failed to log error to Delta table: {e}")
    prompt = ChatPromptTemplate.from_template(ERROR_AGENT_PROMPT)
    chain = prompt | llm | StrOutputParser()
    suggestion = chain.invoke({"error": error_message})
    error_model = ErrorResponse(session_id=state.get("session_id"), user_query=state.get("user_query"), error_message=error_message, suggested_fix=suggestion)
    return {"final_response": error_model.model_dump_json()}

def save_cache_agent(state: AgentState) -> dict:
    spark = SparkSession.builder.appName("SaveCacheAgent").getOrCreate()
    response_to_cache = state['final_response'] if isinstance(state['final_response'], str) else state['final_response']
    save_to_cache(spark, state['user_query'], response_to_cache, state['session_id'])
    return {}

## 12. Build the Graph

In [None]:
workflow = StateGraph(AgentState)
rag_agent_with_vsc = partial(rag_agent, vsc=vsc)
mixed_intent_agent_with_vsc = partial(mixed_intent_agent, vsc=vsc)
workflow.add_node("cache_agent", cache_agent)
workflow.add_node("router", router_agent)
workflow.add_node("intent_agent", intent_agent)
workflow.add_node("rag_agent", rag_agent_with_vsc)
workflow.add_node("text_to_sql_agent", text_to_sql_agent)
workflow.add_node("mixed_intent_agent", mixed_intent_agent_with_vsc)
workflow.add_node("voice_summarization_agent", voice_summarization_agent)
workflow.add_node("context_engineer_agent", context_engineer_agent)
workflow.add_node("response_agent", response_agent)
workflow.add_node("llm_error_agent", llm_error_agent)
workflow.add_node("save_cache", save_cache_agent)
workflow.set_entry_point("cache_agent")
workflow.add_conditional_edges("cache_agent", lambda state: "continue" if not state.get("from_cache") else "end", {"continue": "router", "end": END})
workflow.add_edge("router", "intent_agent")
workflow.add_conditional_edges("intent_agent", lambda state: state["intent"], {"descriptive": "text_to_sql_agent", "analytical": "text_to_sql_agent", "drilldown": "text_to_sql_agent", "visualization": "text_to_sql_agent", "unstructured": "rag_agent", "mixed": "mixed_intent_agent", "voice": "voice_summarization_agent"})
workflow.add_conditional_edges("rag_agent", lambda state: "error" if state.get("error") else "continue", {"continue": "context_engineer_agent", "error": "llm_error_agent"})
workflow.add_conditional_edges("text_to_sql_agent", lambda state: "error" if state.get("error") else "continue", {"continue": "context_engineer_agent", "error": "llm_error_agent"})
workflow.add_conditional_edges("mixed_intent_agent", lambda state: "error" if state.get("error") else "continue", {"continue": "context_engineer_agent", "error": "llm_error_agent"})
workflow.add_conditional_edges("voice_summarization_agent", lambda state: "error" if state.get("error") else "continue", {"continue": "context_engineer_agent", "error": "llm_error_agent"})
workflow.add_edge("context_engineer_agent", "response_agent")
workflow.add_edge("response_agent", "save_cache")
workflow.add_edge("llm_error_agent", "save_cache")
workflow.add_edge("save_cache", END)
langgraph_app = workflow.compile()

## 13. Run the Agent

In [None]:
user_query = "YOUR_QUERY_HERE"
session_id = str(uuid.uuid4())
inputs = {"user_query": user_query, "session_id": session_id}
for output in langgraph_app.stream(inputs, {"recursion_limit": 10}):
    for key, value in output.items():
        print(f"Output from node '{key}':")
        print("---")
        print(value)
    print("\n---\n")