In [0]:
from pyspark.sql.functions import sum, count, countDistinct, col

# -------------------------------
# READ ONLY REQUIRED COLUMNS
# -------------------------------
fact_sales_df = (
    spark.table("real_time_projects.ecommerce_historical.fact_sales")
    .select(
        "customer_state",
        "order_id",
        "revenue"
    )
)

# ---------------------------------------------
# REPARTITION BEFORE AGGREGATION (KEY POINT)
# ---------------------------------------------
fact_sales_df = fact_sales_df.repartition("customer_state")

# ------------------------
# TRANSFORM & AGGREGATE
# ------------------------
gold_sales_by_state_df = (
    fact_sales_df
    .groupBy("customer_state")
    .agg(
        sum("revenue").alias("total_revenue"),
        countDistinct("order_id").alias("total_orders"),
        count("*").alias("total_items")
    )
)

# --------------------------
# WRITE GOLD DELTA TABLE
# --------------------------
(
    gold_sales_by_state_df
    .write
    .format("delta")
    .mode("overwrite")
    .partitionBy("customer_state")
    .option("overwriteSchema", "true")
    .saveAsTable("real_time_projects.ecommerce_historical.gold_sales_by_state")
)

# ----------------------
# DELTA OPTIMIZATION
# ----------------------

# ----------------------
# ENABLE AUTO OPTIMIZE
# ----------------------
spark.sql("""
ALTER TABLE real_time_projects.ecommerce_historical.gold_sales_by_state
SET TBLPROPERTIES (
  delta.autoOptimize.optimizeWrite = true,
  delta.autoOptimize.autoCompact = true
)
""")