# 02_ai_report_generator
Hybrid LLM + structured SQL generator; validates SQL; creates UC view; tags lineage.

In [None]:
import json, re
from datetime import datetime
from openai import OpenAI

catalog = "finance"
schema = "kyc_ml"
base_table = f"{catalog}.{schema}.customer_enriched"
metadata_dbfs_path = "/dbfs/FileStore/report_metadata/report_definitions.json"

try:
    client = OpenAI(api_key=dbutils.secrets.get(scope="llm", key="openai_api_key"))
except Exception as e:
    print("LLM client not initialized; falling back to structured generation. Error:", e)
    client = None

def extract_sql(text):
    m = re.search(r'(WITH\b|SELECT\b)', text, flags=re.IGNORECASE)
    return text[m.start():].strip() if m else text.strip()

def add_aliases_for_aggs(sql_text):
    def repl(m):
        expr = m.group(0)
        clean = re.sub(r'[^0-9a-zA-Z]', '_', expr)
        alias = clean.lower()
        return f"{expr} AS {alias}"
    new_sql = re.sub(r'\b(avg|sum|min|max)\([^\)]+\)(?!\s+AS)', repl, sql_text, flags=re.IGNORECASE)
    return new_sql

def validate_sql_and_preview(sql_text, max_rows=5, allowed_prefixes=None):
    errors = []
    m = re.search(r'(WITH\b|SELECT\b)', sql_text, flags=re.IGNORECASE)
    if not m:
        errors.append('No SELECT or WITH found in generated text.')
        return None, errors
    clean_sql = sql_text[m.start():].strip()
    if allowed_prefixes:
        tables = re.findall(r'\b[A-Za-z0-9_]+\.[A-Za-z0-9_]+\.[A-Za-z0-9_]+\b', clean_sql)
        for t in tables:
            if not any(t.startswith(pref) for pref in allowed_prefixes):
                errors.append(f'Table {t} outside allowed prefixes {allowed_prefixes}')
    if re.search(r'\b(avg|sum|min|max)\([^\)]+\)(?!\s+AS)', clean_sql, flags=re.IGNORECASE):
        errors.append('Aggregations without AS alias found.')
    try:
        spark.sql(f'EXPLAIN {clean_sql}').show(truncate=False)
    except Exception as e:
        errors.append(f'EXPLAIN failed: {e}')
    try:
        display(spark.sql(f'SELECT * FROM ({clean_sql}) tmp_preview LIMIT {max_rows}'))
    except Exception as e:
        errors.append(f'Preview failed: {e}')
    return clean_sql, errors

with open(metadata_dbfs_path,'r') as f:
    reports = json.load(f)

for rpt in reports:
    report_name = rpt['report_name']
    prompt_nl = rpt.get('natural_language')
    if prompt_nl:
        prompt = prompt_nl
    else:
        dims = ', '.join(rpt.get('dimensions', []))
        meas = ', '.join(rpt.get('measures', []))
        filters = rpt.get('filters','')
        prompt = f"Return {meas} grouped by {dims} from table {base_table} {('where ' + filters) if filters else ''}"
    schema_info = 'Columns: step (int), type (string), amount (double), isFraud (int), date (date), accountType (string)'
    full_prompt = f"You are a Spark SQL generator. {schema_info}. Use table {base_table}. Return ONLY a valid Spark SQL query. Request: {prompt}"
    print('\n--- Generating for', report_name, '---')
    generated = None
    if client:
        try:
            res = client.chat.completions.create(model='gpt-4', messages=[{'role':'user','content':full_prompt}], temperature=0)
            generated = res.choices[0].message.content
        except Exception as e:
            print('LLM call failed:', e)
            generated = None
    if not generated:
        if 'dimensions' in rpt and 'measures' in rpt:
            dims = ', '.join(rpt['dimensions'])
            meas = ', '.join(rpt['measures'])
            where = ('WHERE ' + rpt['filters']) if rpt.get('filters') else ''
            generated = f"SELECT {dims}, {meas} FROM {base_table} {where} GROUP BY {dims}"
        else:
            print('Insufficient metadata to generate SQL for', report_name)
            continue

    sql_candidate = extract_sql(generated)
    sql_candidate = add_aliases_for_aggs(sql_candidate)
    clean_sql, errs = validate_sql_and_preview(sql_candidate, allowed_prefixes=[f"{catalog}.{schema}"])
    if errs:
        print('Validation errors:', errs)
        sql_candidate = add_aliases_for_aggs(sql_candidate)
        clean_sql, errs = validate_sql_and_preview(sql_candidate, allowed_prefixes=[f"{catalog}.{schema}"])
    if errs:
        print('Skipping view creation due to errors for', report_name)
        continue

    view_name = f"{catalog}.{schema}.{report_name}"
    try:
        spark.sql(f"CREATE OR REPLACE VIEW {view_name} AS {clean_sql}")
        props = {
            'lineage.generated_by':'LLM',
            'lineage.prompt': prompt,
            'lineage.created_at': datetime.utcnow().isoformat()
        }
        for k,v in props.items():
            spark.sql(f"ALTER VIEW {view_name} SET TBLPROPERTIES ('{k}' = '{v}')")
        print('Created and tagged view:', view_name)
    except Exception as e:
        print('Failed to create view:', e)

print('\nAll reports processed.')