In [0]:
%run ./01-config

In [0]:
class Bronze():   
    def __init__(self, env):
        Conf = Config()
        self.landing_zone = Conf.base_dir_data + "/raw"
        self.checkpoint_base = Conf.base_dir_checkpoint + "/checkpoints"        
        self.catalog = env
        self.bronze_db = Conf.bronze_db_name
        self.silver_db = Conf.silver_db_name
        spark.sql(f"USE {self.catalog}.{self.bronze_db}")
    
    def load_registered_users_bz(self,once=False,process_time = "S seconds"):
        import pyspark.sql.functions as f
        schema_ddl = "user_id LONG,	device_id LONG,	mac_address	STRING, registration_timestamp DOUBLE"

        read_stream = (spark.readStream.format("cloudFiles")
                       .option("cloudFiles.format","csv")
                       .option("maxfilesPerTrigger",1)
                       .schema(schema_ddl)
                       .option("header","true")
                       .load(f"{self.landing_zone}/registered_users/")
                       .withColumn("load_time",f.current_timestamp())
                       .withColumn("source_file",f.input_file_name())
                       )
        write_stream = (read_stream.writeStream.format("delta")
                        .outputMode("append")
                        .option("checkPointLocation",f"{self.checkpoint_base}/load_registered_users_bz")
                        .option("queryName","load_registered_users_bz")
                        )

        if once:
            return write_stream.trigger(availableNow=True).toTable(f"{self.catalog}.{self.bronze_db}.registered_users_bz")
        else:
            return write_stream.trigger(processingTime=process_time).toTable(f"{self.catalog}.{self.bronze_db}.registered_users_bz")
        
        spark.sparkContext.setLocalProperty("spark.scheduler.pool", "bronze_p2")


    def load_gym_logins_bz(self,once=False,process_time = "5 seconds"):
        import pyspark.sql.functions as f
        schema_ddl = "mac_address STRING, gym LONG, login DOUBLE, logout DOUBLE"

        read_stream = (spark.readStream.format("cloudFiles")
                       .option("cloudFiles.format","csv")
                       .option("maxfilesPerTrigger",1)
                       .schema(schema_ddl)
                       .option("header","true")
                       .load(f"{self.landing_zone}/gym_logins/")
                       .withColumn("load_time",f.current_timestamp())
                       .withColumn("source_file",f.input_file_name())
                       )
        write_stream = (read_stream.writeStream.format("delta")
                        .outputMode("append")
                        .option("checkPointLocation",f"{self.checkpoint_base}/load_gym_logins_bz")
                        .option("queryName","load_gym_logins_bz")
                        )

        if once:
            return write_stream.trigger(availableNow=True).toTable(f"{self.catalog}.{self.bronze_db}.gym_logins_bz")
        else:
            return write_stream.trigger(processingTime=process_time).toTable(f"{self.catalog}.{self.bronze_db}.gym_logins_bz")
        
        spark.sparkContext.setLocalProperty("spark.scheduler.pool", "bronze_p2")
    
    def load_kafka_multiplex_bz(self,once=False,process_time = "5 seconds"):
        import pyspark.sql.functions as f
        schema_ddl = "key STRING,value STRING,topic STRING,partition BIGINT,offset BIGINT,timestamp BIGINT"
        date_lookup = spark.read.table(f"{self.catalog}.{self.silver_db}.date_lookup").select("date","week_part")


        read_stream = (spark.readStream.format("cloudFiles")
                       .option("cloudFiles.format","json")
                       .option("maxfilesPerTrigger",1)
                       .schema(schema_ddl)
                       .load(f"{self.landing_zone}/kafka_multiplex/")
                       .withColumn("load_time",f.current_timestamp())
                       .withColumn("source_file",f.input_file_name())
                       .join(f.broadcast(date_lookup),f.to_date((f.col("timestamp")/1000).cast("TIMESTAMP")) == date_lookup.date,"left")
                       )
        
        write_stream = (read_stream.writeStream.format("delta")
                        .outputMode("append")
                        .option("checkPointLocation",f"{self.checkpoint_base}/load_kafka_multiplex_bz")
                        .option("queryName","load_kafka_multiplex_bz")
                        )
        
        if once:
            return write_stream.trigger(availableNow=True).toTable(f"{self.catalog}.{self.bronze_db}.kafka_multiplex_bz")
        else:
            return write_stream.trigger(processingTime=process_time).toTable(f"{self.catalog}.{self.bronze_db}.kafka_multiplex_bz")
        
        spark.sparkContext.setLocalProperty("spark.scheduler.pool", "bronze_p1")

        
    def load_bronze(self,once=True,process_time="5 seconds"):
        import time
        start = time.time()
        print("loading data into bronze layer")
        self.load_gym_logins_bz(once,process_time)
        self.load_registered_users_bz(once,process_time)
        self.load_kafka_multiplex_bz(once,process_time)
        if once:
            for stream in spark.streams.active:
                stream.awaitTermination()
        end = time.time()
        print(f"time taken for loading data into bronze {int(end)-int(start)} seconds")
    
    def assert_count(self,table_name,expected_count,filter="true"):
        actual_count = spark.read.table(f"{self.catalog}.{self.bronze_db}.{table_name}").where(filter).count()
        print(f"actual_count : {actual_count}")
        assert actual_count == expected_count, f"Expected {expected_count} rows in {table_name}, found {actual_count}"
        print(f"Expected {expected_count} rows in {table_name}, found {actual_count}")

    def validate(self,sets):
        import time
        start = time.time()
        print("validating bronze layer records")
        self.assert_count("registered_users_bz",5 if sets==1 else 10)
        self.assert_count("gym_logins_bz",8 if sets==1 else 16)
        self.assert_count("kafka_multiplex_bz",7 if sets==1 else 13,"topic='user_info'")
        self.assert_count("kafka_multiplex_bz",16 if sets==1 else 32,"topic='workout'")
        self.assert_count("kafka_multiplex_bz",sets*253801,"topic='bpm'")
        end = time.time()
        print(f"time taken for validating data into bronze {int(end)-int(start)} seconds")

