# 04_ai_sql_generator


In [None]:
# 04_ai_sql_generator
import json,re,uuid
from datetime import datetime
# OpenAI client used via dbutils secret 'llm' -> 'openai_api_key'
try:
    from openai import OpenAI
    client = OpenAI(api_key=dbutils.secrets.get(scope='llm', key='openai_api_key'))
except Exception as e:
    print('OpenAI client not available or secret missing:', e)
    client = None

metadata_dbfs_path = '/dbfs/FileStore/report_metadata/report_definitions.json'
with open(metadata_dbfs_path,'r') as f:
    reports = json.load(f)

current_user = spark.sql('SELECT current_user()').first()[0]

for rpt in reports:
    report_name = rpt['report_name']
    prompt = rpt.get('natural_language') or 'Generate report'
    full_prompt = f"You are a Spark SQL generator. Use table finance.kyc_gold.customer_enriched. Request: {prompt}"
    generated = None
    if client:
        try:
            res = client.chat.completions.create(model='gpt-4', messages=[{'role':'user','content':full_prompt}], temperature=0)
            generated = res.choices[0].message.content
        except Exception as e:
            print('LLM call failed:', e)
    if not generated:
        dims = ', '.join(rpt.get('dimensions',[]))
        meas = ', '.join(rpt.get('measures',[]))
        where = ('WHERE ' + rpt['filters']) if rpt.get('filters') else ''
        generated = f"SELECT {dims}, {meas} FROM finance.kyc_gold.customer_enriched {where} GROUP BY {dims}"
    m = re.search(r'(WITH\\b|SELECT\\b)', generated, flags=re.IGNORECASE)
    sql = generated[m.start():].strip() if m else generated
    cid = str(uuid.uuid4())
    row = [(cid, report_name, prompt, sql, 'PENDING', current_user, datetime.utcnow(), None, None, None)]
    cols = ['id','report_name','prompt','generated_sql','status','created_by','created_at','updated_by','updated_at','notes']
    spark.createDataFrame(row, schema=cols).write.format('delta').mode('append').saveAsTable('finance.kyc_gold.ai_sql_candidates')
    print('Wrote candidate', cid)
