In [0]:
# ============================================================
# 03_gold_marts.py (REVISED: category/seller revenue = items_subtotal)
#
# PURPOSE
# -------
# Build GOLD marts (pre-aggregated tables) for multiple dashboard tiles.
#
# IMPORTANT DESIGN DECISION (per your request)
# --------------------------------------------
# For category-level and seller-level "revenue", we use:
#   items_subtotal = SUM(order_items.price)
#
# Why:
# - It's the cleanest metric at the item grain
# - It does NOT double count like order-level payment totals can
# - It matches how category/seller sales are naturally measured
#
# Sources (Silver tables this script expects)
# ------------------------------------------
# - olist.silver.orders            columns used: order_id, customer_id, order_purchase_ts, order_purchase_date,
#                                 order_delivered_customer_ts, order_estimated_delivery_ts
# - olist.silver.order_items       columns used: order_id, order_item_id, product_id, seller_id, price, freight_value
# - olist.silver.order_payments    columns used: order_id, payment_type, payment_installments, payment_value
# - olist.silver.customers         columns used: customer_id, customer_state
# - olist.silver.products          columns used: product_id, product_category_name
# - olist.silver.category_translation columns used: product_category_name, product_category_name_english
# - olist.silver.sellers           columns used: seller_id, seller_city, seller_state
#
# Gold tables created
# -------------------
# 1) olist.gold.orders_daily
# 2) olist.gold.revenue_by_category
# 3) olist.gold.revenue_by_state
# 4) olist.gold.top_sellers
# 5) olist.gold.delivery_monthly
# 6) olist.gold.payment_type_mix
# ============================================================


# ----------------------------
# 0) Imports
# ----------------------------

from pyspark.sql.functions import col  # Reference dataframe columns by name
from pyspark.sql.functions import sum as fsum  # Aggregate sum (alias to avoid shadowing Python sum)
from pyspark.sql.functions import avg as favg  # Aggregate average
from pyspark.sql.functions import count as fcount  # Aggregate count
from pyspark.sql.functions import countDistinct  # Aggregate distinct counts
from pyspark.sql.functions import when  # Conditional expressions
from pyspark.sql.functions import date_trunc  # Truncate timestamps to month/day grain
from pyspark.sql.functions import datediff  # Difference in days between two dates/timestamps
from pyspark.sql.functions import coalesce  # First non-null expression
from pyspark.sql.functions import lower  # Normalize strings to lowercase
from pyspark.sql.functions import trim  # Trim whitespace from strings


# ----------------------------
# 1) Configuration
# ----------------------------

SILVER_SCHEMA = "olist.silver"  # Silver schema name
GOLD_SCHEMA = "olist.gold"  # Gold schema name
GOLD_WRITE_MODE = "overwrite"  # Rebuild gold marts every run (simple and stable for dashboards)

spark.sql(f"CREATE SCHEMA IF NOT EXISTS {GOLD_SCHEMA}")  # Ensure gold schema exists


# ----------------------------
# 2) Helper: Write UC managed gold table
# ----------------------------

def write_gold_table(df, table_name: str):
    """
    Write df to Unity Catalog as a managed Delta table in olist.gold.
    """

    full_name = f"{GOLD_SCHEMA}.{table_name}"  # Fully-qualified UC name for the table

    (  # Begin write chain
        df.write  # DataFrameWriter
        .format("delta")  # Delta format
        .mode(GOLD_WRITE_MODE)  # Overwrite for rebuild
        .option("overwriteSchema", "true")  # Allow schema updates on overwrite
        .saveAsTable(full_name)  # UC managed table write
    )  # End write chain


# ============================================================
# 3) Load Silver tables (these must exist)
# ============================================================

orders = spark.table(f"{SILVER_SCHEMA}.orders")  # Orders cleaned table
order_items = spark.table(f"{SILVER_SCHEMA}.order_items")  # Order items cleaned table
order_payments = spark.table(f"{SILVER_SCHEMA}.order_payments")  # Payments cleaned table
customers = spark.table(f"{SILVER_SCHEMA}.customers")  # Customers cleaned table
products = spark.table(f"{SILVER_SCHEMA}.products")  # Products cleaned table
sellers = spark.table(f"{SILVER_SCHEMA}.sellers")  # Sellers cleaned table
category_translation = spark.table(f"{SILVER_SCHEMA}.category_translation")  # Category translation cleaned table


# ============================================================
# 4) Build reusable aggregates (from Silver) for Gold marts
# ============================================================

# ----------------------------
# 4A) Payments aggregated at the order level
# ----------------------------
# We use payments for:
# - orders_daily revenue
# - payment_type_mix
# We do NOT use payment totals for category/seller revenue to avoid double counting.

payments_by_order = (
    order_payments
    .groupBy(col("order_id"))  # One row per order
    .agg(
        fsum(col("payment_value")).alias("order_revenue"),  # Total paid for the order
        favg(col("payment_installments")).alias("avg_installments_per_order")  # Average installments per order
    )
)

# ----------------------------
# 4B) Items aggregated at the order level
# ----------------------------
# We use items_subtotal to represent line-item sales.
# This is clean for category/seller attribution.

items_by_order = (
    order_items
    .groupBy(col("order_id"))  # One row per order
    .agg(
        fsum(col("price")).alias("items_subtotal"),  # Sum of item prices (line-item sales)
        fsum(col("freight_value")).alias("freight_total"),  # Total freight for the order
        fcount(col("order_item_id")).alias("items_sold")  # Total items sold in the order
    )
)

# ----------------------------
# 4C) Orders fact table (order grain)
# ----------------------------
# This provides a convenient order-grain dataset for several marts.

orders_fact = (
    orders
    .select(  # Select only fields we need (keeps plan smaller/faster)
        col("order_id"),
        col("customer_id"),
        col("order_purchase_ts"),
        col("order_purchase_date"),
        col("order_delivered_customer_ts"),
        col("order_estimated_delivery_ts"),
    )
    .join(payments_by_order, on="order_id", how="left")  # Attach payments
    .join(items_by_order, on="order_id", how="left")  # Attach item totals
    .join(customers.select(col("customer_id"), col("customer_state")), on="customer_id", how="left")  # Attach state
)

# Fill null numeric metrics with 0 so aggregations behave predictably
orders_fact = (
    orders_fact
    .withColumn("order_revenue", coalesce(col("order_revenue"), col("items_subtotal"), col("freight_total"), col("order_revenue")))  # Revenue fallback
    .withColumn("items_subtotal", coalesce(col("items_subtotal"), col("items_subtotal")))  # Keep subtotal (or NULL -> NULL; fine)
    .withColumn("items_sold", coalesce(col("items_sold"), col("items_sold")))  # Keep items_sold (or NULL -> NULL; fine)
)


# ============================================================
# GOLD 1) orders_daily (time series)
# - Revenue uses payments (order_revenue)
# - Item sales uses items_subtotal and items_sold
# ============================================================

orders_daily = (
    orders_fact
    .groupBy(col("order_purchase_date").alias("order_date"))  # One row per purchase date
    .agg(
        countDistinct(col("order_id")).alias("orders_count"),  # Total orders per day
        fsum(col("order_revenue")).alias("revenue"),  # Payment-based revenue per day
        fsum(col("items_subtotal")).alias("items_subtotal"),  # Line-item sales subtotal per day
        fsum(col("items_sold")).alias("items_sold"),  # Total items per day
        favg(col("order_revenue")).alias("avg_order_value"),  # Average order revenue per day
    )
    .orderBy(col("order_date"))  # Sort for a time-series chart
)

write_gold_table(orders_daily, "orders_daily")  # Write olist.gold.orders_daily


# ============================================================
# GOLD 2) revenue_by_category (categorical distribution)
# - Revenue metric = items_subtotal (sum of item prices) at category level
# ============================================================

# Join order_items -> products -> category_translation to get category names (English where possible)
items_with_category = (
    order_items
    .select("order_id", "order_item_id", "product_id", "price")  # Keep only fields needed for this mart
    .join(products.select("product_id", "product_category_name"), on="product_id", how="left")  # Bring in PT category name
    .join(
        category_translation.select(
            col("product_category_name").alias("pt_category_name"),  # Rename for join clarity
            col("product_category_name_english").alias("category_english")  # English category label
        ),
        on=(lower(trim(col("product_category_name"))) == lower(trim(col("pt_category_name")))),  # Normalize and join
        how="left"
    )
    .withColumn("category", coalesce(col("category_english"), col("product_category_name")))  # Fallback to PT if no translation
)

revenue_by_category = (
    items_with_category
    .groupBy(col("category"))  # One row per category
    .agg(
        countDistinct(col("order_id")).alias("orders_count"),  # Distinct orders containing that category
        fcount(col("order_item_id")).alias("items_sold"),  # Number of line items sold (items)
        fsum(col("price")).alias("items_subtotal"),  # Category "revenue" = sum of item price
        favg(col("price")).alias("avg_item_price"),  # Average item price within category
    )
    .orderBy(col("items_subtotal").desc())  # Sort by category revenue
)

write_gold_table(revenue_by_category, "revenue_by_category")  # Write olist.gold.revenue_by_category


# ============================================================
# GOLD 3) revenue_by_state (categorical distribution)
# - Revenue metric = payments (order_revenue) at state level (order grain)
# ============================================================

revenue_by_state = (
    orders_fact
    .groupBy(col("customer_state"))  # One row per state
    .agg(
        countDistinct(col("order_id")).alias("orders_count"),  # Orders in that state
        countDistinct(col("customer_id")).alias("customers_count"),  # Unique customers
        fsum(col("order_revenue")).alias("revenue"),  # Payment-based revenue
        fsum(col("items_subtotal")).alias("items_subtotal"),  # Also show subtotal (nice extra metric)
    )
    .orderBy(col("revenue").desc())  # Sort by revenue
)

write_gold_table(revenue_by_state, "revenue_by_state")  # Write olist.gold.revenue_by_state


# ============================================================
# GOLD 4) top_sellers (leaderboard)
# - Seller revenue metric = items_subtotal (sum of item prices) per seller
# ============================================================

seller_sales = (
    order_items
    .select("seller_id", "order_id", "order_item_id", "price", "freight_value")  # Fields needed for seller mart
    .groupBy(col("seller_id"))  # One row per seller
    .agg(
        countDistinct(col("order_id")).alias("orders_count"),  # How many orders this seller appeared in
        fcount(col("order_item_id")).alias("items_sold"),  # How many items sold by seller
        fsum(col("price")).alias("items_subtotal"),  # Seller revenue = sum of item price
        fsum(col("freight_value")).alias("freight_total"),  # Freight total (optional)
        favg(col("price")).alias("avg_item_price"),  # Average item price for seller
    )
)

top_sellers = (
    seller_sales
    .join(sellers.select("seller_id", "seller_city", "seller_state"), on="seller_id", how="left")  # Add seller attributes
    .orderBy(col("items_subtotal").desc())  # Sort by seller revenue
)

write_gold_table(top_sellers, "top_sellers")  # Write olist.gold.top_sellers


# ============================================================
# GOLD 5) delivery_monthly (time series)
# - Delivery metrics use orders timestamps from Silver
# ============================================================

delivered_orders = (
    orders
    .filter(col("order_delivered_customer_ts").isNotNull())  # Only delivered orders have delivery performance metrics
    .withColumn("order_month", date_trunc("month", col("order_purchase_ts")))  # Month grain based on purchase time
    .withColumn("delivery_days", datediff(col("order_delivered_customer_ts"), col("order_purchase_ts")))  # Delivery duration in days
    .withColumn(
        "delivered_on_time",
        when(col("order_delivered_customer_ts") <= col("order_estimated_delivery_ts"), 1).otherwise(0)  # 1 if on-time else 0
    )
)

delivery_monthly = (
    delivered_orders
    .groupBy(col("order_month"))  # One row per month
    .agg(
        fcount(col("order_id")).alias("delivered_orders"),  # Count of delivered orders
        favg(col("delivery_days")).alias("avg_delivery_days"),  # Average delivery time
        (fsum(col("delivered_on_time")) / fcount(col("order_id"))).alias("pct_delivered_on_time"),  # On-time percentage
    )
    .orderBy(col("order_month"))  # Sort for time series charts
)

write_gold_table(delivery_monthly, "delivery_monthly")  # Write olist.gold.delivery_monthly


# ============================================================
# GOLD 6) payment_type_mix (categorical distribution)
# - Uses payment rows directly (one order can have multiple payment rows)
# ============================================================

payment_type_mix = (
    order_payments
    .withColumn("payment_type", lower(trim(col("payment_type"))))  # Normalize payment type
    .groupBy(col("payment_type"))  # One row per payment type
    .agg(
        countDistinct(col("order_id")).alias("orders_count"),  # Orders that used this payment type
        fsum(col("payment_value")).alias("revenue"),  # Total payment value for this type
        favg(col("payment_installments")).alias("avg_installments"),  # Average installments
    )
    .orderBy(col("orders_count").desc())  # Sort by popularity
)

write_gold_table(payment_type_mix, "payment_type_mix")  # Write olist.gold.payment_type_mix


# ============================================================
# 6) Validation: row counts for all Gold tables
# ============================================================

gold_tables = [  # List all gold tables created
    "orders_daily",
    "revenue_by_category",
    "revenue_by_state",
    "top_sellers",
    "delivery_monthly",
    "payment_type_mix",
]

for t in gold_tables:  # Loop through each gold table
    display(  # Display row counts in notebook output
        spark.sql(  # Execute SQL query via Spark
            f"SELECT '{GOLD_SCHEMA}.{t}' AS table_name, COUNT(*) AS row_count FROM {GOLD_SCHEMA}.{t}"
        )
    )
