In [0]:
%run ./01-config

In [0]:
from pyspark.sql import functions as F
landing_zone = base_dir_data + "/raw"
checkpoint_base = base_dir_checkpoint + "/checkpoints"


def upserter(df_micro_batch, batch_id, merge_query, temp_view):
    df_micro_batch.createOrReplaceTempView(temp_view)
    df_micro_batch._jdf.sparkSession().sql(merge_query)
    print(f"Batch {batch_id} for {temp_view} processed.")



def upsert_workout_bpm_summary(once=True, processing_time="15 seconds", startingVersion=0):
    merge_query = f"""
    MERGE INTO {catalog}.{db_name}.workout_bpm_summary a
    USING workout_bpm_summary_delta b
    ON a.user_id=b.user_id AND a.workout_id = b.workout_id AND a.session_id=b.session_id
    WHEN NOT MATCHED THEN INSERT *
    """



    df_users = spark.read.table(f"{catalog}.{db_name}.user_bins")

    df_delta = (spark.readStream
                    .option("startingVersion", startingVersion)
                    .table(f"{catalog}.{db_name}.workout_bpm")
                    .withWatermark("end_time", "30 seconds")
                    .groupBy("user_id", "workout_id", "session_id", "end_time")
                    .agg(F.min("heartrate").alias("min_bpm"), F.mean("heartrate").alias("avg_bpm"), 
                        F.max("heartrate").alias("max_bpm"), F.count("heartrate").alias("num_recordings"))                         
                    .join(df_users, ["user_id"])
                    .select("workout_id", "session_id", "user_id", "age", "gender", "city", "state", "min_bpm", "avg_bpm", "max_bpm", "num_recordings")
                )


    stream_writer = (df_delta.writeStream
                    .foreachBatch(lambda df, id: upserter(df, id, merge_query, "workout_bpm_summary_delta"))
                    .outputMode("append")
                    .option("checkpointLocation", f"{checkpoint_base}/workout_bpm_summary")
                    .queryName("workout_bpm_summary_upsert_stream")
            )
    if once == True:
        return stream_writer.trigger(availableNow=True).start()
    else:
        return stream_writer.trigger(processingTime=processing_time).start()
    


    
def upsert_gold(once=True, processing_time="5 seconds"):
        import time
        start = int(time.time())
        print(f"\nExecuting gold layer upsert ...")
        upsert_workout_bpm_summary(once, processing_time)
        if once:
            for stream in spark.streams.active:
                stream.awaitTermination()
        print(f"Completed gold layer upsert {int(time.time()) - start} seconds") 



def assert_count(catalog, db_name, table_name, expected_count, filter="true"):
    print(f"Validating record counts in {table_name}...", end='')
    actual_count = spark.read.table(f"{catalog}.{db_name}.{table_name}").where(filter).count()
    assert actual_count == expected_count, (
        f"Expected {expected_count:,} records, found {actual_count:,} in {table_name} where {filter}"
    )
    print(f"Found {actual_count:,} / Expected {expected_count:,} records where {filter}: Success")




def assert_rows(catalog, db_name, test_data_dir, location, table_name, sets):
    print(f"Validating records in {table_name}...", end='')
    expected_rows = spark.read.format("parquet").load(f"{test_data_dir}/{location}_{sets}.parquet").collect()
    actual_rows = spark.read.table(f"{catalog}.{db_name}.{table_name}").collect()
    assert expected_rows == actual_rows, (
        f"\n Data mismatch in {table_name}\n"
        f"- Expected: {len(expected_rows)} rows\n"
        f"- Actual:   {len(actual_rows)} rows"
    )
    print(f"Expected data matches with the actual data in {table_name}: Success")


def validate_gold(catalog, db_name, test_data_dir, sets):
    import time
    start = int(time.time())
    print(f"\nValidating gold layer records...")

    assert_rows(catalog, db_name, test_data_dir, "7-gym_summary", "gym_summary", sets)

    if sets > 1:
        assert_count(catalog, db_name, "workout_bpm_summary", 2)

    print(f"Gold layer validation completed in {int(time.time()) - start} seconds")



   

