In [0]:
import dbruntime.databricks_repl_context
from databricks.sdk import WorkspaceClient
from databricks.sdk.service import serving
import re

# 1. Define the System Prompt
# This tells the AI how to behave
SYSTEM_PROMPT = """
You are a SQL expert for an e-commerce company. 
You have access to a table: ecomm_data_project.gold.ecom_one_big_table.
The columns are: User_ID, Country, User_Gender, User_ProductsSold, User_ProductsWished, User_AccountAge, Country_TotalSellers, Country_TotalBuyers, Country_BuyerRatio.

Given a user's question, write a valid Spark SQL query. 
Return ONLY the SQL code, nothing else.

Rules:
1. Return ONLY the raw SQL query. No markdown formatting (no ```sql).
2. Use standard Spark SQL syntax.
3. If you cannot answer based on the schema, say 'I cannot answer that.'

"""

# 2. SQL Guardrail Function (MNC Security Practice)
def is_safe_query(sql_query):
    # Prevent DDL/DML injection (Delete, Drop, etc.)
    forbidden = ["DROP", "DELETE", "TRUNCATE", "INSERT", "UPDATE", "GRANT", "REVOKE"]
    for word in forbidden:
        if re.search(rf"\b{word}\b", sql_query, re.IGNORECASE):
            return False, f"Safety violation: {word} is not allowed."
    if not sql_query.strip().upper().startswith(("SELECT", "WITH")):
        return False, "Query must start with SELECT or WITH."
    return True, "Safe"

# 3. The Core Agent Logic
def ask_data_agent(user_question):
    # Use ai_query() to generate the SQL
    # This invokes a managed foundation model endpoint in Databricks
    prompt = f"{SYSTEM_PROMPT}\n\nQuestion: {user_question}"
    
    sql_gen = f"SELECT ai_query('databricks-meta-llama-3-1-70b-instruct', '{prompt}')"
    generated_sql = spark.sql(sql_gen).collect()[0][0]
    
    # Validation & Execution
    safe, message = is_safe_query(generated_sql)
    if safe:
        print(f"Executing: {generated_sql}")
        return spark.sql(generated_sql)
    else:
        raise Exception(message)

# Example usage:
# df = ask_data_agent("Which top 3 countries have the highest buyer ratio?")
# display(df)