In [1]:
from autogen import GroupChat, GroupChatManager, register_function
from typing import Annotated

In [2]:
api_key = 
config_list = [
    {
        "model": "gpt-4-turbo",
        "api_key": api_key,
        "base_url": "https://api.openai.com/v1"
    }
]

In [3]:
import sqlite3
from typing import Annotated

def fetch_schema(DB_PATH: Annotated[str, "Path to the Database file"]) -> Annotated[str, "Resulting Schema from the DB"]:
    """Fetches the schema of the database (table names, columns, types, and sample values)."""
    
    conn = sqlite3.connect(DB_PATH)
    cursor = conn.cursor()
    
    # Get all table names
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = cursor.fetchall()
    
    schema = {}
    
    for table in tables:
        table_name = table[0]
        
        # Get column details
        cursor.execute(f"PRAGMA table_info({table_name});")
        columns = cursor.fetchall()
        
        # Get sample data (first 3 rows) from the table
        cursor.execute(f"SELECT * FROM {table_name} LIMIT 3;")
        sample_data = cursor.fetchall()
        
        schema[table_name] = []
        for col in columns:
            col_name, col_type = col[1], col[2]
            # Extract sample values for the column
            col_index = columns.index(col)
            sample_values = [row[col_index] for row in sample_data] if sample_data else ["No data"]
            
            schema[table_name].append({
                "name": col_name,
                "type": col_type,
                "samples": sample_values
            })
    
    conn.close()

    schema_str = "\n".join([
        "Table: {}\nColumns:\n{}".format(
            table, "\n".join(
                ["  - {} ({}) | Samples: {}".format(col["name"], col["type"], ", ".join(map(str, col["samples"]))) 
                 for col in cols]
            )
        )
        for table, cols in schema.items()
    ])
    
    return schema_str

def validate_sql_query(sql_query: str, db_path: str) -> str:
    """
    Helper function to validate an SQL query.
    Executes the query on the specified database and returns 'success' if valid, otherwise 'failure'.
    
    Args:
        sql_query (str): The SQL query to validate.
        db_path (str): Path to the database file.

    Returns:
        str: 'success' if the query is valid, otherwise 'failure: <error_message>'.
    """
    import sqlite3  # Or use the relevant database connector

    try:
        conn = sqlite3.connect(db_path)  # Connect to the provided database
        cursor = conn.cursor()
        cursor.execute("EXPLAIN QUERY PLAN " + sql_query)  # Lightweight validation
        conn.close()
        return "success"
    except Exception as e:
        return f"failure: {str(e)}"


import sqlite3  # Or use the relevant database connector

def validate_and_execute_sql_query(sql_query: Annotated[str, "Result of SQL query from SQLAgent"], db_path: Annotated[str, "Path to the Database"]) -> Annotated[dict, "Result of running the query with keys being either success or failure"]:
    """
    Validates and executes an SQL query on the specified database.
    
    Args:
        sql_query (str): The SQL query to validate and execute.
        db_path (str): Path to the database file.

    Returns:
        str: 'success' if the query is valid but does not return results (e.g., UPDATE).
        list: Query results if the query is valid and returns data (e.g., SELECT).
        str: 'failure: <error_message>' if the query is invalid.
    """
    try:
        conn = sqlite3.connect(db_path)
        cursor = conn.cursor()
        
        # Validate query using EXPLAIN QUERY PLAN
        cursor.execute("EXPLAIN QUERY PLAN " + sql_query)
        
        # If validation passes, execute the actual query
        cursor.execute(sql_query)

        if sql_query.strip().lower().startswith("select"):
            # Fetch and return results for SELECT queries
            results = cursor.fetchall()
            conn.close()
            return results  # Returns list of tuples

        conn.commit()  # Commit changes for INSERT, UPDATE, DELETE
        conn.close()
        return {"success": results}  # No need to return results for non-SELECT queries

    except Exception as e:
        return f"failure: {str(e)}"

In [4]:
import autogen

In [5]:
user_query = "what was the result of india's last cricket match"
user_query = "which is the least preferred product among customers ?"
DB_PATH = "blinkit_sales.db"

In [6]:
user_proxy = autogen.ConversableAgent(
    name="Admin",
    code_execution_config=False,
    llm_config={"config_list": config_list},
    human_input_mode="NEVER",
    system_message = f"Analyze the user query :{user_query}, and pass it to SchemaAgent for further process.",
    is_termination_msg=lambda msg: (
                                        msg.get("content") is not None 
                                        and "thank you" in msg["content"].lower() 
                                        and msg["content"].strip().lower() != "thank you."
                                    )
)

    # lambda msg: isinstance(msg, str) and msg.strip().lower() == "terminate"
# lambda msg: msg.get("content") is not None and "thank you" in msg["content"]
# lambda msg: msg.get("content") is not None and "thank you" in msg["content"].lower()


schema_agent = autogen.ConversableAgent(
    name="SchemaAgent",
    llm_config={"config_list": config_list},
    system_message=f"As a database expert, your task is to inspect the database and generate a structured schema for the Database File : {DB_PATH}, by calling fetch_schema({DB_PATH}) function"
                    "The schema should include table names, columns, and their data types."
                    f"Once the structured schema is ready, analyze the user query: {user_query} and determine which tables would be most relevant in answering it."
                    "If no relevant tables can be identified, alert the user with the message: 'The query cannot be answered as no relevant tables were found in the database.' Provide a reason for this and terminate the chat by saying thank you"
                    "Once you have completed this analysis,terminate the process by saying 'terminate' and return the gathered information to SQLAgent."
)


sql_agent = autogen.ConversableAgent(
    name="SQLAgent",
    llm_config={"config_list": config_list},
    system_message=(
        "As a SQL expert, your task is to generate an accurate SQL query based on the provided database schema and user query. "
        "You will receive the structured schema and the user query as input from SchemaAgent."
        "Analyze the schema to determine the relevant tables and construct an optimized SQL query that accurately answers the user's request. "
        "Once you have generated the SQL query, terminate the process by saying 'terminate' and return the final SQL query."
        "If no relevant tables can be identified, alert the user with the message: 'The query cannot be answered as no relevant tables were found in the database.' Provide a reason for this and terminate the chat by saying thank you"
    )
)

validation_agent = autogen.ConversableAgent(
    name="ValidationAgent",
    llm_config={"config_list": config_list},
    system_message=(
        "As a SQL validation expert, your task is to validate the SQL query generated by SQLAgent. "
        f"You will receive an SQL query and the database path : {DB_PATH} as input. "
        f"Use the provided helper function `validate_and_execute_sql_query(sql_query, {DB_PATH})` to check whether the query is correct. "
        "If the query is valid, respond with 'SQL query is valid' and return the result to ResultFormattingAgent."
        "If the query is invalid, respond with 'SQL query is invalid' along with the error message to SQLErrorCorrectionAgent."
        "Once validation is complete, terminate the process by saying 'terminate'."
    )
)

sql_error_correction_agent = autogen.ConversableAgent(
    name="SQLErrorCorrectionAgent",
    llm_config={"config_list": config_list},
    system_message=(
        "You are an expert in SQL query error correction. "
        "Your task is to correct SQL query errors based on the given failure message and database schema. "
        "You will receive a failed SQL query along with an error message explaining why it failed. "
        "Ensure the corrected query follows proper SQL syntax and is aligned with the database schema. "
        "Once you have corrected the query, return the fixed version to the ValidationAgent for validation."
    )
)

result_formatting_agent = autogen.ConversableAgent(
    name="ResultFormattingAgent",
    llm_config={"config_list": config_list},
    system_message=(
        "You are responsible for formatting database query results into a structured, human-readable format. "
        "When provided with raw query results, process them into a clear tabular format, numbered list, or summary, depending on the result type. "
        "For large datasets, summarize key trends instead of listing every record. "
        "Return the formatted output in a way that is easy to understand."
        "Once you have the final result, output the result and finally close the conversation by saying thank you"
    )
)

In [7]:
from autogen import register_function

register_function(
    fetch_schema,
    caller=schema_agent,
    executor=schema_agent,
    name="FetchingSchemaSchemaAgent",
    description="Get Fetch Schema",
)

register_function(
    fetch_schema,
    caller=user_proxy,
    executor=user_proxy,
    name="FetchingSchemaUserProxy",
    description="Get Fetch Schema",
)

register_function(
    validate_and_execute_sql_query,
    caller=sql_agent,
    executor=validation_agent,
    name="ValidationFromSQLAgent",
    description="Validationg the query",
)

register_function(
    validate_and_execute_sql_query,
    caller=validation_agent,
    executor=validation_agent,
    name="ValidationToAgent",
    description="Validationg the query",
)


In [8]:
groupchat = autogen.GroupChat(
    agents=[user_proxy, schema_agent, sql_agent, validation_agent, sql_error_correction_agent, result_formatting_agent],
    messages=[],
    max_round=5,
    allowed_or_disallowed_speaker_transitions={
    user_proxy: [schema_agent, sql_agent, validation_agent, sql_error_correction_agent, result_formatting_agent],
    schema_agent: [user_proxy, sql_agent],
    sql_agent: [user_proxy, validation_agent],
    validation_agent: [sql_error_correction_agent, result_formatting_agent],
    sql_error_correction_agent: [user_proxy, validation_agent],
    result_formatting_agent: [user_proxy]
    },
    speaker_transitions_type="allowed",
)


# groupchat = autogen.GroupChat(
#     agents=[user_proxy, schema_agent, sql_agent, validation_agent, sql_error_correction_agent, result_formatting_agent],
#     messages=[],
#     max_round=5
# )


manager = autogen.GroupChatManager(
    groupchat=groupchat, llm_config={"config_list": config_list},
    is_termination_msg=lambda msg: (
    isinstance(msg, dict) and
    msg.get("name") == "ResultFormattingAgent" and
    bool(msg.get("content"))  # Ensure content is not empty
)
)

In [23]:
chat_result = user_proxy.initiate_chat(
    manager,
    max_turns=3,
    message = f"Analyze the user query : {user_query}"
)


[33mAdmin[0m (to chat_manager):

Analyze the user query : what was the result of india's last cricket match

--------------------------------------------------------------------------------
[32m
Next speaker: SchemaAgent
[0m
[31m
>>>>>>>> USING AUTO REPLY...[0m
[33mSchemaAgent[0m (to chat_manager):

[32m***** Suggested tool call (call_YbMcFhG41cXw6MCAgEe71o9z): FetchingSchemaSchemaAgent *****[0m
Arguments: 
{"DB_PATH":"blinkit_sales.db"}
[32m******************************************************************************************[0m

--------------------------------------------------------------------------------
[32m
Next speaker: SchemaAgent
[0m
[31m
>>>>>>>> USING AUTO REPLY...[0m
[35m
>>>>>>>> EXECUTING FUNCTION FetchingSchemaSchemaAgent...
Call ID: call_YbMcFhG41cXw6MCAgEe71o9z
Input arguments: {'DB_PATH': 'blinkit_sales.db'}[0m
[33mSchemaAgent[0m (to chat_manager):

[32m***** Response from calling tool (call_YbMcFhG41cXw6MCAgEe71o9z) *****[0m
Table: blinkit

In [9]:
chat_result = user_proxy.initiate_chat(
    manager,
    max_turns=3,
    message = f"Analyze the user query : {user_query}"
)

[33mAdmin[0m (to chat_manager):

Analyze the user query : which is the least preferred product among customers ?

--------------------------------------------------------------------------------
[32m
Next speaker: SchemaAgent
[0m
[31m
>>>>>>>> USING AUTO REPLY...[0m
[33mSchemaAgent[0m (to chat_manager):

[32m***** Suggested tool call (call_ZFoWfciXhvRDctH1zsPQUP8d): FetchingSchemaSchemaAgent *****[0m
Arguments: 
{"DB_PATH":"blinkit_sales.db"}
[32m******************************************************************************************[0m

--------------------------------------------------------------------------------
[32m
Next speaker: SchemaAgent
[0m
[31m
>>>>>>>> USING AUTO REPLY...[0m
[35m
>>>>>>>> EXECUTING FUNCTION FetchingSchemaSchemaAgent...
Call ID: call_ZFoWfciXhvRDctH1zsPQUP8d
Input arguments: {'DB_PATH': 'blinkit_sales.db'}[0m
[33mSchemaAgent[0m (to chat_manager):

[32m***** Response from calling tool (call_ZFoWfciXhvRDctH1zsPQUP8d) *****[0m
Table: bl