In [0]:
%run ./01-config

In [0]:
from pyspark.sql import functions as F
from pyspark.sql.window import Window

landing_zone = base_dir_data + "/raw"
checkpoint_base = base_dir_checkpoint + "/checkpoints"



def upserter(df_micro_batch, batch_id, merge_query, temp_view):
    df_micro_batch.createOrReplaceTempView(temp_view)
    df_micro_batch._jdf.sparkSession().sql(merge_query)
    print(f"Batch {batch_id} for {temp_view} processed.")



def upsert_user_profile_microbatch(df_micro_batch, batch_id, merge_query):
    # Define a window to rank records per user_id by updated timestamp descending
    window = Window.partitionBy("user_id").orderBy(F.col("updated").desc())
    
    # Filter for only relevant update_types ("new", "update"), get latest per user_id
    df_micro_batch = (
        df_micro_batch.filter(F.col("update_type").isin(["new", "update"]))
        .withColumn("rank", F.rank().over(window))
        .filter("rank == 1")
        .drop("rank")
    )
    
    # Create or replace temp view for the MERGE SQL
    df_micro_batch.createOrReplaceTempView("user_profile_cdc")
    
    # Run the merge query to upsert into the Silver Delta table
    df_micro_batch._jdf.sparkSession().sql(merge_query)
    
    print(f"Batch {batch_id} processed.")




def upsert_users(once=True, processing_time="15 seconds", startingVersion=0):


    merge_query = f"""
    MERGE INTO {catalog}.{db_name}.users a
    USING users_delta b
    ON a.user_id = b.user_id
    WHEN NOT MATCHED THEN INSERT *

    """

    df_delta =(spark.readStream
    .option("startingVersion", startingVersion)
    .option("ignoreDeletes", "true")
    .table(f"{catalog}.{db_name}.registered_users_bz")
    .selectExpr("user_id", "device_id", "mac_address", "cast(registration_timestamp as timestamp)")
    .withWatermark("registration_timestamp", "30 seconds")
    .dropDuplicates(["user_id", "device_id"])
    )


    stream_writer =(df_delta.writeStream
        .foreachBatch(lambda df, id: upserter(df, id, merge_query, "users_delta"))
        .outputMode("update")
        .option("checkpointLocation", f"{checkpoint_base}/users")
        .queryName("users_upsert_stream")

    )


    if once == True:
        return stream_writer.trigger(availableNow=True).start()
    else:
        return stream_writer.trigger(processingTime=processing_time).start()
        




def upsert_gym_logs(once=True, processing_time="15 seconds", startingVersion=0):

    merge_query = f"""
    MERGE INTO {catalog}.{db_name}.gym_logs a
    USING gym_logs_delta b
    ON a.mac_address=b.mac_address AND a.gym=b.gym AND a.login=b.login
    WHEN MATCHED AND b.logout > a.login AND b.logout > a.logout
        THEN UPDATE SET logout = b.logout
    WHEN NOT MATCHED THEN INSERT *
    """



    df_delta = (spark.readStream
                    .option("startingVersion", startingVersion)
                    .option("ignoreDeletes", True)
                    .table(f"{catalog}.{db_name}.gym_logins_bz")
                    .selectExpr("mac_address", "gym", "cast(login as timestamp)", "cast(logout as timestamp)")
                    .withWatermark("login", "30 seconds")
                    .dropDuplicates(["mac_address", "gym", "login"])
            )


    stream_writer = (df_delta.writeStream
                                .foreachBatch(lambda df, id: upserter(df, id, merge_query, "gym_logs_delta"))
                                .outputMode("update")
                                .option("checkpointLocation", checkpoint_base + "/gyms_logs")
                                .queryName("gym_logs_upsert_stream")
                    )
    


    if once == True:
        return stream_writer.trigger(availableNow=True).start()
    else:
        return stream_writer.trigger(processingTime=processing_time).start()






def upsert_user_profile(once=True, processing_time="15 seconds", startingVersion=0):


    merge_query = f"""
    MERGE INTO {catalog}.{db_name}.user_profile a
    USING user_profile_cdc b
    ON a.user_id=b.user_id
    WHEN MATCHED AND a.updated < b.updated
        THEN UPDATE SET *
    WHEN NOT MATCHED
        THEN INSERT *
    """


    schema = """
    user_id bigint, update_type STRING, timestamp FLOAT, 
    dob STRING, sex STRING, gender STRING, first_name STRING, last_name STRING, 
    address STRUCT<street_address: STRING, city: STRING, state: STRING, zip: INT>
    """


    df_cdc = (
    spark.readStream
            .option("startingVersion", 0)
            .option("ignoreDeletes", True)
            .table(f"{catalog}.{db_name}.kafka_multiplex_bz")
            .filter("topic = 'user_info'")
            .select(F.from_json(F.col("value").cast("string"), schema).alias("v"))
            .select("v.*")
            .select(
                "user_id",
                F.to_date("dob", "MM/dd/yyyy").alias("dob"),
                "sex", "gender", "first_name", "last_name",
                "address.*",
                F.col("timestamp").cast("timestamp").alias("updated"),
                "update_type"
            )
            .withWatermark("updated", "30 seconds")
            .dropDuplicates(["user_id", "updated"])
            )


    stream_writer = (
    df_cdc.writeStream
            .foreachBatch(lambda df, id: upsert_user_profile_microbatch(df, id, merge_query))
            .outputMode("update")
            .option("checkpointLocation", checkpoint_base + "/user_profile")
            .queryName("user_profile_stream")

             )
    


    if once == True:
        return stream_writer.trigger(availableNow=True).start()
    else:
        return stream_writer.trigger(processingTime=processing_time).start()




def upsert_workouts(once=True, processing_time="15 seconds", startingVersion=0):

    merge_query = f"""
    MERGE INTO {catalog}.{db_name}.workouts a
    USING workouts_delta b
    ON a.user_id=b.user_id AND a.time=b.time
    WHEN NOT MATCHED THEN INSERT *
    """


    schema = "user_id INT, workout_id INT, timestamp FLOAT, action STRING, session_id INT"

    df_delta = (spark.readStream
            .option("startingVersion", startingVersion)
            .option("ignoreDeletes", True)
            .table(f"{catalog}.{db_name}.kafka_multiplex_bz")
            .filter("topic = 'workout'")
            .select(F.from_json(F.col("value").cast("string"), schema).alias("v"))
            .select("v.*")
            .select("user_id", "workout_id", 
                    F.col("timestamp").cast("timestamp").alias("time"), 
                    "action", "session_id")
            .withWatermark("time", "30 seconds")
            .dropDuplicates(["user_id", "time"])
    )



    stream_writer = (df_delta.writeStream
                                .foreachBatch(lambda df, id: upserter(df, id, merge_query, "workouts_delta"))
                                .outputMode("update")
                                .option("checkpointLocation",checkpoint_base + "/workouts")
                                .queryName("workouts_upsert_stream")
                    )
    

    if once == True:
        return stream_writer.trigger(availableNow=True).start()
    else:
        return stream_writer.trigger(processingTime=processing_time).start()




def upsert_heart_rate(once=True, processing_time="15 seconds", startingVersion=0):

    merge_query = f"""
    MERGE INTO {catalog}.{db_name}.heart_rate a
    USING heart_rate_delta b
    ON a.device_id=b.device_id AND a.time=b.time
    WHEN NOT MATCHED THEN INSERT *
    """


    schema = "device_id LONG, time TIMESTAMP, heartrate DOUBLE"


    df_delta = (spark.readStream
                        .option("startingVersion", startingVersion)
                        .option("ignoreDeletes", True)
                        .table(f"{catalog}.{db_name}.kafka_multiplex_bz")
                        .filter("topic = 'bpm'")
                        .select(F.from_json(F.col("value").cast("string"), schema).alias("v"))
                        .select("v.*", F.when(F.col("v.heartrate") <= 0, False).otherwise(True).alias("valid"))
                        .withWatermark("time", "30 seconds")
                        .dropDuplicates(["device_id", "time"])
                )


    stream_writer = (df_delta.writeStream
                                .foreachBatch(lambda df, id: upserter(df, id, merge_query, "heart_rate_delta"))
                                .outputMode("update")
                                .option("checkpointLocation", checkpoint_base + "/heart_rate")
                                .queryName("heart_rate_upsert_stream")
                    )


    if once == True:
        return stream_writer.trigger(availableNow=True).start()
    else:
        return stream_writer.trigger(processingTime=processing_time).start()




from pyspark.sql.functions import floor, months_between, current_date, when, col

def age_bins(dob_col):
    age_col = floor(months_between(current_date(), dob_col) / 12)
    
    return (when(age_col < 18, "under 18")
              .when((age_col >= 18) & (age_col < 25), "18-25")
              .when((age_col >= 25) & (age_col < 35), "25-35")
              .when((age_col >= 35) & (age_col < 45), "35-45")
              .when((age_col >= 45) & (age_col < 55), "45-55")
              .when((age_col >= 55) & (age_col < 65), "55-65")
              .when((age_col >= 65) & (age_col < 75), "65-75")
              .when((age_col >= 75) & (age_col < 85), "75-85")
              .when((age_col >= 85) & (age_col < 95), "85-95")
              .when(age_col >= 95, "95+")
              .otherwise("invalid age"))



def upsert_user_bins(once=True, processing_time="15 seconds", startingVersion=0):

    merge_query = f"""
    MERGE INTO {catalog}.{db_name}.user_bins a
    USING user_bins_delta b
    ON a.user_id=b.user_id
    WHEN MATCHED 
    THEN UPDATE SET *
    WHEN NOT MATCHED THEN INSERT *
    """


    # Step 1: Load list of registered users from silver `users` table
    df_user = spark.table(f"{catalog}.{db_name}.users").select("user_id")


    # Step 2: Read streaming changes from `user_profile` table
    # - Set `ignoreChanges=True` to track only new or updated records (no deletes)

    df_delta = (
        spark.readStream
            .option("startingVersion", startingVersion)
            .option("ignoreChanges", True)
            .table(f"{catalog}.{db_name}.user_profile")
            .join(df_user, on="user_id", how="left")  # Statelss left join with static DataFrame
            .select(
                "user_id",
                age_bins(col("dob")).alias("age"),  # Use previously defined function
                "gender", 
                "city", 
                "state"
            )
    )


    stream_writer = (df_delta.writeStream
                            .foreachBatch(lambda df, id: upserter(df, id, merge_query, "user_bins_delta"))
                            .outputMode("update")
                            .option("checkpointLocation", checkpoint_base + "/user_bins")
                            .queryName("user_bins_upsert_stream")
                    )


    if once == True:
        return stream_writer.trigger(availableNow=True).start()
    else:
        return stream_writer.trigger(processingTime=processing_time).start()



def upsert_completed_workouts(once=True, processing_time="15 seconds", startingVersion=0):

    #Idempotent - Only one user workout session completes. So ignore the duplicates and insert the new records
    merge_query = f"""
    MERGE INTO {catalog}.{db_name}.completed_workouts a
    USING completed_workouts_delta b
    ON a.user_id=b.user_id AND a.workout_id = b.workout_id AND a.session_id=b.session_id
    WHEN NOT MATCHED THEN INSERT *
    """




    df_start = (spark.readStream
                        .option("startingVersion", startingVersion)
                        .option("ignoreDeletes", True)
                        .table(f"{catalog}.{db_name}.workouts")
                        .filter("action = 'start'")                         
                        .selectExpr("user_id", "workout_id", "session_id", "time as start_time")
                        .withWatermark("start_time", "30 seconds")
                )

    df_stop = (spark.readStream
                        .option("startingVersion", startingVersion)
                        .option("ignoreDeletes", True)
                        .table(f"{catalog}.{db_name}.workouts")
                        .filter("action = 'stop'")                         
                        .selectExpr("user_id", "workout_id", "session_id", "time as end_time")
                        .withWatermark("end_time", "30 seconds")
                )

    # State cleanup - Define a condition to clean the state
    #               - stop must occur within 3 hours of start 
    #               - stop < start + 3 hours
    join_condition = [df_start.user_id == df_stop.user_id, df_start.workout_id==df_stop.workout_id, df_start.session_id==df_stop.session_id, 
                        df_stop.end_time < df_start.start_time + F.expr('interval 3 hour')]         

    df_delta = (df_start.join(df_stop, join_condition)
                        .select(df_start.user_id, df_start.workout_id, df_start.session_id, df_start.start_time, df_stop.end_time)
                )

    stream_writer = (df_delta.writeStream
                                .foreachBatch(lambda df, id: upserter(df, id, merge_query, "completed_workouts_delta"))
                                .outputMode("append")
                                .option("checkpointLocation", checkpoint_base + "/completed_workouts_delta")
                                .queryName("completed_workouts_upsert_stream")
                    )


    if once == True:
        return stream_writer.trigger(availableNow=True).start()
    else:
        return stream_writer.trigger(processingTime=processing_time).start()

    

def upsert_workout_bpm(once=True, processing_time="15 seconds", startingVersion=0):

    #Idempotent - Only one user workout session completes. So ignore the duplicates and insert the new records
    merge_query = f"""
    MERGE INTO {catalog}.{db_name}.workout_bpm a
    USING workout_bpm_delta b
    ON a.user_id=b.user_id AND a.workout_id = b.workout_id AND a.session_id=b.session_id AND a.time=b.time
    WHEN NOT MATCHED THEN INSERT *
    """



    # Load static user table
    df_users = spark.read.table(f"{catalog}.{db_name}.users")

    # Load the completed workouts stream
    df_completed_workouts = (
        spark.readStream
            .option("startingVersion", 0)
            .option("ignoreDeletes", True)
            .table(f"{catalog}.{db_name}.completed_workouts")
            .join(df_users, "user_id")
            .selectExpr("user_id", "device_id", "workout_id", "session_id", "start_time", "end_time")
            .withWatermark("end_time", "30 seconds")
    )

    # Load the heart rate stream
    df_bpm = (
        spark.readStream
            .option("startingVersion", 0)
            .option("ignoreDeletes", True)
            .table(f"{catalog}.{db_name}.heart_rate")
            .filter("valid = True")
            .selectExpr("device_id", "time", "heartrate")
            .withWatermark("time", "30 seconds")
    )

    from pyspark.sql.functions import expr
    # Define the join condition between BPM and workouts
    join_condition = [
        df_completed_workouts.device_id == df_bpm.device_id,
        df_bpm.time > df_completed_workouts.start_time,
        df_bpm.time <= df_completed_workouts.end_time,
        df_completed_workouts.end_time < df_bpm.time + expr("interval 3 hours")
    ]

    # Join and select the desired columns
    df_delta = (
        df_bpm.join(df_completed_workouts, join_condition)
            .select("user_id", "workout_id", "session_id", "start_time", "end_time", "time", "heartrate")
    )

    stream_writer = (
        df_delta.writeStream
                .foreachBatch(lambda df, id: upserter(df, id, merge_query, "workout_bpm_delta"))
                .outputMode("append")
                .option("checkpointLocation", checkpoint_base + "/workout_bpm")
                .queryName("workout_bpm_upsert_stream")

    )



    if once == True:
        return stream_writer.trigger(availableNow=True).start()
    else:
        return stream_writer.trigger(processingTime=processing_time).start()



def await_queries(once):
    if once:
        for stream in spark.streams.active:
            stream.awaitTermination()


def upsert_silver(once=True, processing_time="5 seconds"):
    import time
    start = int(time.time())
    print(f"\n Executing silver layer upsert ...")

    # Silver Layer 1: Core profile & device-level streams
    upsert_users(once, processing_time)
    upsert_gym_logs(once, processing_time)
    upsert_user_profile(once, processing_time)
    upsert_workouts(once, processing_time)
    upsert_heart_rate(once, processing_time)

    await_queries(once)
    print(f"Completed silver layer 1 upsert in {int(time.time()) - start} seconds")

    # Silver Layer 2: Derived user binning & session logs
    upsert_user_bins(once, processing_time)
    upsert_completed_workouts(once, processing_time)

    await_queries(once)
    print(f" Completed silver layer 2 upsert in {int(time.time()) - start} seconds")

    # Silver Layer 3: Workout BPM aggregation
    upsert_workout_bpm(once, processing_time)

    await_queries(once)
    print(f" Completed silver layer 3 upsert in {int(time.time()) - start} seconds")




def assert_count(catalog, db_name, table_name, expected_count, filter="true"):
    print(f"Validating record counts in {table_name}...", end='')
    actual_count = spark.read.table(f"{catalog}.{db_name}.{table_name}").where(filter).count()
    assert actual_count == expected_count, f"Expected {expected_count:,} records, found {actual_count:,} in {table_name} where {filter}"
    print(f"Found {actual_count:,} / Expected {expected_count:,} records where {filter}: Success")


def validate_silver(catalog, db_name, sets):
    import time
    start = int(time.time())
    print(f"\n Starting Silver layer validation...")

    # Silver layer 1
    assert_count(catalog, db_name, "users", 5 if sets == 1 else 10)
    assert_count(catalog, db_name, "gym_logs", 8 if sets == 1 else 16)
    assert_count(catalog, db_name, "user_profile", 5 if sets == 1 else 10)
    assert_count(catalog, db_name, "workouts", 16 if sets == 1 else 32)
    assert_count(catalog, db_name, "heart_rate", sets * 253801)

    print(f"Silver layer 1 validation done in {int(time.time()) - start} seconds\n")

    # Silver layer 2
    assert_count(catalog, db_name, "user_bins", 5 if sets == 1 else 10)
    assert_count(catalog, db_name, "completed_workouts", 8 if sets == 1 else 16)

    print(f"Silver layer 2 validation done in {int(time.time()) - start} seconds\n")

    # Silver layer 3
    assert_count(catalog, db_name, "workout_bpm", 3968 if sets == 1 else 8192)

    print(f"Silver layer 3 validation done in {int(time.time()) - start} seconds ")


    
