In [None]:
# =========================================================
# Notebook: 01_load_bronze_from_sqldb
# Purpose : FULL or INCREMENTAL load from Fabric SQL DB -> Lakehouse Bronze Delta
# Auth    : Microsoft Entra ID token (no username/password)
# Watermark: ops_watermark.last_modified_ts using source modified_ts
# =========================================================

from pyspark.sql import functions as F
from delta.tables import DeltaTable
import uuid
from datetime import datetime

# -----------------------------
# Notebook Parameters (Pipeline passes these)
# -----------------------------
load_mode = mssparkutils.env.getJobParameter("load_mode", "FULL")          # FULL | INCR
pipeline_name = mssparkutils.env.getJobParameter("pipeline_name", "rp_orchestrator_dev")
run_date = mssparkutils.env.getJobParameter("run_date", datetime.utcnow().strftime("%Y-%m-%d"))

# -----------------------------
# SQL DB connection
# -----------------------------
SQL_SERVER = "g74r6ummslpuxca2gs6zhaj3aa-tux4sdtf5tvu3n6griblkgmll4.database.fabric.microsoft.com"
SQL_DATABASE = "rp_sqldb_dev-7d103162-1cbc-4e63-baef-4072912d791a"
SQL_SCHEMA = "dbo"

TABLES = [
    "customers",
    "orders",
    "order_items",
    "payments",
    "inventory",
    "returns"
]

# Primary keys (match your actual schemas)
KEYS = {
    "customers":   ["customer_id"],
    "orders":      ["order_id"],
    "order_items": ["order_item_id"],
    "payments":    ["payment_id"],
    "inventory":   ["product_id"],
    "returns":     ["return_id"],
}

BRONZE_PREFIX = "bronze_"

run_id = str(uuid.uuid4())
run_ts = datetime.utcnow()

print(f"Mode: {load_mode}")
print(f"RunId: {run_id}")
print(f"Pipeline: {pipeline_name}")
print(f"RunDate: {run_date}")
print(f"RunTS(UTC): {run_ts}")

# -----------------------------
# JDBC + Entra token
# -----------------------------
jdbc_url = (
    f"jdbc:sqlserver://{SQL_SERVER}:1433;"
    f"database={SQL_DATABASE};"
    "encrypt=true;"
    "trustServerCertificate=false;"
    "loginTimeout=30;"
)

token = mssparkutils.credentials.getToken("https://database.windows.net/")

jdbc_props = {
    "driver": "com.microsoft.sqlserver.jdbc.SQLServerDriver",
    "accessToken": token
}

# -----------------------------
# ops_watermark table (must exist from 00_ops_watermark_init)
# -----------------------------
wm = spark.table("ops_watermark")

def get_watermark(table_name: str):
    r = (wm.filter(F.col("table_name") == table_name)
           .select("last_modified_ts")
           .limit(1)
           .collect())
    return r[0]["last_modified_ts"] if r else None

def upsert_watermark(table_name: str, new_ts):
    if new_ts is None:
        return
    spark.sql(f"""
      MERGE INTO ops_watermark t
      USING (SELECT '{table_name}' AS table_name,
                    TIMESTAMP('{new_ts}') AS last_modified_ts,
                    current_timestamp() AS updated_at) s
      ON t.table_name = s.table_name
      WHEN MATCHED THEN UPDATE SET
        t.last_modified_ts = s.last_modified_ts,
        t.updated_at = s.updated_at
      WHEN NOT MATCHED THEN INSERT (table_name,last_modified_ts,updated_at)
      VALUES (s.table_name,s.last_modified_ts,s.updated_at)
    """)

# -----------------------------
# Read helpers
# -----------------------------
def read_full(table_name: str):
    fq = f"{SQL_SCHEMA}.{table_name}"
    return spark.read.jdbc(url=jdbc_url, table=fq, properties=jdbc_props)

def read_incremental(table_name: str, last_ts):
    # Use subquery so JDBC pushes filter to SQL engine
    last_ts_str = last_ts.strftime("%Y-%m-%d %H:%M:%S") if last_ts else "1900-01-01 00:00:00"
    q = f"(SELECT * FROM {SQL_SCHEMA}.{table_name} WHERE modified_ts > '{last_ts_str}') AS src"
    return spark.read.jdbc(url=jdbc_url, table=q, properties=jdbc_props)

# -----------------------------
# Merge helper
# -----------------------------
spark.conf.set("spark.databricks.delta.schema.autoMerge.enabled", "true")

def merge_into_bronze(bronze_table: str, df, key_cols):
    if not spark.catalog.tableExists(bronze_table):
        (df.write.format("delta")
           .mode("overwrite")
           .option("overwriteSchema","true")
           .saveAsTable(bronze_table))
        return "CREATED"

    tgt = DeltaTable.forName(spark, bronze_table)

    cond = " AND ".join([f"t.{k} = s.{k}" for k in key_cols])
    (tgt.alias("t")
        .merge(df.alias("s"), cond)
        .whenMatchedUpdateAll()
        .whenNotMatchedInsertAll()
        .execute())

    return "MERGED"

# -----------------------------
# Load loop
# -----------------------------
for t in TABLES:
    bronze_table = f"{BRONZE_PREFIX}{t}"
    key_cols = KEYS[t]

    print(f"\nüöÄ Loading {t} -> {bronze_table}")

    if load_mode.upper() == "FULL":
        src = read_full(t)
    else:
        last_ts = get_watermark(t)
        src = read_incremental(t, last_ts)

    # If incremental returns no rows, skip
    if src.rdd.isEmpty():
        print(f"‚è≠Ô∏è No new rows for {t} (incremental).")
        continue

    # Add bronze metadata
    out = (src
        .withColumn("bronze_ingest_ts", F.lit(run_ts).cast("timestamp"))
        .withColumn("bronze_run_id", F.lit(run_id))
        .withColumn("bronze_source", F.lit(f"{SQL_DATABASE}.{SQL_SCHEMA}.{t}"))
    )

    # Merge or overwrite
    if load_mode.upper() == "FULL":
        (out.write.format("delta")
           .mode("overwrite")
           .option("overwriteSchema","true")
           .saveAsTable(bronze_table))
        action = "OVERWRITTEN"
    else:
        action = merge_into_bronze(bronze_table, out, key_cols)

    # Update watermark using max(modified_ts) from this batch
    new_wm = out.select(F.max("modified_ts").alias("mx")).collect()[0]["mx"]
    upsert_watermark(t, new_wm)

    cnt = out.count()
    print(f"‚úÖ {bronze_table} {action} | rows processed = {cnt} | new watermark = {new_wm}")

print("\n‚úÖ Bronze load completed.")
display(spark.table("ops_watermark").orderBy("table_name"))
