In [None]:
import os
import json
import pymysql
import time
from dotenv import load_dotenv
from openai import AzureOpenAI

# Load environment variables from .env file
load_dotenv()

# LLM connection settings
api_key = os.getenv("api_key")
api_version = os.getenv("api_version")
azure_endpoint = os.getenv("azure_endpoint")
model_deployment_name = os.getenv("model_deployment_name")

# MySQL connection settings
mysql_host = os.getenv("MYSQL_HOST")
mysql_user = os.getenv("MYSQL_USER")
mysql_password = os.getenv("MYSQL_PASSWORD")
mysql_database = os.getenv("MYSQL_DATABASE")

# Function to get LLM response from Azure OpenAI
def get_llm_response(message_prompt):
    client = AzureOpenAI(
        api_key=api_key,
        api_version=api_version,
        azure_endpoint=azure_endpoint
    )

    response = client.chat.completions.create(
        model=model_deployment_name,
        messages=message_prompt
    )
    
    # Parse response and return generated SQL
    j = response.model_dump_json()
    j = json.loads(j)
    return j["choices"][0]["message"]["content"]

# Function to get a connection to the MySQL database
def get_db_connection():
    try:
        connection = pymysql.connect(
            host=mysql_host,
            user=mysql_user,
            password=mysql_password,
            database=mysql_database
        )
        return connection
    except pymysql.MySQLError as e:
        print(f"Error connecting to MySQL: {e}")
        return None

# Function to retrieve schema information from MySQL database
def get_schema_info(db_connection):
    cursor = db_connection.cursor()
    cursor.execute("SHOW TABLES;")
    tables = cursor.fetchall()
    
    schema_info = {}
    
    # Get columns of each table
    for table in tables:
        table_name = table[0]
        cursor.execute(f"DESCRIBE {table_name};")
        columns = cursor.fetchall()
        schema_info[table_name] = [column[0] for column in columns]
    
    cursor.close()
    return schema_info

# Function to generate SQL query using the LLM based on schema and user query
def generate_sql_query(schema_info, user_query):
    prompt = f"""
    You are a SQL expert. Based on the following database schema, generate a valid SQL query to answer the user's question. 
    Please do not include any extra text, just the SQL query.

    Database Schema:
    """
    
    # Dynamically add schema info (tables and columns)
    for table, columns in schema_info.items():
        prompt += f"\n{table} (" + ", ".join(columns) + ")"
    
    prompt += f"\n\nUser Query: {user_query}\n\nSQL Query:"

    # Send prompt to LLM
    message_prompt = [{"role": "system", "content": "You are a SQL expert."},
                      {"role": "user", "content": prompt}]
    
    sql_query = get_llm_response(message_prompt)
    return sql_query.strip()

# Function to sanitize the SQL query
def sanitize_sql_query(query):
    # Remove unwanted markdown formatting or backticks
    sanitized_query = query.replace("```sql", "").replace("```", "").strip()
    
    # If the response contains extra explanation text, discard it
    if "To better assist you" in sanitized_query or "Could you please specify" in sanitized_query:
        return ""
    
    return sanitized_query

# Function to run the SQL query on the database and get results
def run_sql_query(db_connection, query):
    try:
        # Sanitize the SQL query before execution
        query = sanitize_sql_query(query)
        
        if not query:
            return "Error: The generated SQL query is invalid or unclear."

        cursor = db_connection.cursor()
        cursor.execute(query)
        result = cursor.fetchall()
        cursor.close()
        return result
    except pymysql.MySQLError as e:
        print(f"Error executing query: {e}")
        return None

# Main function to process user input, generate SQL query, execute, and return results
def process_user_query(user_query):
    # Connect to the database
    db_connection = get_db_connection()
    if not db_connection:
        return "Error: Unable to connect to the MySQL database."
    
    # Retrieve schema information from MySQL
    schema_info = get_schema_info(db_connection)

    # Generate SQL query using the LLM
    start_time_gen = time.time()
    sql_query = generate_sql_query(schema_info, user_query)
    end_time_gen = time.time()
    print(f"Time taken to generate SQL query: {end_time_gen - start_time_gen:.4f} seconds")

    # Execute the generated SQL query
    start_time_exec = time.time()
    result = run_sql_query(db_connection, sql_query)
    end_time_exec = time.time()
    print(f"Time taken to execute SQL query: {end_time_exec - start_time_exec:.4f} seconds")

    if result is None or len(result) == 0:
        return "No results found."

    return result

# User interaction
def main():
    while True:
        # Ask the user for a query
        user_query = input("Please enter your query (or type 'exit' to quit): ")

        if user_query.lower() == 'exit':
            break

        # Process the query and return results
        result = process_user_query(user_query)
        
        # Display results
        if isinstance(result, list):
            print("Query results:")
            for row in result:
                print(row)
        else:
            print(result)

if __name__ == "__main__":
    main()


Time taken to generate SQL query: 2.3179 seconds
Time taken to execute SQL query: 0.0017 seconds
((42,),)
