# Streamlit Rule Builder – Enhanced
Full CRUD on rules (add/edit/delete), preview impact, build draft, and approve.

In [None]:
import streamlit as st
from pyspark.sql import functions as F
from pyspark.sql import Row

st.set_page_config(page_title="Risk Rules", layout="wide")
st.title("Risk Rules – Builder & Approval (Enhanced)")

# --------- Config ---------
DEFAULT_CATALOG = "<CATALOG_NAME>"
ALLOWED_IMPACT_COLUMNS = {"risk_band","risk_points"}
ALLOWED_SEGMENTS = {"retail","corporate","all"}

# --------- Inputs ---------
catalog = st.text_input("Catalog (Unity Catalog)", value=DEFAULT_CATALOG, help="e.g., reporting_factory")
_ = spark.sql(f"USE CATALOG {catalog}")
spark.sql("USE SCHEMA control")

# Ensure table exists
spark.sql("""
CREATE TABLE IF NOT EXISTS control.risk_rules
(rule_id STRING, name STRING, segment STRING, condition_sql STRING,
 impact_column STRING, impact_value STRING, priority INT, enabled BOOLEAN,
 effective_from DATE, effective_to DATE, owner STRING, notes STRING)
USING DELTA
""")

st.caption("Editing below writes directly to control.risk_rules in Unity Catalog (auditable).")

# --------- Load rules ---------
rules_sdf = spark.table("control.risk_rules")
rules_pdf = rules_sdf.orderBy("priority","rule_id").toPandas()

# Editable columns (safe set)
editable_cols = ["name","segment","condition_sql","impact_column","impact_value","priority","enabled","notes"]

st.subheader("Edit Rules Inline")
st.write("Tip: Use `risk_band` values like Low/Medium/High, and `risk_points` like +10 / -5.")
edited = st.data_editor(
    rules_pdf[["rule_id"] + editable_cols],
    use_container_width=True,
    num_rows="dynamic",
    disabled=["rule_id"],  # editing rule_id is not supported in-place
    key="rules_editor"
)

# Detect changes by comparing to original
changed_rows = []
if st.button("Save Changes"):
    before = rules_pdf.set_index("rule_id")
    after = edited.set_index("rule_id")
    to_update = []

    for rid in after.index:
        if rid not in before.index:
            continue
        row_before = before.loc[rid]
        row_after = after.loc[rid]
        if not row_before.equals(row_after):
            # basic validations
            seg = row_after["segment"]
            col = row_after["impact_column"]
            imp = str(row_after["impact_value"])
            pri = int(row_after["priority"])
            en  = bool(row_after["enabled"])

            if seg not in ALLOWED_SEGMENTS:
                st.error(f"Rule {rid}: invalid segment '{seg}'. Allowed: {ALLOWED_SEGMENTS}")
                st.stop()
            if col not in ALLOWED_IMPACT_COLUMNS:
                st.error(f"Rule {rid}: invalid impact_column '{col}'. Allowed: {ALLOWED_IMPACT_COLUMNS}")
                st.stop()
            if col == "risk_points":
                import re
                if not re.fullmatch(r"[+\-]?\d+", imp):
                    st.error(f"Rule {rid}: impact_value must be integer-like for risk_points (e.g., +15). Got: {imp}")
                    st.stop()
            to_update.append((rid, row_after.to_dict()))

    if to_update:
        for rid, data in to_update:
            set_exprs = []
            for c in editable_cols:
                val = data[c]
                if isinstance(val, bool):
                    set_exprs.append(f"{c} = {str(val).lower()}")
                elif isinstance(val, (int, float)):
                    set_exprs.append(f"{c} = {val}")
                else:
                    # escape single quotes
                    sval = str(val).replace("'", "''")
                    set_exprs.append(f"{c} = '{sval}'")
            spark.sql(f"""UPDATE control.risk_rules SET {', '.join(set_exprs)} WHERE rule_id = '{rid}'""")
        st.success(f"Updated {len(to_update)} rule(s). Refresh to see changes.")
    else:
        st.info("No changes detected.")

# --------- Add new rule ---------
st.subheader("Add New Rule")
with st.form("add_rule"):
    c1, c2, c3 = st.columns(3)
    with c1:
        new_id = st.text_input("rule_id", placeholder="R5", max_chars=50)
        new_name = st.text_input("name", placeholder="fico_very_low")
        new_segment = st.selectbox("segment", sorted(ALLOWED_SEGMENTS))
        new_priority = st.number_input("priority", value=50, min_value=1, step=1)
        new_enabled = st.checkbox("enabled", value=True)
    with c2:
        new_impact_column = st.selectbox("impact_column", sorted(ALLOWED_IMPACT_COLUMNS))
        new_impact_value = st.text_input("impact_value", placeholder="High or +15")
        new_owner = st.text_input("owner", value="risk_ops")
    with c3:
        new_condition = st.text_area("condition_sql", placeholder="f.fico_score < 580 AND f.dti > 35", height=100)
        new_notes = st.text_input("notes", placeholder="Explain the rule intent")
    submitted = st.form_submit_button("Insert Rule")

if submitted:
    # validations
    if not new_id or not new_name or not new_condition:
        st.error("rule_id, name and condition_sql are required.")
        st.stop()
    if new_segment not in ALLOWED_SEGMENTS:
        st.error(f"segment must be one of {ALLOWED_SEGMENTS}")
        st.stop()
    if new_impact_column not in ALLOWED_IMPACT_COLUMNS:
        st.error(f"impact_column must be one of {ALLOWED_IMPACT_COLUMNS}")
        st.stop()
    if new_impact_column == "risk_points":
        import re
        if not re.fullmatch(r"[+\-]?\d+", new_impact_value or ""):
            st.error("impact_value must be integer-like for risk_points (e.g., +20)")
            st.stop()
    # escape strings
    def esc(s): return str(s).replace("'", "''")
    spark.sql(f"""
        INSERT INTO control.risk_rules
        (rule_id, name, segment, condition_sql, impact_column, impact_value, priority, enabled, 
         effective_from, effective_to, owner, notes)
        VALUES
        ('{esc(new_id)}','{esc(new_name)}','{new_segment}','{esc(new_condition)}','{new_impact_column}',
         '{esc(new_impact_value)}',{int(new_priority)},{str(bool(new_enabled)).lower()}, current_date, NULL,
         '{esc(new_owner)}','{esc(new_notes)}')
    """)
    st.success(f"Inserted rule {new_id}.")

# --------- Delete rule ---------
st.subheader("Delete Rule")
rid_del = st.text_input("rule_id to delete", placeholder="R3")
if st.button("Delete Rule"):
    if not rid_del:
        st.warning("Enter a rule_id to delete.")
    else:
        spark.sql(f"DELETE FROM control.risk_rules WHERE rule_id = '{rid_del.replace("'","''")}'")
        st.success(f"Deleted rule {rid_del} (if it existed).")

st.markdown("---")
st.subheader("Preview Impact (1% sample)")
if st.button("Run Preview on Sample"):
    spark.sql("USE SCHEMA gold")
    f = spark.table("gold.features").sample(0.01, seed=42).alias("f")
    r = spark.table("control.risk_rules").where(
        "enabled = true AND current_date BETWEEN effective_from AND coalesce(effective_to, date'2999-12-31')"
    )
    matches = None
    for row in r.collect():
        pred = row["condition_sql"]
        part = f.selectExpr(
            "loan_id",
            f"'{row['rule_id']}' as rule_id",
            f"'{row['impact_column']}' as impact_column",
            f"'{row['impact_value']}' as impact_value",
            f"{row['priority']} as priority"
        ).where(pred)
        matches = part if matches is None else matches.unionByName(part)
    if matches is None:
        preview = f.withColumn("matched_rules", F.array())\
                   .withColumn("risk_band", F.lit("Medium"))\
                   .withColumn("risk_points", F.lit(0))
    else:
        from pyspark.sql.window import Window
        w = Window.partitionBy("loan_id","impact_column").orderBy(F.desc("priority"))
        top = matches.withColumn("rn", F.row_number().over(w)).where("rn=1")
        preview = (f.join(top, on="loan_id", how="left")
                     .groupBy(f.columns)
                     .agg(F.collect_list("rule_id").alias("matched_rules"),
                          F.max(F.when(F.col("impact_column")=="risk_band", F.col("impact_value"))).alias("risk_band"),
                          F.sum(F.when(F.col("impact_column")=="risk_points",
                                       F.regexp_extract(F.col("impact_value"), "[-+]?\d+", 0).cast("int")).otherwise(0)).alias("risk_points")))\
                     .fillna({"risk_band":"Medium","risk_points":0,"matched_rules":[]})
    c1, c2, c3 = st.columns(3)
    c1.metric("Loans (sample)", int(preview.count()))
    c2.metric("Avg DTI", float(preview.select(F.avg("dti")).first()[0] or 0))
    c3.metric("Avg FICO", float(preview.select(F.avg("fico_score")).first()[0] or 0))
    st.write("High risk count (sample):", preview.select(F.sum((F.col("risk_band")=="High").cast("int")).alias("high_risk_count")).toPandas())
    st.success("Preview complete.")

st.markdown("---")
st.subheader("Generate Draft Report")
if st.button("Build Draft Now"):
    spark.sql("USE SCHEMA gold")
    from datetime import datetime
    run_id = f"RR_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}"
    spark.sql(f"INSERT INTO gold.report_runs VALUES ('{run_id}','Risk Report','DRAFT','rules@current',current_timestamp(),NULL,NULL,NULL)")
    spark.sql("""
        CREATE OR REPLACE TEMP VIEW _kpi AS
        SELECT 'loans_total' as metric, 'all' as dimension, count(*)*1.0 as value FROM gold.risk_eval
        UNION ALL
        SELECT 'high_risk_count','all', sum(CASE WHEN risk_band='High' OR risk_points>=20 THEN 1 ELSE 0 END) FROM gold.risk_eval
        UNION ALL
        SELECT 'avg_dti','all', avg(dti) FROM gold.risk_eval
        UNION ALL
        SELECT 'avg_fico','all', avg(fico_score) FROM gold.risk_eval
    """)
    spark.sql(f"DELETE FROM gold.report_facts WHERE report_run_id = '{run_id}'")
    spark.sql(f"INSERT INTO gold.report_facts SELECT '{run_id}', metric, dimension, value FROM _kpi")
    st.success(f"Draft report created: {run_id}")

st.subheader("Approve Latest Draft")
if st.button("Approve Most Recent Draft"):
    latest = spark.sql("SELECT report_run_id FROM gold.report_runs WHERE status='DRAFT' ORDER BY started_at DESC LIMIT 1").collect()
    if latest:
        rid = latest[0]["report_run_id"]
        spark.sql(f"""
            UPDATE gold.report_runs
            SET status='APPROVED', approved_by=current_user(), approved_at=current_timestamp()
            WHERE report_run_id='{rid}' AND status='DRAFT'
        """)
        spark.sql("""
            CREATE OR REPLACE VIEW gold.report_facts_approved_latest AS
            SELECT rf.*
            FROM gold.report_facts rf
            JOIN (
              SELECT report_run_id FROM gold.report_runs WHERE status='APPROVED' ORDER BY approved_at DESC LIMIT 1
            ) latest USING (report_run_id);
        """)
        st.success(f"Approved run: {rid}")
    else:
        st.warning("No draft runs found.")
