In [0]:
from pyspark.sql import functions as F
from pyspark.sql.types import *
from pyspark.sql.functions import monotonically_increasing_id, col, count, row_number, lit, current_timestamp, coalesce, max
from pyspark.sql.window import Window
from delta.tables import DeltaTable

In [0]:
# Define Paths
silver_path = "/mnt/mock_prajwal/example/silver/"
gold_path = "/mnt/mock_prajwal/example/gold/"

In [0]:
df = spark.read.format("delta").load(silver_path + "CustMaster")
display(df)

In [0]:
df.printSchema()

In [0]:
from pyspark.sql.functions import col
from pyspark.sql.window import Window
from pyspark.sql.functions import row_number

df_selected = df.select(
    col("customer_id"),
    col("customer_name"),
    col("dob"),
    col("gender"),
    col("email"),
    col("phone_number"),
    col("address"),
    col("city"),
    col("state"),
    col("country"),
    col("customer_segment"),
    col("join_date")
).dropDuplicates(['customer_id'])

window_spec = Window.orderBy("customer_id")

df_selected = df_selected.withColumn("customer_key", row_number().over(window_spec))

display(df_selected)

In [0]:
from pyspark.sql.types import StructType, StructField, IntegerType, StringType, DateType

schema = StructType([
    StructField("customer_id", IntegerType(), True),
    StructField("customer_name", StringType(), True),
    StructField("dob", DateType(), True),
    StructField("gender", StringType(), True),
    StructField("email", StringType(), True),
    StructField("phone_number", StringType(), True),
    StructField("address", StringType(), True),
    StructField("city", StringType(), True),
    StructField("state", StringType(), True),
    StructField("country", StringType(), True),
    StructField("customer_segment", StringType(), True),
    StructField("join_date", DateType(), True),
    StructField("customer_key", IntegerType(), False),
    StructField("start_date", TimestampType(), False),
    StructField("end_date", TimestampType(), True),
    StructField("is_active", BooleanType(), False),
    StructField("last_modified", TimestampType(), False)
])

# Create the schema if it does not exist
spark.sql("CREATE SCHEMA IF NOT EXISTS Prajwal_Telecom")

# Create the table with the specified schema
spark.sql("""
    CREATE TABLE IF NOT EXISTS Prajwal_Telecom.Dim_Customer (
        customer_id INT,
        customer_name STRING,
        dob DATE,
        gender STRING,
        email STRING,
        phone_number STRING,
        address STRING,
        city STRING,
        state STRING,
        country STRING,
        customer_segment STRING,
        join_date DATE,
        customer_key INT NOT NULL,
        start_date TIMESTAMP NOT NULL,
        end_date TIMESTAMP,
        is_active BOOLEAN NOT NULL,
        last_modified TIMESTAMP NOT NULL
    )
        USING DELTA 
    LOCATION "/mnt/mock_prajwal/example/gold/Dim_Customer"
""")

In [0]:
from pyspark.sql.functions import coalesce, max, lit, row_number, col, current_timestamp
from pyspark.sql.window import Window
from delta.tables import DeltaTable

# Load target table
target_df = spark.read.table("Prajwal_Telecom.Dim_Customer")

# Check if 'customer_key' column exists in the target table
if 'customer_key' in target_df.columns:
    # Finding max of previous data from table
    max_key = target_df.agg(coalesce(max("customer_key"), lit(0))).collect()[0][0]
else:
    max_key = 0

# Define window specification
window_spec = Window.orderBy("customer_id")

# Add row number and customer_key
source_df_keyed = df_selected \
    .withColumn("rn", row_number().over(window_spec)) \
    .withColumn("customer_key", col("rn") + lit(max_key)) \
    .drop("rn")

# Add audit columns to source
source_df_audit_col = source_df_keyed \
    .withColumn("start_date", current_timestamp()) \
    .withColumn("end_date", lit(None).cast("timestamp")) \
    .withColumn("is_active", lit(True)) \
    .withColumn("last_modified", current_timestamp())

# Convert target to DeltaTable
target_table = DeltaTable.forName(spark, "Prajwal_Telecom.Dim_Customer")

# Merge condition (on business key and active rows)
merge_condition = "target.customer_id = source.customer_id AND target.is_active = true"

# Perform Expiration
target_table.alias("target") \
    .merge(source_df_audit_col.alias("source"), merge_condition) \
    .whenMatchedUpdate(
        condition="target.customer_name != source.customer_name OR "
                  "target.dob != source.dob OR "
                  "target.gender != source.gender OR "
                  "target.email != source.email OR "
                  "target.phone_number != source.phone_number OR "
                  "target.address != source.address OR "
                  "target.city != source.city OR "
                  "target.state != source.state OR "
                  "target.country != source.country OR "
                  "target.customer_segment != source.customer_segment",
        set={
            "end_date": current_timestamp(),
            "is_active": lit(False),
            "last_modified": current_timestamp()
        }
    ).execute()

# Insert records with new keys (new business keys that never existed)
updated_target_df = spark.read.table("Prajwal_Telecom.Dim_Customer").filter("is_active=true").select("customer_id")
insert_df = source_df_audit_col.join(updated_target_df, on="customer_id", how="left_anti")

insert_df.write.format("delta").mode("append").option("mergeSchema", "true").save("/mnt/mock_prajwal/example/gold/Dim_Customer")

In [0]:
dim_customer = spark.read.format("delta").load("/mnt/mock_prajwal/example/gold/Dim_Customer")
display(dim_customer)

In [0]:
# Load silver layer table
silver_df = spark.read.format("delta").load(silver_path + "CustMaster")

# Load gold layer table
gold_df = spark.read.format("delta").load("/mnt/mock_prajwal/example/gold/Dim_Customer")

# Record count for silver layer
silver_count = silver_df.count()

# Record count for gold layer
gold_count = gold_df.count()

# Display counts
display(spark.createDataFrame([(silver_count, gold_count)], ["Silver Layer Count", "Gold Layer Count"]))

In [0]:
# Count the number of updated records
updated_count = target_table.toDF().filter("is_active = false AND last_modified = current_timestamp()").count()

# Count the number of inserted records
inserted_count = insert_df.count()

# Display the counts
display(spark.createDataFrame([(updated_count, inserted_count)], ["Updated Records", "Inserted Records"]))