In [None]:
import os
import mysql.connector
import pandas as pd
from tabulate import tabulate
from dotenv import load_dotenv
import re
import sqlparse
from google import genai
from google.genai import types
load_dotenv()
GEMINI_API_KEY = os.environ.get("GEMINI")
if not GEMINI_API_KEY:
    raise Exception("❌ GEMINI_API_KEY not found. Please check your .env file or environment.")

DB_NAME = "online_store"
DB_USER = "root"
DB_PASS = ""
DB_HOST = "localhost"

def test_db_connection():
    try:
        conn = mysql.connector.connect(
            host=DB_HOST,
            user=DB_USER,
            password=DB_PASS,
            database=DB_NAME
        )
        conn.close()
        print("✅ MySQL connection successful!")
    except Exception as e:
        print(f"❌ MySQL connection failed:\n{e}")

# Test the DB connection
test_db_connection()

def run_sql(sql):
    try:
        conn = mysql.connector.connect(
            host=DB_HOST,
            user=DB_USER,
            password=DB_PASS,
            database=DB_NAME
        )
        cursor = conn.cursor()
        cursor.execute(sql)
        columns = [desc[0] for desc in cursor.description]
        rows = cursor.fetchall()
        conn.close()
        return pd.DataFrame(rows, columns=columns)
    except Exception as e:
        return f"❌ SQL Execution Error:\n{e}"

def is_safe_select(sql):
    parsed = sqlparse.parse(sql)
    for stmt in parsed:
        if stmt.get_type() != 'SELECT':
            return False
    return True

def add_limit(sql, limit=100):
    if re.search(r"\blimit\b", sql, re.IGNORECASE):
        return sql
    sql = sql.rstrip(';')
    return f"{sql} LIMIT {limit};"
SYSTEM_PROMPT = """
You are a MySQL query generator for an eCommerce database.

RULES:
1. Return ONLY a valid MySQL SELECT query.
2. Do NOT include backticks, markdown, explanations, or natural language.
3. Do NOT prefix with "SQL:" or wrap with quotes or code blocks.
4. The entire output must be ONE LINE of raw SQL.

SCHEMA:
- users(user_id, name, email, password_hash, phone, created_at, updated_at)
- categories(category_id, name, description)
- products(product_id, category_id, name, description, base_price, brand, image_url, created_at)
- product_variants(variant_id, product_id, sku, color, size, additional_price)
- inventory(variant_id, quantity, last_updated)
- orders(order_id, user_id, order_date, status, total_amount, shipping_address)
- order_items(order_item_id, order_id, variant_id, quantity, price)
- payments(payment_id, order_id, payment_method, payment_status, paid_at)
- shipping(shipping_id, order_id, carrier, tracking_number, status, estimated_delivery_date)
- reviews(review_id, user_id, product_id, rating, comment, created_at)
"""

def get_sql_from_gemini(user_question):
    client = genai.Client(api_key=GEMINI_API_KEY)
    model = "gemini-2.5-flash-preview-04-17"

    prompt = SYSTEM_PROMPT + f"\nUser Question: {user_question}\n"
    contents = [
        types.Content(
            role="user",
            parts=[
                types.Part.from_text(text=prompt),
            ],
        ),
    ]
    generate_content_config = types.GenerateContentConfig(
        response_mime_type="text/plain",
    )

    sql_result = ""
    for chunk in client.models.generate_content_stream(
        model=model,
        contents=contents,
        config=generate_content_config,
    ):
        sql_result += chunk.text
    return sql_result.strip()


def chat_with_db_gemini(user_question):
    print(f"\n🗨️ You: {user_question}\n")

    sql_reply = get_sql_from_gemini(user_question)
    print(f"🤖 Gemini SQL:\n{sql_reply}\n")

    # Clean up Gemini output
    sql_line = sql_reply.splitlines()[0]
    sql_clean = re.sub(r"[`;]", "", sql_line).strip()
    final_sql = sql_clean + ";"

    # Validate it's a SELECT query
    if not is_safe_select(final_sql):
        print("❌ Gemini did not return a valid SELECT SQL query.")
        print(f"⚠️ Raw output: {sql_reply}")
        return

    # Add LIMIT
    final_sql = add_limit(final_sql)

    print(f"📥 Running SQL:\n{final_sql}\n")
    result = run_sql(final_sql)

    if isinstance(result, pd.DataFrame) and not result.empty:
        print("📊 Result:\n")
        print(tabulate(result, headers='keys', tablefmt='fancy_grid', showindex=False))
        return result
    elif isinstance(result, pd.DataFrame):
        print("⚠️ No results found.")
    else:
        print(result)

chat_with_db_gemini("hi")
