In [None]:
# Databricks RAG Notebook (Python Script Export)
# Use # COMMAND ---------- as cell boundaries in Databricks.

# COMMAND ----------
CATALOG = "development"
SCHEMA = "materials_planning_analytics"

RAW_TABLE = f"{CATALOG}.{SCHEMA}.materials_transformation_ap_v"
RAG_TABLE = f"{CATALOG}.{SCHEMA}.apparel_materials_rag"

VECTOR_ENDPOINT = "materials_vs_endpoint"
VECTOR_INDEX = f"{CATALOG}.{SCHEMA}.materials_rag_index"

EMBED_MODEL = "databricks-bge-large-en"
GPT_ENDPOINT = "gpt-5.1"

from pyspark.sql import functions as F
from pyspark.sql.types import StringType, NumericType

# COMMAND ----------
# 1. Clean and Normalize All Columns
df_raw = spark.table(RAW_TABLE)
df = df_raw

for c in df.columns:
    if isinstance(df.schema[c].dataType, StringType):
        df = df.withColumn(c, F.trim(F.col(c)))

string_cols = [c for c in df.columns if isinstance(df.schema[c].dataType, StringType)]
df = df.fillna("", subset=string_cols)

numeric_cols = [c for c in df.columns if isinstance(df.schema[c].dataType, NumericType)]
for c in numeric_cols:
    df = df.withColumn(c, F.col(c).cast("double"))
df = df.fillna(0, subset=numeric_cols)

df.write.mode("overwrite").saveAsTable(RAG_TABLE)

# COMMAND ----------
# 2. material_json + summary_text
df = spark.table(RAG_TABLE)

material_json_col = F.to_json(F.struct(*[F.col(c) for c in df.columns]))
df = df.withColumn("material_json", material_json_col)

summary_cols = [
    "PRODUCT_CD","STYLE_CD","BOM_SEASON_CODE","FISCAL_YEAR","FACTORY_CD","GEO_CD",
    "PPG","FOP","DIMENSION","MATERIAL_FAMILY_NM","MATERIAL_INTENT_DESCRIPTION",
    "MATERIAL_CONTENT","WEIGHT_GRAMS_PER_SQUARE_METER","MATL_GAUGE_INCH",
    "CONSDRTN_AND_RISK_NM","LATEST_PRICE_PER_UOM","LATEST_PRICE_UOM"
]

existing = [c for c in summary_cols if c in df.columns]
summary_expr = F.concat_ws(" | ", *[F.col(c).cast("string") for c in existing])
df = df.withColumn("summary_text", summary_expr)

df.write.mode("overwrite").saveAsTable(RAG_TABLE)

# COMMAND ----------
# 3. Embeddings
spark.sql(f"""
ALTER TABLE {RAG_TABLE}
ADD COLUMN IF NOT EXISTS embedding ARRAY<DOUBLE>
""")

df = spark.table(RAG_TABLE)
df_emb = df.withColumn(
    "embedding",
    F.expr(f"ai_embeddings('{EMBED_MODEL}', material_json)")
)
df_emb.write.mode("overwrite").saveAsTable(RAG_TABLE)

# COMMAND ----------
# 4. Vector Endpoint + Index
spark.sql("""
CREATE VECTOR SEARCH ENDPOINT IF NOT EXISTS materials_vs_endpoint
TYPE "STANDARD";
""")

spark.sql(f"""
CREATE VECTOR SEARCH INDEX IF NOT EXISTS {VECTOR_INDEX}
ON TABLE {RAG_TABLE}
COLUMN embedding
OPTIONS (
  endpoint_name = "{VECTOR_ENDPOINT}",
  metric_type = "COSINE"
)
""")

# COMMAND ----------
# 5. Retrieval + GPT Reasoning
from databricks.vector_search.client import VectorSearchClient
from databricks.sdk import WorkspaceClient
import pandas as pd, textwrap

vsc = VectorSearchClient()
w = WorkspaceClient()

def embed_query(text: str):
    df_emb = spark.createDataFrame([(text,)], ["text"])
    return df_emb.selectExpr(
        f"ai_embeddings('{EMBED_MODEL}', text) AS emb"
    ).collect()[0]["emb"]

def retrieve_candidates(text, k=10):
    qvec = embed_query(text)
    index = vsc.get_index(endpoint_name=VECTOR_ENDPOINT, index_name=VECTOR_INDEX)

    res = index.similarity_search(
        query_vector=qvec,
        columns=["PRODUCT_CD","STYLE_CD","summary_text","material_json"],
        k=k
    )

    df = pd.DataFrame(res["result"]["data_array"])
    df["similarity"] = res["result"].get("scores", [None]*len(df))
    return df

def gpt_reason(query, df):
    table_md = df.to_markdown(index=False)
    system_prompt = '''
You are a senior apparel material developer.
Provide similarity scoring, substitution recommendations,
differences, and risk flags. Output top 3 recommended materials.
'''
    user_prompt = f"User request: {query}

Candidates:
{table_md}"

    resp = w.serving_endpoints.query(
        name=GPT_ENDPOINT,
        inputs={"messages":[
            {"role":"system","content":system_prompt},
            {"role":"user","content":user_prompt}
        ]},
    )
    return resp.output_text

# COMMAND ----------
# 6. Streamlit UI
import streamlit as st

st.title("Material Similarity & Substitution Assistant")

query = st.text_area("Describe your target material:")
k = st.slider("Candidates", 3, 20, 8)

if st.button("Search"):
    with st.spinner("Retrieving candidates..."):
        cands = retrieve_candidates(query, k)
    st.dataframe(cands)

    with st.spinner("GPT-5.1 reasoning..."):
        answer = gpt_reason(query, cands)
    st.write(answer)
