In [0]:
# %pip install databricks-vectorsearch
# %pip install --upgrade databricks-langchain langchain-community langchain databricks-sql-connector

# %pip install -U mlflow

# %pip install databricks-sql-connector pandas
# dbutils.library.restartPython()
import mlflow
import os
from openai import OpenAI
from databricks.vector_search.client import VectorSearchClient

In [0]:
from databricks.sdk import WorkspaceClient

w = WorkspaceClient()
# w.secrets.create_scope(
#     scope="pizza-secrets"
# )
# Now you can add your secret as before
w.secrets.put_secret(
    scope="pizza-secrets",
    key="DATABRICKS_TOKEN",
    string_value=dbutils.notebook.entry_point
        .getDbutils()
        .notebook()
        .getContext()
        .apiToken()
        .get()
)

In [0]:
import os
ctx = dbutils.notebook.entry_point.getDbutils().notebook().getContext()

os.environ["DATABRICKS_HOST"] = ctx.apiUrl().get()
os.environ["DATABRICKS_TOKEN"] = ctx.apiToken().get()

In [0]:
# =========================
# ENTITY DEFINITIONS
# =========================
import yaml
REGISTRY_YAML = """entities:
  pizza:
    grain: pizza_name_id
    tables:
      monthly: retail_analytics.pizza.gold_pizza_metrics_monthly
      daily: retail_analytics.pizza.gold_pizza_metrics_daily
    time_dimensions: [day, month, year]

  ingredient:
    grain: ingredient
    tables:
      usage: retail_analytics.pizza.gold_ingredient_usage
    time_dimensions: [month, year]

# =========================
# BASE METRICS
# =========================
metrics:
  total_pizzas_sold:
    sql: SUM(total_pizzas_sold)
    table: retail_analytics.pizza.gold_pizza_metrics_monthly
    entity: pizza

  total_revenue:
    sql: SUM(total_revenue)
    table: retail_analytics.pizza.gold_pizza_metrics_monthly
    entity: pizza

  daily_revenue:
    sql: SUM(revenue_perday)
    table: retail_analytics.pizza.gold_pizza_metrics_daily
    entity: pizza

  ingredient_usage_grams:
    sql: SUM(total_ingredient_grams)
    table: retail_analytics.pizza.gold_ingredient_usage
    entity: ingredient

  ingredient_items_qty_grams:
    sql: SUM(items_qty_in_grams)
    table: retail_analytics.pizza.gold_ingredient_usage
    entity: ingredient

# =========================
# DERIVED METRICS
# =========================
derived_metrics:
  average_daily_revenue:
    sql: SUM(revenue_perday) / COUNT(DISTINCT day)
    table: retail_analytics.pizza.gold_pizza_metrics_daily
    entity: pizza

  revenue_per_pizza:
    sql: SUM(total_revenue) / SUM(total_pizzas_sold)
    table: retail_analytics.pizza.gold_pizza_metrics_monthly
    entity: pizza

  revenue_per_ingredient_gram:
    sql: SUM(total_revenue) / SUM(total_ingredient_grams)
    tables:
      revenue: retail_analytics.pizza.gold_pizza_metrics_monthly
      ingredient: retail_analytics.pizza.gold_ingredient_usage
    entity: ingredient

# =========================
# JOIN DEFINITIONS
# =========================
joins:
  pizza_to_ingredient:
    left_table: retail_analytics.pizza.gold_pizza_metrics_monthly
    right_table: retail_analytics.pizza.gold_ingredient_usage
    on:
      - pizza_name_id
      - month
      - year

  pizza_daily_to_monthly:
    left_table: retail_analytics.pizza.gold_pizza_metrics_daily
    right_table: retail_analytics.pizza.gold_pizza_metrics_monthly
    on:
      - pizza_name_id
      - month
      - year

# =========================
# TIME AGGREGATIONS
# =========================
time_grains:
  daily:
    columns: [day, month, year]

  monthly:
    columns: [month, year]

  yearly:
    columns: [year]

# =========================
# RANKING & WINDOW PATTERNS
# =========================
ranking_patterns:
  highest:
    order: DESC
    limit: 1

  lowest:
    order: ASC
    limit: 1

  top_n:
    order: DESC
    limit: N

  bottom_n:
    order: ASC
    limit: N

window_patterns:
  rank_within_time:
    sql: RANK() OVER (PARTITION BY {time} ORDER BY {metric} DESC)

  dense_rank_within_time:
    sql: DENSE_RANK() OVER (PARTITION BY {time} ORDER BY {metric} DESC)

# =========================
# TREND PATTERNS
# =========================
trend_patterns:
  month_over_month_change:
    sql: "{metric} - LAG({metric}) OVER (PARTITION BY {entity} ORDER BY year, month)"

  month_over_month_growth_pct:
    sql: "({metric} - LAG({metric}) OVER (PARTITION BY {entity} ORDER BY year, month))
          / NULLIF(LAG({metric}) OVER (PARTITION BY {entity} ORDER BY year, month), 0)"

  consecutive_decline:
    rule: "negative change for N consecutive months"

# =========================
# BUSINESS TERM DEFINITIONS
# =========================
business_terms:
  peak_month:
    definition: "month with highest total_revenue"

  bottleneck:
    definition: "ingredient with highest total_ingredient_grams during peak_month"

  least_used:
    definition: "minimum total_ingredient_grams"

  top_selling:
    definition: "highest total_pizzas_sold"

  volatile:
    definition: "highest standard deviation of monthly revenue"

  declining_sales:
    definition: "negative month_over_month_change for 3 consecutive months"

# =========================
# VALIDATION RULES
# =========================
validation:
  forbid_columns:
    - peak
    - best
    - highest
  require_time_for_trends: true
  require_join_for_derived_metrics: true"""
REGISTRY = yaml.safe_load(REGISTRY_YAML)


In [0]:
class SQLAgent:
    def __init__(self, registry):
        self.registry = registry

    def plan(self, intent):
        return intent

    def generate_sql(self, plan):
        entity = plan["entity"]
        metric = plan["metric"]
        time_grain = plan.get("time_grain")
        window = plan.get("window")
        trend = plan.get("trend")
        ranking = plan.get("ranking")
        joins = plan.get("joins", [])
        filters = plan.get("filters", [])

        base_table = self.registry["entities"][entity]["table"]
        metric_sql = self.registry["metrics"][metric]["sql"]

        # ---- JOIN LOGIC ----
        join_sql = []
        for j in joins:
            cfg = self.registry["entities"][entity]["joins"][j]
            join_sql.append(f"JOIN {cfg['table']} USING ({cfg['on']})")

        # ---- SELECT / GROUP BY ----
        select_cols = [entity]
        group_cols = [entity]

        if time_grain:
            select_cols.append(time_grain)
            group_cols.append(time_grain)

        # ---- BASE CTE ----
        sql = f"""
WITH base AS (
    SELECT
        {', '.join(select_cols)},
        {metric_sql} AS metric_value
    FROM {base_table}
    {' '.join(join_sql)}
"""

        if filters:
            sql += f"""
    WHERE {' AND '.join(filters)}
"""

        sql += f"""
    GROUP BY {', '.join(group_cols)}
)
"""

        # ---- WINDOW FUNCTION ----
        if window:
            window_sql = self.registry["window_patterns"][window]["sql"].format(
                time=time_grain,
                metric="metric_value"
            )
            sql += f"""
, ranked AS (
    SELECT *,
           {window_sql} AS rank
    FROM base
)
"""

        # ---- TREND LOGIC ----
        if trend:
            trend_sql = self.registry["trend_patterns"][trend]["sql"].format(
                entity=entity,
                time=time_grain,
                metric="metric_value"
            )
            sql += f"""
, trended AS (
    SELECT *,
           {trend_sql} AS trend_value
    FROM base
)
"""

        final_cte = "ranked" if window else "trended" if trend else "base"

        sql += f"""
SELECT *
FROM {final_cte}
"""

        if ranking:
            r = self.registry["ranking"][ranking]
            sql += f"""
ORDER BY metric_value {r['order']}
LIMIT {r['limit']}
"""

        return sql


In [0]:
import json

def resolve_intent(llm, question):
    prompt = f"""
    Output STRICT JSON only.

    {{
      "entity": "pizza|ingredient",
      "metric": "revenue|pizzas_sold|ingredient_usage_grams",
      "time_grain": "month|quarter|year|null",
      "window": "rank_by_time|null",
      "trend": "mom_growth|null",
      "ranking": "highest|lowest|null",
      "joins": [],
      "filters": []
    }}

    Question: {question}
    """
    return json.loads(llm.invoke(prompt))


In [0]:
def execute_with_repair(spark, sql, retries=2):
    for i in range(retries):
        try:
            return spark.sql(sql)
        except Exception as e:
            if i == retries - 1:
                raise RuntimeError(f"SQL failed:\n{sql}\n\n{e}")


In [0]:
def interpret_result(llm, question, df):
    data = df.limit(20).toPandas().to_dict(orient="records")

    prompt = f"""
    Explain the result using ONLY the data below.
    Do not invent metrics.

    Question: {question}
    Data: {data}
    """
    return llm.invoke(prompt)


In [0]:
import mlflow.pyfunc
from pyspark.sql import SparkSession

class PizzaAnalyticsAgent(mlflow.pyfunc.PythonModel):

    def load_context(self, context):
        self.registry = REGISTRY
        self.sql_agent = SQLAgent(self.registry)
        self.llm = context.artifacts["llm_client"]
        self.spark = SparkSession.builder.getOrCreate()

    def predict(self, context, model_input):
        question = model_input["question"]

        intent = resolve_intent(self.llm, question)
        sql = self.sql_agent.generate_sql(intent)
        df = execute_with_repair(self.spark, sql)
        answer = interpret_result(self.llm, question, df)

        return {
            "sql": sql,
            "answer": answer,
            "preview": df.limit(10).toPandas()
        }


In [0]:
from openai import OpenAI
import os

llm_client = OpenAI(
    base_url=f"{os.environ['DATABRICKS_HOST']}/serving-endpoints",
    api_key=os.environ["DATABRICKS_TOKEN"],
)


In [0]:
!pip install openai

In [0]:
import mlflow

mlflow.pyfunc.log_model(
    artifact_path="pizza_analytics_agent",
    python_model=PizzaAnalyticsAgent(),
    artifacts={
        "llm_client": llm_client   # your Groq / OpenAI / Databricks LLM
    }
)


In [0]:
import mlflow
import yaml
import json
# from pyspark.sql import SparkSession
from openai import OpenAI
import pandas as pd
from mlflow.models import ModelSignature
from mlflow.types import Schema, ColSpec

# -------------------------
# 1Ô∏è‚É£ Registry YAML (embedded)
# -------------------------
REGISTRY_YAML = """
entities:
  pizza:
    grain: pizza_name_id
    tables:
      monthly: retail_analytics.pizza.gold_pizza_metrics_monthly
      daily: retail_analytics.pizza.gold_pizza_metrics_daily
    time_dimensions: [day, month, year]

  ingredient:
    grain: ingredient
    tables:
      usage: retail_analytics.pizza.gold_ingredient_usage
    time_dimensions: [month, year]

# =========================
# BASE METRICS
# =========================
metrics:
  total_pizzas_sold:
    sql: SUM(total_pizzas_sold)
    table: retail_analytics.pizza.gold_pizza_metrics_monthly
    entity: pizza

  total_revenue:
    sql: SUM(total_revenue)
    table: retail_analytics.pizza.gold_pizza_metrics_monthly
    entity: pizza

  daily_revenue:
    sql: SUM(revenue_perday)
    table: retail_analytics.pizza.gold_pizza_metrics_daily
    entity: pizza

  ingredient_usage_grams:
    sql: SUM(total_ingredient_grams)
    table: retail_analytics.pizza.gold_ingredient_usage
    entity: ingredient

  ingredient_items_qty_grams:
    sql: SUM(items_qty_in_grams)
    table: retail_analytics.pizza.gold_ingredient_usage
    entity: ingredient

# =========================
# DERIVED METRICS
# =========================
derived_metrics:
  average_daily_revenue:
    sql: SUM(revenue_perday) / COUNT(DISTINCT day)
    table: retail_analytics.pizza.gold_pizza_metrics_daily
    entity: pizza

  revenue_per_pizza:
    sql: SUM(total_revenue) / SUM(total_pizzas_sold)
    table: retail_analytics.pizza.gold_pizza_metrics_monthly
    entity: pizza

  revenue_per_ingredient_gram:
    sql: SUM(total_revenue) / SUM(total_ingredient_grams)
    tables:
      revenue: retail_analytics.pizza.gold_pizza_metrics_monthly
      ingredient: retail_analytics.pizza.gold_ingredient_usage
    entity: ingredient

# =========================
# JOIN DEFINITIONS
# =========================
joins:
  pizza_to_ingredient:
    left_table: retail_analytics.pizza.gold_pizza_metrics_monthly
    right_table: retail_analytics.pizza.gold_ingredient_usage
    on:
      - pizza_name_id
      - month
      - year

  pizza_daily_to_monthly:
    left_table: retail_analytics.pizza.gold_pizza_metrics_daily
    right_table: retail_analytics.pizza.gold_pizza_metrics_monthly
    on:
      - pizza_name_id
      - month
      - year

# =========================
# TIME AGGREGATIONS
# =========================
time_grains:
  daily:
    columns: [day, month, year]

  monthly:
    columns: [month, year]

  yearly:
    columns: [year]

# =========================
# RANKING & WINDOW PATTERNS
# =========================
ranking_patterns:
  highest:
    order: DESC
    limit: 1

  lowest:
    order: ASC
    limit: 1

  top_n:
    order: DESC
    limit: N

  bottom_n:
    order: ASC
    limit: N

window_patterns:
  rank_within_time:
    sql: RANK() OVER (PARTITION BY {time} ORDER BY {metric} DESC)

  dense_rank_within_time:
    sql: DENSE_RANK() OVER (PARTITION BY {time} ORDER BY {metric} DESC)

# =========================
# TREND PATTERNS
# =========================
trend_patterns:
  month_over_month_change:
    sql: "{metric} - LAG({metric}) OVER (PARTITION BY {entity} ORDER BY year, month)"

  month_over_month_growth_pct:
    sql: "({metric} - LAG({metric}) OVER (PARTITION BY {entity} ORDER BY year, month))
          / NULLIF(LAG({metric}) OVER (PARTITION BY {entity} ORDER BY year, month), 0)"

  consecutive_decline:
    rule: "negative change for N consecutive months"

# =========================
# BUSINESS TERM DEFINITIONS
# =========================
business_terms:
  peak_month:
    definition: "month with highest total_revenue"

  bottleneck:
    definition: "ingredient with highest total_ingredient_grams during peak_month"

  least_used:
    definition: "minimum total_ingredient_grams"

  top_selling:
    definition: "highest total_pizzas_sold"

  volatile:
    definition: "highest standard deviation of monthly revenue"

  declining_sales:
    definition: "negative month_over_month_change for 3 consecutive months"

# =========================
# VALIDATION RULES
# =========================
validation:
  forbid_columns:
    - peak
    - best
    - highest
  require_time_for_trends: true
  require_join_for_derived_metrics: true"""
REGISTRY = yaml.safe_load(REGISTRY_YAML)

# -------------------------
# 2Ô∏è‚É£ SQLAgent
# -------------------------
class SQLAgent:
    def __init__(self, registry):
        self.registry = registry

    def plan(self, intent):
        return intent

    def generate_sql(self, intent):
        entity = intent["entity"]
        metric = intent["metric"]
        time_grain = intent.get("time_grain")  # day / month / year / None
        window = intent.get("window")
        trend = intent.get("trend")
        ranking = intent.get("ranking")
        joins = intent.get("joins", [])
        filters = intent.get("filters", [])

        # -----------------------------
        # 1Ô∏è‚É£ Determine base table
        # -----------------------------
        # Use daily table for day-level queries, monthly for month/year
        if time_grain == "day":
            base_table = self.registry["entities"][entity]["tables"].get("daily")
        else:
            base_table = self.registry["entities"][entity]["tables"].get("monthly")

        # -----------------------------
        # 2Ô∏è‚É£ Determine metric column
        # -----------------------------
        if time_grain == "day" and metric in ["total_revenue", "daily_revenue"]:
            metric_sql = "SUM(revenue_perday)"
        else:
            metric_sql = self.registry["metrics"][metric]["sql"]

        # -----------------------------
        # 3Ô∏è‚É£ Build SELECT & GROUP BY
        # -----------------------------
        select_cols = [entity]
        group_cols = [entity]

        # Add time_grain columns
        if time_grain == "day":
            select_cols += ["day", "month", "year"]
            group_cols += ["day", "month", "year"]
        elif time_grain == "month":
            select_cols += ["month", "year"]
            group_cols += ["month", "year"]
        elif time_grain == "year":
            select_cols += ["year"]
            group_cols += ["year"]

        # -----------------------------
        # 4Ô∏è‚É£ Build JOINs
        # -----------------------------
        join_sql = []
        for j in joins:
            cfg = self.registry["joins"][j]
            join_sql.append(f"JOIN {cfg['right_table']} USING ({', '.join(cfg['on'])})")

        # -----------------------------
        # 5Ô∏è‚É£ Build WHERE filters
        # -----------------------------
        filter_clauses = []
        for f in filters:
            field = f.get("field")
            op = f.get("operator", "=")
            value = f.get("value")
            if isinstance(value, str):
                value = f"'{value}'"
            filter_clauses.append(f"{field} {op} {value}")

        where_sql = f"WHERE {' AND '.join(filter_clauses)}" if filter_clauses else ""

        # -----------------------------
        # 6Ô∏è‚É£ Base CTE
        # -----------------------------
        sql = f"""
WITH base AS (
    SELECT
        {', '.join(select_cols)},
        {metric_sql} AS metric_value
    FROM {base_table}
    {' '.join(join_sql)}
    {where_sql}
    GROUP BY {', '.join(group_cols)}
)
"""

        # -----------------------------
        # 7Ô∏è‚É£ Apply window function
        # -----------------------------
        if window:
            window_sql = self.registry["window_patterns"][window]["sql"].format(
                time="_".join(group_cols),  # partition by all grouping columns
                metric="metric_value"
            )
            sql += f"""
, ranked AS (
    SELECT *,
           {window_sql} AS rank
    FROM base
)
"""
            final_cte = "ranked"
        else:
            final_cte = "base"

        # -----------------------------
        # 8Ô∏è‚É£ Apply trend function
        # -----------------------------
        if trend:
            trend_sql = self.registry["trend_patterns"][trend]["sql"].format(
                entity=entity,
                metric="metric_value",
                time="_".join(group_cols)
            )
            sql += f"""
, trended AS (
    SELECT *,
           {trend_sql} AS trend_value
    FROM base
)
"""
            final_cte = "trended"

        # -----------------------------
        # 9Ô∏è‚É£ Final SELECT + ranking
        # -----------------------------
        sql += f"SELECT * FROM {final_cte}\n"

        if ranking:
            r = self.registry["ranking_patterns"][ranking]
            limit_val = r['limit'] if isinstance(r['limit'], int) else 'N'
            sql += f"ORDER BY metric_value {r['order']} LIMIT {limit_val}\n"

        return sql


# -------------------------
# 3Ô∏è‚É£ LLM client helper (works notebook + serving)
# -------------------------
def get_llm():
    return OpenAI(
        base_url=f"{os.environ['DATABRICKS_HOST']}/serving-endpoints",
        api_key=os.environ["DATABRICKS_TOKEN"],
    )
import json
import ast

import json
import re
import json
import re

def extract_json_from_llm(raw):
    """
    Extract last JSON object from LLM output.
    Handles:
    - list outputs from Databricks LLM
    - extra reasoning text before JSON
    """
    # If list output, join all texts
    if isinstance(raw, list):
        raw = "".join([item.get("text", str(item)) if isinstance(item, dict) else str(item) for item in raw])

    # Regex to find JSON object
    matches = re.findall(r'(\{.*\})', raw, flags=re.DOTALL)
    if not matches:
        raise ValueError(f"No JSON object found in LLM output:\n{raw}")

    # Use the last JSON object
    json_str = matches[-1]

    # Replace single quotes with double quotes (naive fix)
    json_str = json_str.replace("'", '"')

    # Remove trailing commas
    json_str = re.sub(r",\s*}", "}", json_str)
    json_str = re.sub(r",\s*]", "]", json_str)

    try:
        return json.loads(json_str)
    except json.JSONDecodeError as e:
        raise ValueError(f"Failed to parse extracted JSON:\n{json_str}\nError: {e}")


def extract_intent_json(raw):
    """
    Extract valid intent JSON from raw LLM output.

    Handles:
    - list outputs from Databricks LLM
    - extra reasoning text before JSON
    - JSON array or single JSON object
    """
    # --- Step 1: Convert list outputs to string ---
    if isinstance(raw, list):
        texts = []
        for item in raw:
            if isinstance(item, dict) and "text" in item:
                texts.append(item["text"])
            else:
                texts.append(str(item))
        raw = "".join(texts)
        print(raw)

    # --- Step 2: Extract all JSON objects or arrays using regex ---
    matches = re.findall(r'(\{.*?\}|\[.*?\])', raw, flags=re.DOTALL)
    if not matches:
        raise ValueError(f"No JSON found in LLM output:\n{raw}")

    # --- Step 3: Pick last valid JSON containing 'entity' ---
    for m in reversed(matches):
        try:
            candidate = json.loads(m)
            # Accept dict or list of dicts with 'entity' key
            if isinstance(candidate, dict) and "entity" in candidate:
                return candidate
            if isinstance(candidate, list) and all(isinstance(c, dict) and "entity" in c for c in candidate):
                return candidate
        except json.JSONDecodeError:
            continue

    raise ValueError(f"No valid intent JSON found in LLM output:\n{raw}")





def resolve_intent(llm, question: str) -> dict:
    DEFAULT_INTENT = {
        "entity": "pizza",
        "metric": "total_revenue",
        "time_grain": None,
        "window": None,
        "trend": None,
        "ranking": None,
        "joins": [],
        "filters": []
    }

    prompt = f"""
Return STRICT JSON only.

Valid values:
- entity: pizza | ingredient
- metric: total_revenue | total_pizzas_sold | ingredient_usage_grams
- time_grain: day | month | year | null
- window: rank_within_time | null
- trend: month_over_month_growth_pct | null
- ranking: highest | lowest | null
- joins: []
- filters: []

Question: {question}
"""

    response = llm.chat.completions.create(
        model="databricks-gpt-oss-120b",
        messages=[
            {"role": "system", "content": "You output ONLY valid JSON."},
            {"role": "user", "content": prompt}
        ],
        temperature=0
    )

    raw = response.choices[0].message.content
    parsed = extract_intent_json(raw)
    # ‚úÖ SAFE NORMALIZATION (THIS FIXES YOUR ERROR)
    # if isinstance(raw, dict):
    #     parsed = raw
    # elif isinstance(raw, str):
    #     parsed = json.loads(raw)
    # else:
    #     raise ValueError(f"Unsupported LLM output type: {type(raw)}")

    if isinstance(parsed, list):
        intents = []
        for p in parsed:
            intent = DEFAULT_INTENT.copy()
            intent.update(p)
            intents.append(intent)
        return intents
    else:
        intent = DEFAULT_INTENT.copy()
        intent.update(parsed)
        return intent

# -------------------------
# 4Ô∏è‚É£ Result interpreter
# -------------------------
def interpret_result(llm, question, df):
    data = df.limit(20).toPandas().to_dict(orient="records")

    prompt = f"""
Explain the SQL result using only the data below.
Do not invent values.

Question:
{question}

Data:
{json.dumps(data, indent=2)}
"""

    response = llm.chat.completions.create(
        model="databricks-gpt-oss-120b",
        messages=[
            {"role": "system", "content": "You are a SQL and analytics assistant."},
            {"role": "user", "content": prompt}
        ]
    )

    return response.choices[0].message.content
def run_sql(self, query: str) -> pd.DataFrame:
        with sql.connect(
            server_hostname="dbc-741c7540-6155.cloud.databricks.com",
            http_path="/sql/1.0/warehouses/e2eb5bcf69add002",
            access_token=os.environ["DATABRICKS_TOKEN"],
        ) as conn:
            return pd.read_sql(query, conn)
# -------------------------
# 5Ô∏è‚É£ SQL execution with auto-repair
# -------------------------
def execute_sql(sql_query: str) -> pd.DataFrame:
    """
    Execute SQL using Databricks SQL connector and return Pandas DataFrame.
    Auto-repair logic can be added here later.
    """
    try:
        return run_sql(sql_query)
    except Exception as e:
        raise RuntimeError(f"SQL execution failed:\n{sql_query}\n\n{e}")



# -------------------------
# 6Ô∏è‚É£ PyFunc model
# -------------------------
import mlflow.pyfunc

class PizzaRAGPyFuncModel(mlflow.pyfunc.PythonModel):
    def load_context(self, context):
        self.registry = REGISTRY
        self.agent = SQLAgent(self.registry)
        self.llm = get_llm()

    def predict(self, context, model_input):
        query_text = str(model_input["query"].iloc[0])

        # Resolve intent(s) ‚Äî can be single dict or list of dicts
        resolved = resolve_intent(self.llm, query_text)
        intents = resolved if isinstance(resolved, list) else [resolved]

        all_answers = []
        sql_queries = []
        for intent in intents:
            # Generate SQL for each intent
            sql_query = self.agent.generate_sql(intent)

            # Execute SQL
            df = execute_sql(sql_query)

            # Interpret result
            answer_text = interpret_result(self.llm, query_text, df)

            all_answers.append({
                "answer": answer_text,
                "intent": json.dumps(intent),
                "sql": sql_query,
                "num_rows": len(df),
                "num_documents": len(intent.get("joins", []))
            })

        # Combine into single DataFrame for MLflow
        return pd.DataFrame(all_answers)


# -------------------------
# 7Ô∏è‚É£ MLflow model signature & log
# -------------------------
signature = ModelSignature(
    inputs=Schema([ColSpec("string", "query")]),
    outputs=Schema([
        ColSpec("string", "answer"),
        ColSpec("string", "intent"),
        ColSpec("string", "sql"),
        ColSpec("long", "num_rows"),
        ColSpec("long", "num_documents")
    ])
)

with mlflow.start_run():
    mlflow.pyfunc.log_model(
        name="pizza_rag_model",
        python_model=PizzaRAGPyFuncModel(),
        signature=signature,
        input_example={"query": "how much revenue generated in 25 december and year 2015"},
        registered_model_name="retail_analytics.pizza.pizza_rag_models",
    )


In [0]:
import mlflow
import yaml
import json
import os
from openai import OpenAI
import pandas as pd
from mlflow.models import ModelSignature
from mlflow.types import Schema, ColSpec
from databricks import sql
import ast
import re


# -------------------------
# 1Ô∏è‚É£ Registry YAML (embedded)
# -------------------------
REGISTRY_YAML = """
entities:
  pizza:
    grain: pizza_name_id
    tables:
      monthly: retail_analytics.pizza.gold_pizza_metrics_monthly
      daily: retail_analytics.pizza.gold_pizza_metrics_daily
    time_dimensions: [day, month, year]

  ingredient:
    grain: ingredient
    tables:
      usage: retail_analytics.pizza.gold_ingredient_usage
    time_dimensions: [month, year]

# =========================
# BASE METRICS
# =========================
metrics:
  total_pizzas_sold:
    entity: pizza
    grains:
      day:
        table: retail_analytics.pizza.gold_pizza_metrics_daily
        sql: SUM(total_pizzas_sold)
      month:
        table: retail_analytics.pizza.gold_pizza_metrics_monthly
        sql: SUM(total_pizzas_sold)
      year:
        table: retail_analytics.pizza.gold_pizza_metrics_monthly
        sql: SUM(total_pizzas_sold)

  total_revenue:
    sql: SUM(total_revenue)
    table: retail_analytics.pizza.gold_pizza_metrics_monthly
    entity: pizza

  daily_revenue:
    sql: SUM(revenue_perday)
    table: retail_analytics.pizza.gold_pizza_metrics_daily
    entity: pizza

  ingredient_usage_grams:
    sql: SUM(total_ingredient_grams)
    table: retail_analytics.pizza.gold_ingredient_usage
    entity: ingredient

  ingredient_items_qty_grams:
    sql: SUM(items_qty_in_grams)
    table: retail_analytics.pizza.gold_ingredient_usage
    entity: ingredient

# =========================
# DERIVED METRICS
# =========================
derived_metrics:
  average_daily_revenue:
    sql: SUM(revenue_perday) / COUNT(DISTINCT day)
    table: retail_analytics.pizza.gold_pizza_metrics_daily
    entity: pizza

  revenue_per_pizza:
    sql: SUM(total_revenue) / SUM(total_pizzas_sold)
    table: retail_analytics.pizza.gold_pizza_metrics_monthly
    entity: pizza

  revenue_per_ingredient_gram:
    sql: SUM(total_revenue) / SUM(total_ingredient_grams)
    tables:
      revenue: retail_analytics.pizza.gold_pizza_metrics_monthly
      ingredient: retail_analytics.pizza.gold_ingredient_usage
    entity: ingredient

# =========================
# JOIN DEFINITIONS
# =========================
joins:
  pizza_to_ingredient:
    left_table: retail_analytics.pizza.gold_pizza_metrics_monthly
    right_table: retail_analytics.pizza.gold_ingredient_usage
    on:
      - pizza_name_id
      - month
      - year

  pizza_daily_to_monthly:
    left_table: retail_analytics.pizza.gold_pizza_metrics_daily
    right_table: retail_analytics.pizza.gold_pizza_metrics_monthly
    on:
      - pizza_name_id
      - month
      - year

# =========================
# TIME AGGREGATIONS
# =========================
time_grains:
  daily:
    columns: [day, month, year]

  monthly:
    columns: [month, year]

  yearly:
    columns: [year]

# =========================
# RANKING & WINDOW PATTERNS
# =========================
ranking_patterns:
  highest:
    order: DESC
    limit: 1

  lowest:
    order: ASC
    limit: 1

  top_n:
    order: DESC
    limit: N

  bottom_n:
    order: ASC
    limit: N

window_patterns:
  rank_within_time:
    sql: RANK() OVER (PARTITION BY {time} ORDER BY {metric} DESC)

  dense_rank_within_time:
    sql: DENSE_RANK() OVER (PARTITION BY {time} ORDER BY {metric} DESC)

# =========================
# TREND PATTERNS
# =========================
trend_patterns:
  month_over_month_change:
    sql: "{metric} - LAG({metric}) OVER (PARTITION BY {entity} ORDER BY year, month)"

  month_over_month_growth_pct:
    sql: "({metric} - LAG({metric}) OVER (PARTITION BY {entity} ORDER BY year, month))
          / NULLIF(LAG({metric}) OVER (PARTITION BY {entity} ORDER BY year, month), 0)"

  consecutive_decline:
    rule: "negative change for N consecutive months"

# =========================
# BUSINESS TERM DEFINITIONS
# =========================
business_terms:
  peak_month:
    definition: "month with highest total_revenue"

  bottleneck:
    definition: "ingredient with highest total_ingredient_grams during peak_month"

  least_used:
    definition: "minimum total_ingredient_grams"

  top_selling:
    definition: "highest total_pizzas_sold"

  volatile:
    definition: "highest standard deviation of monthly revenue"

  declining_sales:
    definition: "negative month_over_month_change for 3 consecutive months"

# =========================
# VALIDATION RULES
# =========================
validation:
  forbid_columns:
    - peak
    - best
    - highest
  require_time_for_trends: true
  require_join_for_derived_metrics: true"""
REGISTRY = yaml.safe_load(REGISTRY_YAML)

# -------------------------
# 2Ô∏è‚É£ SQLAgent
# -------------------------
class SQLAgent:
    def __init__(self, registry):
        self.registry = registry

    def plan(self, intent):
        return intent

    def generate_sql(self, intent):
        entity = intent["entity"]
        metric = intent["metric"]
        time_grain = intent.get("time_grain")  # day / month / year / None
        window = intent.get("window")
        trend = intent.get("trend")
        ranking = intent.get("ranking")
        joins = intent.get("joins", [])
        filters = intent.get("filters", [])

        # -----------------------------
        # 1Ô∏è‚É£ Determine base table
        # -----------------------------
        # Use daily table for day-level queries, monthly for month/year
        # -----------------------------
        # 1Ô∏è‚É£ Determine base table (METRIC-DRIVEN)
        # -----------------------------
        # metric_cfg = self.registry["metrics"][metric]
        # metric_entity = metric_cfg["entity"]
        # if intent.get("time_grain") == "day" or intent.get("breakdown") == "day":
        #     table = "retail_analytics.pizza.gold_pizza_metrics_daily"
        # else:
        #     table = "retail_analytics.pizza.gold_pizza_metrics_monthly"
        # if metric_entity == "ingredient":
        #     base_table = self.registry["entities"]["ingredient"]["tables"]["usage"]
        # else:
        #     # pizza metrics
        #     if time_grain == "day":
        #         base_table = self.registry["entities"]["pizza"]["tables"]["daily"]
        #     else:
        #         base_table = self.registry["entities"]["pizza"]["tables"]["monthly"]
                # -----------------------------
        # 1Ô∏è‚É£ Metric-driven grain resolution (SINGLE SOURCE OF TRUTH)
        # -----------------------------
        metric_cfg = self.registry["metrics"][metric]

        # Determine effective grain
        effective_grain = (
            intent.get("breakdown")
            if intent.get("breakdown") in ["day", "month", "year"]
            else intent.get("time_grain")
        )

        # Default grain if still None
        if effective_grain is None:
            effective_grain = "month"

        # Case 1: Grain-aware metric
        if "grains" in metric_cfg:
            grains_cfg = metric_cfg["grains"]

            if effective_grain not in grains_cfg:
                raise ValueError(
                    f"Metric '{metric}' does not support grain '{effective_grain}'"
                )

            grain_cfg = grains_cfg[effective_grain]
            base_table = grain_cfg["table"]
            metric_sql = grain_cfg["sql"]

# Case 2: Legacy flat metric (BACKWARD COMPATIBLE)
        else:
            base_table = metric_cfg["table"]
            metric_sql = metric_cfg["sql"]

        # -----------------------------
        # 2Ô∏è‚É£ Determine metric SQL
        # -----------------------------
        # if time_grain == "day" and metric in ["total_revenue", "daily_revenue"]:
        #     metric_sql = "SUM(revenue_perday)"
        # else:
        #     metric_sql = self.registry["metrics"][metric]["sql"]

 
        # -----------------------------
# 4Ô∏è‚É£ Build SELECT & GROUP BY
# -----------------------------
        select_cols = []
        group_cols = []

        breakdown = intent.get("breakdown", "none")

        # ---------- BREAKDOWN FIRST (OVERRIDES DEFAULT GRAIN) ----------
        if breakdown == "month":
            select_cols.extend(["year", "month"])
            group_cols.extend(["year", "month"])
            intent["__exclude_grain__"] = True
        elif breakdown == "day":
            select_cols.extend(["day", "month", "year"])
            group_cols.extend(["day", "month", "year"])
            intent["__exclude_grain__"] = True

        elif breakdown == "year":
            select_cols.append("year")
            group_cols.append("year")
            intent["__exclude_grain__"] = True

        elif breakdown == "pizza":
            select_cols.append("pizza_name_id")
            group_cols.append("pizza_name_id")

        elif breakdown == "ingredient":
            select_cols.append("ingredient")
            group_cols.append("ingredient")

        else:
            # no breakdown ‚Üí pure aggregate
            intent["__exclude_grain__"] = True


        # ---------- TIME GRAIN (ONLY IF NO BREAKDOWN) ----------
        time_grain = intent.get("time_grain")

        if breakdown == "none":
            if time_grain == "day":
                select_cols += ["day", "month", "year"]
                group_cols += ["day", "month", "year"]

            elif time_grain == "month":
                select_cols += ["month", "year"]
                group_cols += ["month", "year"]

            elif time_grain == "year":
                select_cols += ["year"]
                group_cols += ["year"]


        # ---------- DEDUP ----------
        select_cols = list(dict.fromkeys(select_cols))
        group_cols = list(dict.fromkeys(group_cols))

        # -----------------------------
        # 5Ô∏è‚É£ Build JOINs
        # -----------------------------

        join_sql = []
        for j in joins:
            if isinstance(j, str):
                # Join from registry
                cfg = self.registry["joins"][j]
                join_sql.append(f"JOIN {cfg['right_table']} USING ({', '.join(cfg['on'])})")
            elif isinstance(j, dict):
                # Join specified as dict
                right_table = j.get("right") or j.get("right_table")
                on_cols = j.get("on", [])
                if right_table and on_cols:
                    join_sql.append(f"JOIN {right_table} USING ({', '.join(on_cols)})")


        # -----------------------------
        # 6Ô∏è‚É£ Build WHERE filters
        # -----------------------------
 
        filter_clauses = []
        MONTH_MAP = {
                    "january": 1, "february": 2, "march": 3,
                    "april": 4, "may": 5, "june": 6,
                    "july": 7, "august": 8, "september": 9,
                    "october": 10, "november": 11, "december": 12
                            }
        for f in filters:
            field = f.get("field")
            op = f.get("operator", "=")
            value = f.get("value")

            # convert 'eq' to '='
            if op.lower() == "eq":
                op = "="
            if field == "month" and isinstance(value, str):
                value = MONTH_MAP.get(value.lower())
                if value is None:
                    raise ValueError(f"Invalid month value: {f.get('value')}")
            # handle 'date' filter for pizza daily
            if field == "date":
                if time_grain == "day":
                    yyyy, mm, dd = value.split("-")
                    filter_clauses.append(f"day = {int(dd)}")
                    filter_clauses.append(f"month = {int(mm)}")
                    filter_clauses.append(f"year = {int(yyyy)}")
                elif time_grain == "month":
                    yyyy, mm = value.split("-")
                    filter_clauses.append(f"month = {int(mm)}")
                    filter_clauses.append(f"year = {int(yyyy)}")
                elif time_grain == "year":
                    yyyy = value.split("-")[0]
                    filter_clauses.append(f"year = {int(yyyy)}")
            elif field == "year":
                filter_clauses.append(f"year {op} {value}")
            else:
                if isinstance(value, str):
                    value_sql = f"'{value}'"
                else:
                    value_sql = str(value)
                filter_clauses.append(f"{field} {op} {value_sql}")
        filter_clauses = apply_time_range(intent, filter_clauses)
        where_sql = f"WHERE {' AND '.join(filter_clauses)}" if filter_clauses else ""


        # -----------------------------
        # 7Ô∏è‚É£ Base CTE
        # -----------------------------
        sql = f"""
WITH base AS (
    SELECT
        {', '.join(select_cols)},
        {metric_sql} AS metric_value
    FROM {base_table}
    {' '.join(join_sql)}
    {where_sql}
    GROUP BY {', '.join(group_cols)}
)
"""

        # -----------------------------
        # 8Ô∏è‚É£ Apply window function
        # -----------------------------
        final_cte = "base"
        if window:
            window_sql = self.registry["window_patterns"][window]["sql"].format(
                time=",".join(group_cols),
                metric="metric_value"
            )
            sql += f"""
, ranked AS (
    SELECT *,
           {window_sql} AS rank
    FROM base
)
"""
            final_cte = "ranked"

        # -----------------------------
        # 9Ô∏è‚É£ Apply trend function
        # -----------------------------
        if trend:
            trend_sql = self.registry["trend_patterns"][trend]["sql"].format(
                entity=grain_col,
                metric="metric_value",
                time=",".join(group_cols)
            )
            sql += f"""
, trended AS (
    SELECT *,
           {trend_sql} AS trend_value
    FROM base
)
"""
            final_cte = "trended"

        # -----------------------------
        # 10Ô∏è‚É£ Final SELECT + ranking
        # -----------------------------
        sql += f"SELECT * FROM {final_cte}\n"
        ranking_type = None
        ranking_limit = None

        if isinstance(ranking, dict):
            ranking_type = ranking.get("type")
            ranking_limit = ranking.get("limit")

        elif isinstance(ranking, str):
            ranking_type = ranking

        # Default limits
        if ranking_type in ("highest", "lowest") and ranking_limit is None:
            ranking_limit = 1

        # Resolve pattern from registry
        if ranking_type:
            pattern = self.registry["ranking_patterns"].get(ranking_type)
            if not pattern:
                raise ValueError(f"Unknown ranking type: {ranking_type}")

            order_dir = pattern["order"]
            sql += f"ORDER BY metric_value {order_dir}\n"

        # Apply LIMIT
        if ranking_limit:
            sql += f"LIMIT {int(ranking_limit)}\n"
        return sql


from datetime import datetime

def apply_time_range(intent, filter_clauses):
    tr = intent.get("time_range")
    if not tr:
        return filter_clauses

    now = datetime.now()
    current_month = now.month
    current_year = now.year

    ttype = tr.get("type")

    if ttype in ("quarter", "half", "range"):
        start = tr["start_month"]
        end = tr["end_month"]
        filter_clauses.append(f"month BETWEEN {start} AND {end}")

    elif ttype == "ytd":
        filter_clauses.append(f"month BETWEEN 1 AND {current_month}")
        filter_clauses.append(f"year = {current_year}")

    elif ttype == "ltm":
        # Rolling 12 months
        filter_clauses.append(
            f"(year = {current_year} AND month <= {current_month}) OR "
            f"(year = {current_year - 1} AND month > {current_month})"
        )

    # IMPORTANT: drop grouping
    intent["__exclude_grain__"] = True
    intent["time_grain"] = None

    return filter_clauses

# -------------------------
# 3Ô∏è‚É£ LLM client helper (works notebook + serving)
# -------------------------
def get_llm():
    return OpenAI(
        base_url=f"{os.environ['DATABRICKS_HOST']}/serving-endpoints",
        api_key=os.environ["DATABRICKS_TOKEN"],
    )

def extract_intent_json(raw):
    """
    Extract valid intent JSON from raw LLM output.

    Handles:
    - list outputs from Databricks LLM
    - extra reasoning text before JSON
    - JSON array or single JSON object
    """
    # --- Step 1: Convert list outputs to string ---
    if isinstance(raw, list):
        texts = []
        for item in raw:
            if isinstance(item, dict) and "text" in item:
                texts.append(item["text"])
            else:
                texts.append(str(item))
        raw = "".join(texts)
        print(raw)

    # --- Step 2: Extract all JSON objects or arrays using regex ---
    matches = re.findall(r'(\{.*?\}|\[.*?\])', raw, flags=re.DOTALL)
    if not matches:
        raise ValueError(f"No JSON found in LLM output:\n{raw}")

    # --- Step 3: Pick last valid JSON containing 'entity' ---
    for m in reversed(matches):
        try:
            candidate = json.loads(m)
            # Accept dict or list of dicts with 'entity' key
            if isinstance(candidate, dict) and "entity" in candidate:
                return candidate
            if isinstance(candidate, list) and all(isinstance(c, dict) and "entity" in c for c in candidate):
                return candidate
        except json.JSONDecodeError:
            continue

    raise ValueError(f"No valid intent JSON found in LLM output:\n{raw}")





# -------------------------
# 3Ô∏è‚É£ Resolve intent safely (single intent)
# -------------------------
def resolve_intent(llm, question: str) -> dict:
    DEFAULT_INTENT = {
    "entity": "pizza",
    "metric": "total_revenue",
    "time_grain": None,
    "window": None,
    "trend": None,
    "ranking": None,
    "joins": [],
    "filters": [],
    "time_range": None
    }

    prompt = f"""
You are an analytics intent parser.

You MUST return ONLY valid JSON.
No markdown.
No explanations.
No extra text.

====================
OUTPUT SCHEMA
====================
{{
  "entity": "pizza | ingredient",
  "metric": "total_revenue | total_pizzas_sold | ingredient_usage_grams",
  "time_grain": "day | month | year | null",
  "window": "rank_within_time | null",
  "trend": "month_over_month_growth_pct | null",
  "breakdown": "none | month | year | pizza | ingredient"
  "ranking": null | {{
    "type": "highest | lowest | top_n | bottom_n",
    "limit": number
  }},
  "joins": [],
  "filters": [],
  "time_range": null | {{
    "type": "quarter | half | range | ytd | ltm",
    "start_month": number | null,
    "end_month": number | null
  }}
}}

====================
DECISION ORDER (MANDATORY)
====================
1. Identify entity and metric
2. Detect business time periods (time_range)
3. IF time_range exists:
   - DO NOT create month filters
   - Set time_grain = null
4. ELSE:
   - Extract explicit month / year / date filters
5. Then detect ranking, trend, window

====================
TIME RANGE DETECTION (HIGHEST PRIORITY)
====================
Convert business periods to months:

- Q1 / first quarter ‚Üí 1‚Äì3
- Q2 ‚Üí 4‚Äì6
- Q3 ‚Üí 7‚Äì9
- Q4 ‚Üí 10‚Äì12
- first half / H1 / first 6 months ‚Üí 1‚Äì6
- second half / H2 ‚Üí 7‚Äì12
- YTD / year to date ‚Üí 1 to current month
- LTM / last 12 months / trailing 12 months ‚Üí rolling 12 months

Rules:
- If any time_range is detected:
  - Populate time_range
  - filters MUST be empty
  - time_grain MUST be null

====================
FILTER RULES (ONLY IF NO TIME_RANGE)
====================
- Month names ‚Üí month filter (string)
- Year numbers ‚Üí year filter (number)
- Exact dates ‚Üí field = "date" (YYYY-MM-DD)
- filters MUST always be a list

Filter format:
{{
  "field": "month | year | date | pizza_name_id | ingredient",
  "operator": "=",
  "value": string | number
}}
====================
BREAKDOWN RULES (MANDATORY)
====================
You MUST decide the breakdown dimension based on the user's question.

Examples:
- "monthly data", "per month", "month-wise", "trend across months"
  ‚Üí breakdown = "month"

- "by pizza", "top pizzas", "which pizza"
  ‚Üí breakdown = "pizza"

- "ingredient usage by ingredient"
  ‚Üí breakdown = "ingredient"

- If user asks only for a total (no breakdown wording)
  ‚Üí breakdown = "none"

If breakdown is not "none":
- DO NOT include entity grain unless breakdown explicitly requests it
====================
RANKING RULES
====================
- most / highest / best ‚Üí {{ "type": "highest" }}
- least / lowest ‚Üí {{ "type": "lowest" }}
- top N ‚Üí {{ "type": "top_n", "limit": N }}
- bottom N ‚Üí {{ "type": "bottom_n", "limit": N }}

====================
EXAMPLES
====================
Question: how many pizzas sold in first 6 months
Output:
{{
  "entity": "pizza",
  "metric": "total_pizzas_sold",
  "time_grain": null,
  "window": null,
  "trend": null,
  "ranking": null,
  "joins": [],
  "filters": [],
  "time_range": {{
    "type": "half",
    "start_month": 1,
    "end_month": 6
  }}
}}

Question: top 3 pizzas by revenue in December 2015
Output:
{{
  "entity": "pizza",
  "metric": "total_revenue",
  "time_grain": "month",
  "window": null,
  "trend": null,
  "ranking": {{ "type": "top_n", "limit": 3 }},
  "joins": [],
  "filters": [
    {{ "field": "month", "operator": "=", "value": "december" }},
    {{ "field": "year", "operator": "=", "value": 2015 }}
  ],
  "time_range": null
}}

====================
QUESTION
====================
{question}
"""

    response = llm.chat.completions.create(
        model="databricks-gpt-oss-120b",
        messages=[
            {"role": "system", "content": "You output ONLY valid JSON."},
            {"role": "user", "content": prompt}
        ],
        temperature=0
    )

    raw = response.choices[0].message.content

    # --- Parse LLM output safely ---
    if isinstance(raw, dict):
        intents = [raw]
    elif isinstance(raw, str):
        parsed = json.loads(raw)
        # Ensure it‚Äôs a list of intents
        intents = parsed if isinstance(parsed, list) else [parsed]
    elif isinstance(raw, list):
        # Databricks OSS output list format
        texts = [x["text"] for x in raw if isinstance(x, dict) and "text" in x]
        combined_text = "".join(texts)
        parsed = json.loads(combined_text)
        intents = parsed if isinstance(parsed, list) else [parsed]
    else:
        raise ValueError(f"Unsupported LLM output type: {type(raw)}")

    # --- Take only the first intent to avoid list issues ---
    intent = DEFAULT_INTENT.copy()
    intent.update(intents[0])
    q = question.lower()

    BUSINESS_TIME_PHRASES = [
        "worst month",
        "best month",
        "worst year",
        "best year",
        "business performance",
        "overall performance"
    ]

    if any(p in q for p in BUSINESS_TIME_PHRASES):
        # Force business-level aggregation
        intent["__exclude_grain__"] = True

        # These questions are always time-based
        if intent.get("time_grain") is None:
            intent["time_grain"] = "month"

        # Ranking default for worst/best
        if "worst" in q and not intent.get("ranking"):
            intent["ranking"] = {"type": "lowest"}
        elif "best" in q and not intent.get("ranking"):
            intent["ranking"] = {"type": "highest"}
    date_match = re.search(r'\b(january|february|march|april|may|june|july|august|september|october|november|december)\s+\d{1,2}\b', q)

    if date_match:
        intent["time_grain"] = "day"
        intent["breakdown"] = "day"
        month_name, day = date_match.group().split()
        month_num = datetime.strptime(month_name, "%B").month

        intent["filters"].append({
            "field": "month",
            "operator": "=",
            "value": month_num
        })

        intent["filters"].append({
            "field": "day",
            "operator": "=",
            "value": int(day)
        })
    return intent


# -------------------------
# 4Ô∏è‚É£ Result interpreter
# -------------------------


def interpret_result(llm, question: str, df: pd.DataFrame) -> str:
    """
    ALWAYS returns a plain string from LLM output.
    Handles OSS model output that can be list/dict/text.
    """

    data_preview = df.head(50).to_dict(orient="records")

    prompt = f"""
You are a data analyst.

Rules:
- Answer the question using ONLY the data provided.
- Do NOT invent values.
- Do NOT show reasoning or JSON.
- Output ONLY a single plain English paragraph.

Question:
{question}

Data:
{data_preview}
"""

    response = llm.chat.completions.create(
        model="databricks-gpt-oss-120b",
        messages=[
            {"role": "system", "content": "You output ONLY plain text."},
            {"role": "user", "content": prompt}
        ],
        temperature=0
    )

    # üîπ Extract text from all possible OSS formats
    raw = response.choices[0].message.content

    # If it's a list of dicts (Databricks OSS default)
    if isinstance(raw, list):
        texts = []
        for item in raw:
            if isinstance(item, dict) and "text" in item:
                texts.append(item["text"])
            else:
                texts.append(str(item))
        raw = " ".join(texts)

    # If it's a dict, try common fields
    elif isinstance(raw, dict):
        if "text" in raw:
            raw = raw["text"]
        elif "summary" in raw and isinstance(raw["summary"], list):
            raw = " ".join([s.get("text", "") for s in raw["summary"]])
        else:
            raw = str(raw)

    # Force string type as a last resort
    if not isinstance(raw, str):
        raw = str(raw)

    # Strip leading/trailing spaces
    return raw.strip()

def run_sql(query: str) -> pd.DataFrame:
        with sql.connect(
            server_hostname="dbc-741c7540-6155.cloud.databricks.com",
            http_path="/sql/1.0/warehouses/e2eb5bcf69add002",
            access_token=os.environ["DATABRICKS_TOKEN"],
        ) as conn:
            return pd.read_sql(query, conn)
# -------------------------
# 5Ô∏è‚É£ SQL execution with auto-repair
# -------------------------
def execute_sql(sql_query: str) -> pd.DataFrame:
    """
    Execute SQL using Databricks SQL connector and return Pandas DataFrame.
    Auto-repair logic can be added here later.
    """
    try:
        print(sql_query)
        return run_sql(sql_query)
    except Exception as e:
        raise RuntimeError(f"SQL execution failed:\n{sql_query}\n\n{e}")



# -------------------------
# 6Ô∏è‚É£ PyFunc model
# -------------------------
import mlflow.pyfunc

class PizzaRAGPyFuncModel(mlflow.pyfunc.PythonModel):
    def load_context(self, context):
        self.registry = REGISTRY
        self.agent = SQLAgent(self.registry)
        self.llm = get_llm()

    def predict(self, context, model_input):
      # -------------------------
      # Extract query safely
      # -------------------------
      if isinstance(model_input, pd.DataFrame):
          query_text = str(model_input.iloc[0]["query"])
      elif isinstance(model_input, dict):
          query_text = str(model_input.get("query"))
      elif isinstance(model_input, str):
          query_text = model_input
      else:
          raise ValueError(f"Unsupported model_input type: {type(model_input)}")

      # -------------------------
      # Resolve intent
      # -------------------------
      intent = resolve_intent(self.llm, query_text)
      intent["entity"] = self.registry["metrics"][intent["metric"]]["entity"]

      q = query_text.lower()

      if "which month" in q or "by month" in q:
          intent["time_grain"] = "month"
          intent["entity"] = "pizza"
          intent["__exclude_grain__"] = True

      # -------------------------
      # Generate & execute SQL
      # -------------------------
      sql_query = self.agent.generate_sql(intent)
      df = execute_sql(sql_query)

      # -------------------------
      # Interpret result
      # üî• CRITICAL FIX HERE
      # -------------------------
      raw_answer = interpret_result(self.llm, query_text, df)

      # ‚úÖ FORCE STRING OUTPUT
      if isinstance(raw_answer, dict):
          # Prefer readable text if present
          if "summary" in raw_answer:
              answer_text = " ".join(
                  s.get("text", "") for s in raw_answer.get("summary", [])
              )
          else:
              answer_text = json.dumps(raw_answer)
      else:
          answer_text = str(raw_answer)

      # üö® HARD LIMIT to avoid any future issues
      answer_text = answer_text[:4000]

      # -------------------------
      # Return dataframe (simple types ONLY)
      # -------------------------
      return pd.DataFrame([{
          "answer": answer_text,          # ‚úÖ plain string
          "intent": json.dumps(intent),   # ‚úÖ string
          "sql": sql_query,               # ‚úÖ string
          "num_rows": int(len(df)),
          "num_documents": int(len(intent.get("joins", [])))
      }])


# -------------------------
# 7Ô∏è‚É£ MLflow model signature & log
# -------------------------
signature = ModelSignature(
    inputs=Schema([ColSpec("string", "query")]),
    outputs=Schema([
        ColSpec("string", "answer"),
        ColSpec("string", "intent"),
        ColSpec("string", "sql"),
        ColSpec("long", "num_rows"),
        ColSpec("long", "num_documents")
    ])
)

with mlflow.start_run():
    mlflow.pyfunc.log_model(
        name="analytics_pizza_rag_model",
        python_model=PizzaRAGPyFuncModel(),
        signature=signature,
        input_example={"query": "in which month pizza sale was high in first quarter"},
        registered_model_name="retail_analytics.pizza.analytics_pizza_rag_models",
    )


In [0]:
from databricks import sql

In [0]:
import json

# Sample model output (stringified dict inside 'answer')
sample_prediction = {
    "answer": "{'type': 'reasoning', 'summary': [{'type': 'summary_text', 'text': 'We need to answer: \"hoe many pizza sold in first 6 months\". Data is a list of pizza entries with metric_value. Likely each entry is a month? But no month info. Possibly first 6 entries correspond to first 6 months. So sum metric_value of first six items. Let\\'s compute: first six entries:\\n\\n1. ckn_alfredo_l: 95\\n2. spinach_fet_s: 245\\n3. soppressata_l: 222\\n4. ckn_pesto_s: 148\\n5. ital_supr_l: 368\\n6. green_garden_l: 53\\n\\nSum: 95+245=340; +222=562; +148=710; +368=1078; +53=1131.\\n\\nThus answer: 1,131 pizzas sold in first six months.'}]} A total of 1,131 pizzas were sold in the first six months.\""
}

def extract_final_answer(answer_raw: str) -> str:
    if not isinstance(answer_raw, str):
        return str(answer_raw)

    text = answer_raw.strip()

    # If reasoning block exists, take everything AFTER last closing brace
    if "}" in text:
        after_reasoning = text.rsplit("}", 1)[-1].strip()
        if after_reasoning:
            return after_reasoning

    # Fallback: return last sentence
    sentences = re.split(r'(?<=[.!?])\s+', text)
    return sentences[-1].strip()


# Test
clean_answer = extract_final_answer(sample_prediction)
print("Extracted answer:\n")
print(clean_answer)


In [0]:
import json
import re
sample_prediction = {
  "predictions": [
    {
      "answer": "{'type': 'reasoning', 'summary': [{'type': 'summary_text', 'text': \"We need total revenue in May. Sum all metric_value entries. Let's sum.\\n\\nI'll add sequentially.\\n\\nList values:\\n\\n318.75\\n708.0\\n182.25\\n336.0\\n731.25\\n567.0\\n311.25\\n1514.75\\n2427.75\\n828.75\\n788.5\\n448.0\\n477.0\\n408.0\\n287.5\\n684.75\\n800.0\\n1203.5\\n1680.0\\n487.5\\n225.0\\n324.0\\n544.5\\n577.5\\n528.0\\n107.85000000000001\\n1113.25\\n159.25\\n528.0\\n156.0\\n720.0\\n1402.5\\n368.0\\n1377.0\\n452.25\\n1846.75\\n594.0\\n1306.5\\n518.75\\n630.0\\n1053.0\\n324.0\\n1086.5\\n462.0\\n951.75\\n1275.75\\n429.0\\n140.25\\n242.25\\n770.5\\n\\nNow sum. I'll do cumulative.\\n\\nStart 0\\n+318.75 = 318.75\\n+708.0 = 1026.75\\n+182.25 = 1209.0\\n+336.0 = 1545.0\\n+731.25 = 2276.25\\n+567.0 = 2843.25\\n+311.25 = 3154.5\\n+1514.75 = 4669.25\\n+2427.75 = 7097.0\\n+828.75 = 7925.75\\n+788.5 = 8714.25\\n+448.0 = 9162.25\\n+477.0 = 9639.25\\n+408.0 = 10047.25\\n+287.5 = 10334.75\\n+684.75 = 11019.5\\n+800.0 = 11819.5\\n+1203.5 = 13023.0\\n+1680.0 = 14703.0\\n+487.5 = 15190.5\\n+225.0 = 15415.5\\n+324.0 = 15739.5\\n+544.5 = 16284.0\\n+577.5 = 16861.5\\n+528.0 = 17389.5\\n+107.85000000000001 = 17497.350000000002 (approx 17497.35)\\n+1113.25 = 18610.600000000002\\n+159.25 = 18769.850000000002\\n+528.0 = 193... wait compute: 18769.85+528 = 19297.85? Actually 18769.85+528 = 19297.85. Yes.\\n+156.0 = 19453.85\\n+720.0 = 20173.85\\n+1402.5 = 21576.35\\n+368.0 = 21944.35\\n+1377.0 = 23321.35\\n+452.25 = 23773.6\\n+1846.75 = 25620.35\\n+594.0 = 26214.35\\n+1306.5 = 27520.85\\n+518.75 = 28039.6\\n+630.0 = 28669.6\\n+1053.0 = 29722.6\\n+324.0 = 30046.6\\n+1086.5 = 31133.1\\n+462.0 = 31595.1\\n+951.75 = 32546.85\\n+1275.75 = 33822.6\\n+429.0 = 34251.6\\n+140.25 = 34391.85\\n+242.25 = 34634.1\\n+770.5 = 35404.6\\n\\nTotal revenue May = 35,404.6 (assuming units). Provide answer.\"}]} The total revenue for May is‚ÄØ35,404.6.",
      "intent": "{\"entity\": \"pizza\", \"metric\": \"total_revenue\", \"time_grain\": \"month\", \"window\": null, \"trend\": null, \"ranking\": null, \"joins\": [], \"filters\": [{\"field\": \"month\", \"operator\": \"=\", \"value\": \"may\"}], \"time_range\": null}",
      "sql": "\nWITH base AS (\n    SELECT\n        pizza_name_id, month, year,\n        SUM(total_revenue) AS metric_value\n    FROM retail_analytics.pizza.gold_pizza_metrics_monthly\n    \n    WHERE month = 5\n    GROUP BY pizza_name_id, month, year\n)\nSELECT * FROM base\n",
      "num_rows": 91,
      "num_documents": 0
    }
  ]
}
def get_dynamic_response(data):
    try:
        # 1. Access main components
        prediction = data['predictions'][0]
        full_text = prediction['answer']
        
        # 2. Parse the intent metadata
        intent = json.loads(prediction.get('intent', '{}'))
        metric = intent.get('metric', 'value').replace('_', ' ')
        entity = intent.get('entity', '')
        
        # Extract filters (e.g., "may")
        filters = intent.get('filters', [])
        filter_context = " for " + " ".join([f"{f['value']}" for f in filters]) if filters else ""

        # 3. Dynamic Value Extraction
        # Strategy A: If there's a ranking (Highest/Most used), look for the name after "is"
        if "ranking" in intent and intent["ranking"]:
            # This regex finds the word immediately following "is " at the end of the sentence
            match = re.search(r"is ([A-Z][a-z]+|[\w\s]+)(?:\.|\")?$", full_text.strip())
            if match:
                value = match.group(1).strip()
            else:
                value = full_text.split("is")[-1].strip().rstrip('.')
        
        # Strategy B: Look for bolded values (for revenue/counts)
        elif "**" in full_text:
            value = re.search(r'\*\*(.*?)\*\*', full_text).group(1)
            
        # Strategy C: Fallback to the last number in the text
        else:
            numbers = re.findall(r'[\d,]+\.?\d*', full_text)
            value = numbers[-1] if numbers else "not found"

        # 4. Construct Sentence
        if "ranking" in intent:
            return f"The most used {entity} for{filter_context} is {value}"
        else:
            return f"{metric.capitalize()}{filter_context} is {value}"

    except Exception as e:
        return f"Error parsing: {e}"

# --- TEST CASE ---


print(f"Extracted: {get_dynamic_response(sample_prediction)}")

In [0]:
def extract_clean_answer(data):
    try:
        # 1. Access the 'answer' string
        raw_text = data['predictions'][0]['answer']
        
        # 2. Find the position of the last closing brace '}' 
        # which marks the end of the reasoning block
        last_brace_index = raw_text.rfind('}')
        
        if last_brace_index != -1:
            # 3. Extract everything after the last brace
            clean_answer = raw_text[last_brace_index + 1:].strip()
            
            # 4. Optional: Clean up any trailing quotes or artifacts
            clean_answer = clean_answer.rstrip('"').strip()
            
            return clean_answer
            
        return raw_text # Fallback if no brace found
        
    except (KeyError, IndexError, TypeError):
        return "Invalid JSON structure"

print(extract_clean_answer(sample_prediction))
