In [0]:
# Referential Integrity Check Before fact_sales Merge
# Before merging fact_sales, check if all customer_id values exist in dim_customer.
# Identify orphaned customer_id values in fact_sales
# If any records appear in this query, it means customer_id values exist in fact_sales but not in dim_customer.
# You can reject, log, or hold these records until their dimensions arrive.


from pyspark.sql import functions as F

# Load the fact_sales and dim_customer tables

orders_df = spark.read.format("delta").load("/mnt/Prajwal/Retail_sales_usecase/Silver/SilverCDetails")
dim_customer_df = spark.table("retail.customerdimension")


display(orders_df)
display(dim_customer_df)

In [0]:
# Perform a left join to identify orphaned customer_id in fact_sales
orphaned_customer_ids = orders_df.join(
    dim_customer_df, 
    orders_df.customer_id == dim_customer_df.customer_id, 
    how='left'
).filter(dim_customer_df.customer_id.isNull()) \
  .select(orders_df.customer_id)

  # Get distinct orphaned customer_id values
orphaned_customer_ids_distinct = orphaned_customer_ids.distinct()

# Show orphaned customer_id values
orphaned_customer_ids_distinct.show()

# Count orphaned records if needed
orphaned_count = orphaned_customer_ids_distinct.count()
print(f"Number of orphaned records: {orphaned_count}")

In [0]:
# Load the fact_sales and dim_customer tables

orders_df = spark.read.format("delta").load("/mnt/Prajwal/Retail_sales_usecase/Silver/SilverPDetails")
dim_product_df = spark.table("retail.productdimension")

# Perform a left join to identify orphaned customer_id in fact_sales
orphaned_product_ids = orders_df.join(
    dim_product_df, 
    orders_df.product_id == dim_product_df.product_id, 
    how='left'
).filter(dim_product_df.product_id.isNull()) \
  .select(orders_df.product_id)

# Get distinct orphaned customer_id values
orphaned_product_ids_distinct = orphaned_product_ids.distinct()

# Show orphaned customer_id values
orphaned_product_ids_distinct.show()

# Count orphaned records if needed
orphaned_count = orphaned_product_ids_distinct.count()
print(f"Number of orphaned records: {orphaned_count}")

# **FACT **TABLE****

In [0]:
# Broadcast Join (Performance Boost) - broadcast(product_df) - Avoids expensive shuffle joins
# Late Arriving Data Handling - fillna(-1) for customer_id - Assigns a default key for missing dimensions
# Partitioning Strategy - repartition("date_key") - Partitions data efficiently for faster reads
# ZORDER Optimization - OPTIMIZE FactSales ZORDER BY (customer_id, product_id); - Speeds up queries on frequent filters
# Caching - Avoid recomputation of joins using persist, Speed up writes and Free up memory

from pyspark.sql.functions import *
from pyspark.sql.window import Window
from delta.tables import DeltaTable
from pyspark import StorageLevel

# ---------------------------
# Load Silver Data
# ---------------------------
orders_df = spark.read.format("delta").load("/mnt/Prajwal/Retail_sales_usecase/Silver/silverODetails")
products_df = spark.read.format("delta").load("/mnt/Prajwal/Retail_sales_usecase/Silver/SilverPDetails")
customer_df = spark.read.format("delta").load("/mnt/Prajwal/Retail_sales_usecase/Silver/SilverCDetails")
customerInsight_df = spark.read.format("delta").load("/mnt/Prajwal/Retail_sales_usecase/Silver/SilverCInsights")

# ---------------------------
# Incremental Load: Exclude already processed order_ids
# ---------------------------
try:
    gold_fact_df = spark.read.format("delta").load("/mnt/Prajwal/Retail_sales_usecase/Gold")
    existing_order_ids = gold_fact_df.select("order_id").distinct()
    orders_df = orders_df.join(existing_order_ids, on="order_id", how="left_anti")
except:
    print("Gold fact table doesn't exist. Full load will be performed.")

In [0]:
# Broadcast Join Optimization for Small Dimensions
# ---------------------------
products_df = broadcast(products_df)  # Assuming product table is small

# ---------------------------
# Handle Late Arriving Dimensions
# ---------------------------
# Assign a default surrogate key (-1) for missing customers
customer_df = customer_df.withColumn("customer_id_late", col("customer_id")).fillna(-1)

In [0]:
%python
from pyspark.sql.functions import udf, col
from pyspark.sql.types import TimestampType
from datetime import datetime

# ---------------------------
# Repartition Orders Data to Reduce Shuffle Before Join
# Repartition Order Data to Reduce Shuffle Before Join
# Instead of Spark randomly distributing data, it ensures that all records with the same customer_id go to the same partition.
# This reduces the number of data movements (shuffles) needed when joining with customer_df.
# Prepares data for efficient joins
# Tuning the number of partitions (100, 10 etc) depends on data size
# ---------------------------
orders_df = orders_df.repartition(100, "customer_id")  # Adjust partitions as per data size


def parse_with_year(date_str):
    try:
        if not date_str:
            return None
        
        # Remove leading and trailing spaces
        date_str = date_str.strip()
        
        # Try parsing with two-digit year format
        try:
            date_obj = datetime.strptime(date_str, '%m/%d/%y %H:%M')
        except ValueError:
            # If it fails, try parsing with four-digit year format
            date_obj = datetime.strptime(date_str, '%m/%d/%Y %H:%M')

        # Adjust the year if it is greater than the current year + 5
        if date_obj.year > datetime.now().year + 5:
            date_obj = date_obj.replace(year=date_obj.year - 100)
            
        return date_obj
    except Exception as e:
        return None

# Register the UDF
parse_udf = udf(parse_with_year, TimestampType())

# ---------------------------
# Convert order_date to date format first (before any further transformation)
# ---------------------------
orders_df = orders_df.withColumn("order_date", parse_udf(col("order_date")))

display(orders_df)

In [0]:
x# ---------------------------
# Join with Product info
# ---------------------------
orders_enriched = orders_df.join(
    products_df.select("product_id", "price", "in_stock"),
    on="product_id", how="left"
)

display(orders_enriched)

In [0]:
# ---------------------------
# Derived columns
# ---------------------------
orders_enriched = orders_enriched.withColumn("order_value", col("order_amount").cast("decimal(10,2)")) \
    .withColumn("quarter", quarter("order_date")) \
    .withColumn("year", year("order_date")) \
    .withColumn("quarter_year", concat(lit("Q"), col("quarter"), lit("-"), col("year"))) \
    .withColumn("in_stock_flag", when(col("in_stock") == "Yes", "Y").otherwise("N"))  # Removed order_count

In [0]:
 #---------------------------
# Rename customer_df columns to avoid ambiguity, remove state and country
# ---------------------------
customer_df_renamed = customer_df.select(
    "customer_id"
)

# ---------------------------
# Join customer dimension
# ---------------------------
fact_df = orders_enriched.join(
    customer_df_renamed,
    on="customer_id", how="left"
)
display(fact_df)

In [0]:
# ---------------------------
# High Value Flag
# ---------------------------
customer_total_spend = fact_df.groupBy("customer_id").agg(sum("order_value").alias("total_spent"))
fact_df = fact_df.join(customer_total_spend, on="customer_id", how="left") \
    .withColumn("Is_High_Value_Customer", when(col("total_spent") > 10000, "Y").otherwise("N"))

In [0]:
# ---------------------------
# Generate Surrogate Key
# ---------------------------
fact_df = fact_df.withColumn("Sales_Key", monotonically_increasing_id())

# ---------------------------
# Select Final Columns (Removed order_count)
# ---------------------------
fact_final = fact_df.select(
    "Sales_Key", "customer_id", "product_id", "order_id", "order_date", "order_channel",
    "store_code", "order_value", "price", "quarter", "quarter_year", 
    "Is_High_Value_Customer", "in_stock_flag"
)
display(fact_final)

In [0]:
# ---------------------------
# Persist the fact dataframe
# ---------------------------
# Once you've performed the joins with customer_df, insights_df, and product_df, persisting fact_df will help avoid recomputation:
# Add persist() after joins but before transformations like date conversion . It will help avoid recomputation

# Persist in MEMORY_AND_DISK (best for iterative queries)
fact_df = fact_df.persist(StorageLevel.MEMORY_AND_DISK)
print("Fact table persisted after joins.")

In [0]:
# ---------------------------
# Partition Strategy: Convert order_date to date_key (YYYYMMDD as INT)
# ---------------------------
# Convert order_date to date format first
# to_date(col("order_date")) works only for YYYY-MM-DD format.

fact_df = fact_df.withColumn("date_key", expr("CAST(date_format(order_date, 'yyyyMMdd') AS INT)"))


In [0]:
# ---------------------------
# Add Effective Date and Expiry Info
# ---------------------------
fact_df = fact_df.withColumn("effective_date", current_timestamp()) \
                 .withColumn("expiry_date", lit(None)) \
                 .withColumn("is_current", lit(True))

In [0]:
# ---------------------------
# Coalesce to reduce small files before writing
# coalesce() Used after transformations (before writing to disk) to reduce unnecessary small files.
# ---------------------------
fact_df = fact_df.coalesce(10)

In [0]:
# ---------------------------
# Cache to speed up the write operation
# Before writing to Delta Gold Table, cache it to speed up the write operation:
# ---------------------------
fact_df.cache()

In [0]:
 #---------------------------
# Write to Gold Layer
# ---------------------------
# Checks if the Gold Layer Delta table already exists.
# If it exists, performs an upsert (MERGE) based on order_id to update or insert records.
# If it doesn't exist, writes the full DataFrame as a new partitioned Delta table by quarter_year.

gold_path = "/mnt/Prajwal/Retail_sales_usecase/Gold"

if DeltaTable.isDeltaTable(spark, gold_path):
    delta_table = DeltaTable.forPath(spark, gold_path)

    delta_table.alias("target").merge(
        fact_final.alias("source"),
        "target.order_id = source.order_id"
    ).whenMatchedUpdateAll() \
     .whenNotMatchedInsertAll() \
     .execute()
else:
    fact_final.write.format("delta") \
        .mode("overwrite") \
        .partitionBy("quarter_year") \
        .save(gold_path)


In [0]:
# Backfilling
# Identifies records in the fact table where customer_id = -1 (i.e., unknown at load time).
# Backfills the missing customer_id when the corresponding surrogate key is matched in the dimension.
# Ensures data consistency when late-arriving dimension records show up.

from delta.tables import DeltaTable
from pyspark.sql.functions import col

# Define paths or use table names if registered in metastore
fact_table = DeltaTable.forName(spark, "retail.Fact_Sales")
dim_customer_df = spark.table("retail.customerdimension")

# Count rows where customer_id = -1 before merge
before_count = fact_table.toDF().filter(col("customer_id") == "-1").count()

# Perform the merge to backfill missing customer_id
fact_table.alias("fact") \
    .merge(
        dim_customer_df.alias("dim"),
        "fact.customer_id = '-1' AND fact.customer_id = dim.customer_id"
    ) \
    .whenMatchedUpdate(set={
        "fact.customer_id": "dim.customer_id"
    }) \
    .whenNotMatchedInsert(values={
        "fact.customer_id": "dim.customer_id"  # You can specify other insert logic if needed
    }) \
    .execute()

# Count rows where customer_id was updated after the merge
after_count = fact_table.toDF().filter(col("customer_id") == "-1").count()

# Calculate the number of affected rows
updated_rows = before_count - after_count

# Print affected counts
print(f"Records updated: {updated_rows}")
print(f"Total affected records (before merge): {before_count}")
print(f"Remaining records with customer_id = -1 (after merge): {after_count}")