In [0]:
%run ./01-config

In [0]:
class Upserter():
    def __init__(self,temp_view_name,merge_query):
        self.temp_view_name = temp_view_name
        self.merge_query = merge_query
    
    def upsertdata(self,batch_data,batch_id):
        batch_data.createOrReplaceTempView(self.temp_view_name)
        # print(self.merge_query)
        # spark.sql(self.merge_query)
        batch_data._jdf.sparkSession().sql(self.merge_query)

In [0]:
class Gold():
    def __init__(self,env):
        Conf  = Config()
        self.checkpoint_base = Conf.base_dir_checkpoint + "/checkpoints" 
        self.base_data_dir_test = Conf.base_dir_data + "/test_data"
        self.catalog = env
        self.silver_db = Conf.silver_db_name
        self.gold_db = Conf.gold_db_name
        spark.sql(f"USE {self.catalog}.{self.gold_db}")
    
    def upsert_workout_bpm_summary(self,once=False,process_time="10 seconds",startingVersion=0):
        import pyspark.sql.functions as f

        merge_query=f"""
        MERGE INTO {self.catalog}.{self.gold_db}.workout_bpm_summary a USING workout_bpm_summary_delta b ON
        a.user_id = b.user_id AND a.session_id = b.session_id AND a.workout_id = b.workout_id
        WHEN NOT MATCHED THEN INSERT * 
        """

        data_upserter = Upserter("workout_bpm_summary_delta",merge_query)

        user_bins_tbl=spark.read.table(f"{self.catalog}.{self.silver_db}.user_bins")

        read_stream = (spark.readStream
                       .option("startingVersion",startingVersion)
                       .table(f"{self.catalog}.{self.silver_db}.workout_bpm")
                       .withWatermark("end_time","30 seconds")
                       .groupBy("workout_id","session_id","user_id","end_time")
                       .agg(f.min("heartrate").alias("min_bpm"),f.avg("heartrate").alias("avg_bpm"),f.max("heartrate").alias("max_bpm")
                            ,f.count("heartrate").alias("num_recordings"))
                       .join(user_bins_tbl,["user_id"])
                       .select("workout_id","session_id","user_id","age","gender","city","state","min_bpm","avg_bpm","max_bpm","num_recordings")
                       )
        stream_writer = (read_stream.writeStream
                                 .foreachBatch(data_upserter.upsertdata)
                                 .outputMode("append")
                                 .option("checkpointLocation", f"{self.checkpoint_base}/workout_bpm_summary")
                                 .queryName("workout_bpm_summary_upsert_stream")
                        )
        
        if once:
            stream_writer.trigger(availableNow=True).start()
        else:
            stream_writer.trigger(processingTime=process_time).start()
        
    def upsert(self,once=True,process_time="10 seconds"):
        import time
        start = int(time.time())
        print(f"\nExecuting gold layer upsert ...")
        self.upsert_workout_bpm_summary(once=once,process_time=process_time)
        if once:
            for stream in spark.streams.active:
                stream.awaitTermination()
        print(f"Completed gold layer upsert {int(time.time()) - start} seconds")
    
    def assert_count(self,table_name,expected_count,filter="true"):
        actual_count = spark.read.table(f"{self.catalog}.{self.gold_db}.{table_name}").where(filter).count()
        assert expected_count == actual_count, f"Expected {expected_count:,} records, found {actual_count:,} in {table_name} where {filter}" 
        print(f"Found {actual_count:,} / Expected {expected_count:,} records where {filter}: Success")

    def assert_rows(self, location, table_name, sets):
        print(f"Validating records in {table_name}...", end='')
        expected_rows = spark.read.format("parquet").load(f"{self.base_data_dir_test}/{location}_{sets}.parquet").collect()
        actual_rows = spark.table(table_name).collect()
        assert expected_rows == actual_rows, f"Expected data mismatches with the actual data in {table_name}"
        print(f"Expected data matches with the actual data in {table_name}: Success")
        
        
    def validate(self, sets):
        import time
        start = int(time.time())
        print(f"\nValidating gold layer records..." )       
        self.assert_rows("7-gym_summary", "gym_summary", sets)       
        if sets>1:
            self.assert_count("workout_bpm_summary", 2)
        print(f"Gold layer validation completed in {int(time.time()) - start} seconds")

