### Ingestor Class

In [0]:
import math
from delta.tables import *
from pyspark.sql.types import *
import pyspark.sql.functions as f
from pyspark.sql.window import Window
from datetime import datetime, date, timedelta

class Ingestor(object):
    def __init__(self, table_set="none"):
        self.storage_account = "northwinddl"
        self.external_location = "prod"
        self.catalog = "northwind"
        self.container_path = f"abfss://{self.external_location}@{self.storage_account}.dfs.core.windows.net"
        self.raw_path = f"{self.container_path}/raw"
        self.bronze_path = f"{self.container_path}/bronze"
        self.silver_path = f"{self.container_path}/silver"
        self.gold_path = f"{self.container_path}/gold"
        self.table_set = table_set
        self.config = self.read_config()
        self.processing_columns = {
            "_file_path": "STRING COMMENT 'The path from which the data was ingested'",
            "_file_name": "STRING COMMENT 'The name of the file from which the data was ingested'",
            "_file_size": "BIGINT COMMENT 'The size of the file from which the data was ingested'",
            "_file_modification_time": "TIMESTAMP COMMENT 'The last modified timestamp of the file from which the data was ingested'",
            "_loaded_at_utc": "TIMESTAMP COMMENT 'The timestamp at which the data was ingested'",
            "_last_updated_at_utc": "TIMESTAMP COMMENT 'The timestamp at which the data was last updated'"
        }
        self.client_id = dbutils.secrets.get(scope='northwind-scope', key='northwinddl-client-id')
        self.tenant_id = dbutils.secrets.get(scope='northwind-scope', key='northwinddl-tenant-id')
        self.client_secret = dbutils.secrets.get(scope='northwind-scope', key='northwinddl-client-secret')
        self.timer_start = None
        self.timer_end = None

    def ts_print(self, message):
        '''
        Prints the input message with the timestamp preprended
        '''
        now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        print(f"[{now}] >>> {message}")

    def start_timer(self):
        '''
        Starts the timer to measure the duration of the job.
        '''
        self.timer_start = datetime.now()

    def stop_timer(self):
        '''
        Stops the timer to measure the duration of the job.
        '''
        self.timer_end = datetime.now()

    def calculate_duration(self):
        '''
        Calculates the duration of the job.
        '''
        if self.timer_start is None or self.timer_end is None:
            return None
        else:
            duration = (self.timer_end - self.timer_start).total_seconds()
            hours = math.floor(duration / 60 / 60)
            mins = math.floor((duration - (hours * 60 * 60)) / 60)
            seconds = math.floor(duration % 60)
            return f"Total duration: {hours}h {mins}m {seconds}s"

    def overwrite_config(self):
        '''
        Reloads the configuration tables from the silver container to ensure they are up-to-date during table processing.
        '''
        # Overwrite table configuration
        df = spark.read.format('csv').option('header', True).option('inferSchema', True).load(f'{self.silver_path}/config/config_tables.csv')
        df.write.format('delta').mode('overwrite').option('overwriteSchema', True).save(f'{self.silver_path}/config/config_tables')
        sql_query = f"""
            CREATE TABLE IF NOT EXISTS {self.catalog}.silver.config_tables
            USING DELTA
            LOCATION '{self.silver_path}/config/config_tables'
        """
        spark.sql(sql_query)

        self.ts_print("config_tables successfully overwritten.")

        # Overwrite field configuration
        df = spark.read.format('csv').option('header', True).option('inferSchema', True).load(f'{self.silver_path}/config/config_fields.csv')
        df.write.format('delta').mode('overwrite').option('overwriteSchema', True).save(f'{self.silver_path}/config/config_fields')
        sql_query = f"""
            CREATE TABLE IF NOT EXISTS {self.catalog}.silver.config_fields
            USING DELTA
            LOCATION '{self.silver_path}/config/config_fields'        
        """
        spark.sql(sql_query)

        self.ts_print("config_fields successfully overwritten.")

    def read_config(self):
        '''
        Loads the table and column configurations into a dataframe.
        '''
        configs = spark \
            .sql("SHOW TABLES IN northwind.silver") \
            .filter("tableName like 'config_%'")

        # If both config tables aren't loaded to silver, then load them from the csvs in the silver directory
        if configs.count() < 2:
            self.overwrite_config()

        sql_query = f"""
            SELECT
                *,
                CASE
                    WHEN array_contains(split(primary_keys, ";"), CASE WHEN layer = 'bronze' THEN f.field_raw ELSE f.field_silver END) THEN 'Y'
                    ELSE 'N'
                END AS is_primary_key
            FROM
                {self.catalog}.silver.config_tables AS t
                INNER JOIN {self.catalog}.silver.config_fields AS f USING (table_set)
            WHERE
                table_set = '{self.table_set}'        
        """

        return spark.sql(sql_query)

    def table_exists(self, layer, table):
        '''
        Checks if the specified table exists in the specified schema
        '''
        sql_query = f"SHOW TABLES IN northwind.{layer}"
        df = spark \
            .sql(sql_query) \
            .filter(f"tableName == '{table}'")

        if df.count() == 0:
            return False
        else:
            return True

    def data_type_lookup(self, data_type):
        '''
        Retrieves the corresponding spark data type when instantiating the dataframe.
        '''
        data_types = {
            'DECIMAL': FloatType(),
            'FLOAT': FloatType(),
            'INT': IntegerType(),
            'LONG': IntegerType(),
            'STRING': StringType(),
            'TIMESTAMP': TimestampType()
        }
        return data_types.get(data_type)

    def is_nullable_lookup(self, primary_keys, field):
        '''
        Determines if the column is nullable when instantiating the dataframe.
        '''
        if primary_keys is None:
            return True
        else:
            return False if field in primary_keys.split(';') else True

    def get_schema(self, layer, load_type):
        '''
        Gets the schema of the table being processed when instantiating the dataframe.
        '''
        schema = StructType()

        for row in self.config.filter(f"layer = '{layer}'").collect():

            primary_keys = row['primary_keys']
            field = row['field_raw'] if layer == 'bronze' else row['field_silver']
            comment = row['field_description']

            schema.add(
                field,                                          # Field name
                self.data_type_lookup(row['type_silver']),      # Data type
                self.is_nullable_lookup(primary_keys, field),   # Is nullable
                {"comment": comment}                            # Comment
            )

        return schema

    def create_table(self, layer, table, table_desc, path, tbl_properties, config, load_type):
        '''
        Creates the table in the target schema from the specified external location.
        '''
        self.ts_print(f"Creating table `{self.catalog}.{layer}.{table}`...")

        schema = []
        for row in config:
            field = row["field_raw" if layer == "bronze" else "field_silver"]
            data_type = row["type_silver"]
            comment = row["field_description"]
            schema.append(f"{field} {data_type} COMMENT '{comment}'")

        if layer == "bronze":
            schema.append(f"_file_path {self.processing_columns.get('_file_path')}")
            schema.append(f"_file_name {self.processing_columns.get('_file_name')}")
            schema.append(f"_file_size {self.processing_columns.get('_file_size')}")
            schema.append(f"_file_modification_time {self.processing_columns.get('_file_modification_time')}")
            schema.append(f"_loaded_at_utc {self.processing_columns.get('_loaded_at_utc')}")

        if layer == "bronze" and "STREAM" in load_type:
            schema.append(f"_last_updated_at_utc {self.processing_columns.get('_last_updated_at_utc')}")

        if layer == "silver" and "STREAM" in load_type:
            schema.append(f"_loaded_at_utc {self.processing_columns.get('_loaded_at_utc')}")
            schema.append(f"_last_updated_at_utc {self.processing_columns.get('_last_updated_at_utc')}")

        schema = ','.join(schema)
        schema = f"({schema})"

        tbl_properties = ",".join([f"'{k}'='{v}'" for k, v in tbl_properties.items()])

        sql_query = (f"""
            CREATE TABLE IF NOT EXISTS {self.catalog}.{layer}.{table} {schema}
            USING DELTA
            COMMENT '{table_desc}'
            TBLPROPERTIES ({tbl_properties})
            LOCATION '{path}'
        """)
        spark.sql(sql_query)

        self.ts_print(f"Successfully created table `{self.catalog}.{layer}.{table}`.")

    def alter_table_properties(self, layer, table, tbl_properties):
        '''
        Alter the properties of the table if there is a mismatch between the existing properties and the config.
        '''
        # Only alter table properties if the table doesn't already have the properties set
        existing_properties = spark.sql(f"DESC TABLE EXTENDED {self.catalog}.{layer}.{table}").filter("col_name = 'Table Properties'").select("data_type").first()[0]
        existing_properties = existing_properties.strip("[]").split(",")

        for k, v in tbl_properties.items():
            if f"{k}={v}" not in existing_properties:
                self.ts_print(f"Setting TBLPROPERTY: '{k}'='{v}'")
                sql_query = f"ALTER TABLE {self.catalog}.{layer}.{table} SET TBLPROPERTIES ('{k}'='{v}')"
                spark.sql(sql_query)
    
    def update_column_names(self, df, config):
        '''
        Updates the column names from the bronze names to the silver names.
        '''
        for row in config:
            df = df.withColumnRenamed(row['field_raw'], row['field_silver'])
            
        return df

    def add_metadata_columns(self, df, path, layer, load_type):
        '''
        Adds the details of the processed files as columns to the dataframe.
        '''
        # Get metadata details
        if "FULL" in load_type:
            file_details = dbutils.fs.ls(path)[0]
            file_path = f.lit(file_details[0])
            file_name = f.lit(file_details[1])
            file_size = f.lit(file_details[2])
            file_mod_time = f.lit(file_details[3])
        elif "STREAM" in load_type:
            file_path = f.col("_metadata.file_path")
            file_name = f.col("_metadata.file_name")
            file_size = f.col("_metadata.file_size")
            file_mod_time = f.col('_metadata.file_modification_time')

        # Add metadata columns
        if "FULL" in load_type and layer == "bronze":
            df = (df
                .withColumn("_file_path", file_path)
                .withColumn("_file_name", file_name)
                .withColumn("_file_size", file_size)
                .withColumn("_file_modification_time", file_mod_time)
                .withColumn("_loaded_at_utc", f.current_timestamp())
            )
        elif "STREAM" in load_type and layer == "bronze":
            df = (df
                .withColumn("_file_path", file_path)
                .withColumn("_file_name", file_name)
                .withColumn("_file_size", file_size)
                .withColumn("_file_modification_time", file_mod_time)
                .withColumn("_loaded_at_utc", f.current_timestamp())
                .withColumn("_last_updated_at_utc", f.current_timestamp())
            )
        elif "STREAM" in load_type and layer == "silver":
            df = (df
                .withColumn("_loaded_at_utc", f.current_timestamp())
                .withColumn("_last_updated_at_utc", f.current_timestamp())
            )
        return df

    def update_col_dtypes(self, df, config):
        '''
        Updates column values according to business logic.
        '''
        columns = df.columns

        for row in config:

            if row['field_silver'].startswith('is_') or row['field_silver'].startswith('has_'):
                df = df.withColumn(
                    row['field_raw'],
                    f.when(f.col(row['field_raw']) == True, 'Y').otherwise('N')
                )

        return df

    def get_primary_keys(self, columns, layer):
        '''
        Gets the primary keys for the specified layer.
        '''
        filter_col = "field_raw" if layer == "bronze" else "field_silver"
        df = self.config.filter(
            (f.col(filter_col).isin(columns)) &
            (f.col("layer") == "bronze")
        )
        df = df.collect()

        return df[0]["primary_keys"]

    def write_upsert(self, source_data, target_path, primary_keys, config, layer, load_type):
        '''
        Upserts records from the target dataframe to the target location.
        '''
        self.ts_print("Executing UPSERT...")
        # De-duplicate records in the case that the same record is present in multiple files in this batch
        partition_by_cols = self.get_primary_keys(primary_keys.split(";"), layer)
        partition = Window().partitionBy(partition_by_cols.split(";"))
        order_by = f.desc(f.col(f"_loaded_at_utc"))

        source_data = source_data.withColumn("Rank", f.row_number().over(partition.orderBy(order_by)))
        source_data = source_data.filter("Rank = 1").drop("Rank")

        join_condition = "1 = 1" # 1 = 1 allows us to append all subsequent conditions using AND
        for row in config.filter("is_primary_key = 'Y'").collect():
            target_key = row["field_raw" if layer == "bronze" else "field_silver"]
            source_key = row["field_raw"]
            join_condition += f" AND target.`{target_key}` = source.`{source_key}`" # Use ` backtick to safeguard column names that contain spaces

        # Specify field mapping for updates and inserts
        update_fields = {}
        insert_fields = {}
        for row in config.collect():
            k = row["field_raw" if layer == "bronze" else "field_silver"]
            v = f"source.{row['field_raw']}"
            update_fields[k] = v
            insert_fields[k] = v

        insert_fields["_loaded_at_utc"] = "current_timestamp()"
        insert_fields["_last_updated_at_utc"] = "current_timestamp()"
        update_fields["_last_updated_at_utc"] = "current_timestamp()"

        if layer == "bronze":
            insert_fields["_file_path"] = "source._file_path"
            insert_fields["_file_name"] = "source._file_name"
            insert_fields["_file_size"] = "source._file_size"
            insert_fields["_file_modification_time"] = "source._file_modification_time"
            insert_fields["_loaded_at_utc"] = "source._loaded_at_utc"

        if layer == "bronze" and "STREAM" in load_type:
            insert_fields["_last_updated_at_utc"] = "source._last_updated_at_utc"

        if layer == "silver" and "STREAM" in load_type:
            insert_fields["_loaded_at_utc"] = "source._loaded_at_utc"
            insert_fields["_last_updated_at_utc"] = "source._last_updated_at_utc"

        # Execute merge
        target_table = DeltaTable.forPath(spark, target_path)
        (
            target_table.alias("target")
            .merge(source_data.alias("source"), join_condition)
            .whenMatchedUpdate(set=update_fields)
            .whenNotMatchedInsert(values=insert_fields)
            .execute()
        )

    def write_bronze(self):
        '''
        Writes the raw data to the bronze container and creates a table in the bronze schema.
        '''
        self.ts_print(f"⏩ Executing bronze layer ingestion...")

        # Get config
        config_df = self.config.filter("layer = 'bronze'")
        config = config_df.collect()

        # Set variables
        table = config[0]['table']
        source_table = config[0]['source_table']
        file_type = (config[0]['load_type']).split(" ")[-1].lower()
        raw_path = f"{self.raw_path}/{source_table}"
        bronze_path =f"{self.bronze_path}/{table}"
        table_desc = config[0]['table_description']
        primary_keys = config[0]['primary_keys']
        load_type = config[0]['load_type']
        schema = self.get_schema('bronze', load_type)
        is_cdf = config[0]['is_cdf']

        # Set table properties
        tbl_properties = {}
        tbl_properties["delta.enableChangeDataFeed"] = "true" if is_cdf == "Y" else "false"
        tbl_properties["primary_keys"] = primary_keys

        # Full load to overwrite
        if load_type.startswith("FULL"):

            self.ts_print("Initiating full load...")

            # Read data from raw
            df = (
                spark.read
                .format(file_type)
                .option("header", True)
                .schema(schema)
                .load(raw_path)
            )

            # Update column values according to business logic
            df = self.update_col_dtypes(df, config)

            # Add metadata processing columns
            df = self.add_metadata_columns(df, raw_path, 'bronze', load_type)

            # Create the table and external location if they don't exist
            if not self.table_exists('bronze', table):
                self.create_table('bronze', table, table_desc, bronze_path, tbl_properties, config, load_type)

            # Alter table properties that don't match the config
            self.alter_table_properties("bronze", table, tbl_properties)

            # Write data to bronze
            df.write.format("delta").mode("overwrite").option("overwriteSchema", True).save(bronze_path)

            self.ts_print(f"Raw files from {raw_path} successfully written to {bronze_path}.")
        
        # Incremental load for UPSERT
        elif load_type.startswith("STREAM"):

            self.ts_print("Initiating stream...")

            # Create the table and external location if they don't exist
            if not self.table_exists('bronze', table):
                self.create_table('bronze', table, table_desc, bronze_path, tbl_properties, config, load_type)

            # Alter table properties that don't match the config
            self.alter_table_properties("bronze", table, tbl_properties)

            # Read unprocessed files from raw
            df = (
                spark.readStream
                .format("cloudFiles")
                .option("cloudFiles.format", file_type)
                .option("cloudFiles.useIncrementalListing", True) # Read files alphabetically (yyyMMdd_HHmmss format)
                .option("cloudFiles.schemaLocation", f"{bronze_path}/_schema")
                .option("cloudFiles.backfillInterval", "1 day")
                .option("header", True)
                .option("overwriteSchema", False)
                .option("mergeSchema", False)
                .schema(schema)
                .load(raw_path)
            )
            
            # Add metadata processing columns
            df = self.add_metadata_columns(df, raw_path, 'bronze', load_type)

            # Execute UPSERT for all batches of changes
            micro_batch = lambda batch_df, batch_id: self.write_upsert(batch_df, bronze_path, primary_keys, config_df, "bronze", load_type)
            query = (
                df.writeStream
                .foreachBatch(micro_batch)
                .option("checkpointLocation", f"{bronze_path}/_checkpoint")
                .trigger(once=True)
                .start(bronze_path)
            )

            # Wait for the stream to complete before continuing because it is asynchronous
            query.awaitTermination()

            self.ts_print(f"Raw files from {raw_path} successfully written to {bronze_path}.")

        self.ts_print(f"⏩ Bronze layer ingestion complete.")
        print('')
    
    def write_silver(self):
        '''
        Writes the bronze data to the silver container and creates a table in the silver schema.
        '''
        self.ts_print(f"⏩ Executing silver layer ingestion...")

        # Set config
        config_df = self.config.filter("layer = 'silver'")
        config = config_df.collect()

        # Set variables
        table = config[0]['table']
        source_table = config[0]['source_table']
        table_desc = config[0]['table_description']
        bronze_path = f"{self.bronze_path}/{source_table}"
        silver_path = f"{self.silver_path}/{table}"
        load_type = config[0]['load_type']
        schema = self.get_schema('silver', load_type)
        primary_keys = config[0]['primary_keys']

        # Set table properties
        tbl_properties = {}
        tbl_properties["primary_keys"] = primary_keys

        # Full load to overwrite
        if load_type.startswith("FULL"):

            self.ts_print("Initiating full load...")

            # Read data from bronze table
            df = (
                spark.read
                .option("header", True)
                .table(f"{self.catalog}.bronze.{source_table}")
            )

            # Drop processing columns
            for col in df.columns:
                if col.startswith('_'):
                    df = df.drop(col)

            # Update column names to silver names
            df = self.update_column_names(df, config)

            # Create the table and external location if they don't exist
            if not self.table_exists('silver', table):
                self.create_table('silver', table, table_desc, silver_path, tbl_properties, config, load_type)

            # Alter table properties that don't match the config
            self.alter_table_properties("silver", table, tbl_properties)

            # Write data to silver
            df.write.format("delta").mode("overwrite").option("overwriteSchema", True).save(silver_path)

        # Incremental load for CDF
        elif load_type.startswith("STREAM"):
            
            self.ts_print("Initiating stream...")

            # Create the table and external location if they don't exist
            if not self.table_exists('silver', table):
                self.create_table('silver', table, table_desc, silver_path, tbl_properties, config, load_type)

            # Alter table properties that don't match the config
            self.alter_table_properties("silver", table, tbl_properties)

            # Read bronze data via change data feed
            df = (
                spark.readStream
                .format("delta")
                .option("readChangeFeed", "true")
                .option("startingVersion", 0)
                .table(f"{self.catalog}.bronze.{source_table}")
            )

            # Execute UPSERT for all batches of changes
            micro_batch = lambda batch_df, batch_id: self.write_upsert(batch_df, silver_path, primary_keys, config_df, "silver", load_type)
            query = (
                df.writeStream
                .foreachBatch(micro_batch)
                .option("checkpointLocation", f"{silver_path}/_checkpoint")
                .trigger(once=True)
                .start(silver_path)
            )

            # Wait for the stream to complete before contuining because it is asynchronous
            query.awaitTermination()

        self.ts_print(f"Bronze data from {bronze_path} successfully written to {silver_path}.")

        self.ts_print(f"⏩ Silver layer ingestion complete.")
        print('')