In [None]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import (
    StructField, StructType, IntegerType,
    StringType, DateType, DecimalType
)

from databricks.labs.dqx.engine import DQEngine
from databricks.sdk import WorkspaceClient

import logging

# Setting log level

In [None]:
logging.basicConfig(level=logging.INFO)

# Create parameter

In [None]:
dbutils.widgets.text('storage_account', '0')
dbutils.widgets.text('year', '0')
dbutils.widgets.text('month', '0')
dbutils.widgets.text('day', '0')

In [None]:
storage_account = dbutils.widgets.get('storage_account')
year = dbutils.widgets.get('year')
month = dbutils.widgets.get('month')
day = dbutils.widgets.get('day')

silver_file_path = f'abfss://silver@{storage_account}.dfs.core.windows.net/transformed_data/{year}/{month}/{day}/'

# create data quality instance connected to databricks workspace
dq_engine = DQEngine(WorkspaceClient())

# Define schema

In [None]:
silver_schema = StructType([
    StructField('Sales_Person_ID', IntegerType(), False),
    StructField('Sales_Person', StringType(), False),
    StructField('Country', StringType(), False),
    StructField('Product_ID', IntegerType(), False),
    StructField('Product', StringType(), False),
    StructField('Date', DateType(), False),
    StructField('Revenue', IntegerType(), False),
    StructField('Boxes_Shipped', IntegerType(), False),
    StructField('First_Name', StringType(), False),
    StructField('Last_Name', StringType(), False),
    StructField('Revenue_Per_Box', DecimalType(10, 2), False),
    StructField('Date_Key', StringType(), False),
    StructField('Year', IntegerType(), False),
    StructField('Quarter', IntegerType(), False),
    StructField('Month', IntegerType(), False),
    StructField('Day', IntegerType(), False),
    StructField('Start_Of_Year', DateType(), False),
    StructField('Start_Of_Quarter', DateType(), False),
    StructField('Start_Of_Month', DateType(), False)
])

# Run common functions

In [None]:
%run ./utils/common_functions

# Define dimension functions

In [None]:
def create_dim_sales_person(
    spark: SparkSession,
    silver_schema: StructType,
    silver_file_path: str,
    dq_engine: DQEngine,
    storage_account: str
) -> None:
    """
    Create sales person dimension table.

    Parameter:
        spark: Spark session.
        silver_schema: Schema of silver layer data.
        silver_file_path: File path in silver layer storage.
        dq_engine: Data quality instance.
        storage_account: Storage account name.

    Return:
        None.
    """

    # define variables
    dim_table_name = 'sales_catalog.gold.dim_sales_person'
    checks_file_path = '/pipeline_project/check/checks_gold_dim_sales_person.yml'
    merge_condition = 'trg.Sales_Person_Key = src.Sales_Person_Key'
    gold_storage_path = f'abfss://gold@{storage_account}.dfs.core.windows.net/dim_sales_person'

    # read data from silver layer
    df_silver = read_data(spark, 'parquet', silver_schema, silver_file_path)

    # select necessary columns
    df_silver = df_silver.select(
                                'Sales_Person_ID',
                                'Sales_Person',
                                'First_Name',
                                'Last_Name'
                            ).distinct()

    # select dimension table data, select empty table if not existed
    df_dimension = read_dim_data(
                        spark,
                        dim_table_name,
                        'Sales_Person_Key, Sales_Person_ID, Sales_Person, First_Name, Last_Name',
                        silver_file_path
                    )
    
    # filter new records and existing records
    df_all = df_silver.join(df_dimension, on='Sales_Person_ID', how='left') \
                        .select(
                            df_dimension['Sales_Person_Key'],
                            df_silver['Sales_Person_ID'],
                            df_silver['Sales_Person'],
                            df_silver['First_Name'],
                            df_silver['Last_Name']
                        )
    
    df_existed = df_all.filter(F.col('Sales_Person_Key').isNotNull())

    df_new = df_all.filter(F.col('Sales_Person_Key').isNull())

    # get maximum surrogate key
    max_key = get_surrogate_key(spark, dim_table_name, df_dimension, 'Sales_Person_Key')

    # generate surrogate key for new records
    df_new = df_new.withColumn('Sales_Person_Key', F.monotonically_increasing_id() + max_key)

    # combine new records with surrogate key and existing records
    df_total = df_existed.union(df_new)

    # data quality checks
    data_quality_checks(dq_engine, checks_file_path, df_total)

    # merge data
    merge_data(spark, dim_table_name, df_total, merge_condition, gold_storage_path)

In [None]:
def create_dim_product(
    spark: SparkSession,
    silver_schema: StructType,
    silver_file_path: str,
    dq_engine: DQEngine,
    storage_account: str
) -> None:
    """
    Create product dimension table.

    Parameter:
        spark: Spark session.
        silver_schema: Schema of silver layer data.
        silver_file_path: File path in silver layer storage.
        dq_engine: Data quality instance.
        storage_account: Storage account name.

    Return:
        None.
    """

    # define variables
    dim_table_name = 'sales_catalog.gold.dim_product'
    checks_file_path = '/pipeline_project/check/checks_gold_dim_product.yml'
    merge_condition = 'trg.Product_Key = src.Product_Key'
    gold_storage_path = f'abfss://gold@{storage_account}.dfs.core.windows.net/dim_product'

    # read data from silver layer
    df_silver = read_data(spark, 'parquet', silver_schema, silver_file_path)

    # select necessary columns
    df_silver = df_silver.select('Product_ID', 'Product').distinct()

    # select dimension table data, select empty table if not existed
    df_dimension = read_dim_data(
                        spark,
                        dim_table_name,
                        'Product_Key, Product_ID, Product',
                        silver_file_path
                    )
    
    # filter new records and existing records
    df_all = df_silver.join(df_dimension, on='Product_ID', how='left') \
                        .select(
                            df_dimension['Product_Key'],
                            df_silver['Product_ID'],
                            df_silver['Product']
                        )
    
    df_existed = df_all.filter(F.col('Product_Key').isNotNull())

    df_new = df_all.filter(F.col('Product_Key').isNull())

    # get maximum surrogate key
    max_key = get_surrogate_key(spark, dim_table_name, df_dimension, 'Product_Key')

    # generate surrogate key for new records
    df_new = df_new.withColumn('Product_Key', F.monotonically_increasing_id() + max_key)

    # combine new records with surrogate key and existing records
    df_total = df_existed.union(df_new)

    # data quality checks
    data_quality_checks(dq_engine, checks_file_path, df_total)

    # merge data
    merge_data(spark, dim_table_name, df_total, merge_condition, gold_storage_path)

In [None]:
def create_dim_country(
    spark: SparkSession,
    silver_schema: StructType,
    silver_file_path: str,
    dq_engine: DQEngine,
    storage_account: str
) -> None:
    """
    Create country dimension table.

    Parameter:
        spark: Spark session.
        silver_schema: Schema of silver layer data.
        silver_file_path: File path in silver layer storage.
        dq_engine: Data quality instance.
        storage_account: Storage account name.

    Return:
        None.
    """

    # define variables
    dim_table_name = 'sales_catalog.gold.dim_country'
    checks_file_path = '/pipeline_project/check/checks_gold_dim_country.yml'
    merge_condition = 'trg.Country_Key = src.Country_Key'
    gold_storage_path = f'abfss://gold@{storage_account}.dfs.core.windows.net/dim_country'

    # read data from silver layer
    df_silver = read_data(spark, 'parquet', silver_schema, silver_file_path)

    # select necessary columns
    df_silver = df_silver.select('Country').distinct()

    # select dimension table data, select empty table if not existed
    df_dimension = read_dim_data(
                        spark,
                        dim_table_name,
                        'Country_Key, Country',
                        silver_file_path
                    )
    
    # filter new records and existing records
    df_all = df_silver.join(df_dimension, on='Country', how='left') \
                        .select(
                            df_dimension['Country_Key'],
                            df_silver['Country']
                        )
    
    df_existed = df_all.filter(F.col('Country_Key').isNotNull())

    df_new = df_all.filter(F.col('Country_Key').isNull())

    # get maximum surrogate key
    max_key = get_surrogate_key(spark, dim_table_name, df_dimension, 'Country_Key')

    # generate surrogate key for new records
    df_new = df_new.withColumn('Country_Key', F.monotonically_increasing_id() + max_key)

    # combine new records with surrogate key and existing records
    df_total = df_existed.union(df_new)

    # data quality checks
    data_quality_checks(dq_engine, checks_file_path, df_total)

    # merge data
    merge_data(spark, dim_table_name, df_total, merge_condition, gold_storage_path)

In [None]:
def create_dim_date(
    spark: SparkSession,
    silver_schema: StructType,
    silver_file_path: str,
    dq_engine: DQEngine,
    storage_account: str
) -> None:
    """
    Create date dimension table.

    Parameter:
        spark: Spark session.
        silver_schema: Schema of silver layer data.
        silver_file_path: File path in silver layer storage.
        dq_engine: Data quality instance.
        storage_account: Storage account name.

    Return:
        None.
    """

    # define variables
    dim_table_name = 'sales_catalog.gold.dim_date'
    checks_file_path = '/pipeline_project/check/checks_gold_dim_date.yml'
    merge_condition = 'trg.Date_Key = src.Date_Key'
    gold_storage_path = f'abfss://gold@{storage_account}.dfs.core.windows.net/dim_date'

    # read data from silver layer
    df_silver = read_data(spark, 'parquet', silver_schema, silver_file_path)

    # select necessary columns
    df_total = df_silver.select(
                            'Date_Key',
                            'Date',
                            'Year',
                            'Quarter',
                            'Month',
                            'Day',
                            'Start_Of_Year',
                            'Start_Of_Quarter',
                            'Start_Of_Month'
                        ).distinct()

    # data quality checks
    data_quality_checks(dq_engine, checks_file_path, df_total)

    # merge data
    merge_data(spark, dim_table_name, df_total, merge_condition, gold_storage_path)


# Define main function

In [None]:
def main() -> None:
    """
    Main function to create dimension tables.

    Parameter:
        None.
    
    Return:
        None.
    """

    try:
        logging.info('Creating sales person dimension.')
        create_dim_sales_person(spark, silver_schema, silver_file_path, dq_engine, storage_account)

        logging.info('Creating product dimension.')
        create_dim_product(spark, silver_schema, silver_file_path, dq_engine, storage_account)

        logging.info('Creating country dimension.')
        create_dim_country(spark, silver_schema, silver_file_path, dq_engine, storage_account)

        logging.info('Creating date dimension.')
        create_dim_date(spark, silver_schema, silver_file_path, dq_engine, storage_account)

        logging.info('Finished creating all dimension tables.')

    except Exception as e:
        logging.error(f'Error occured: {e}')
        raise

# Run

In [None]:
if __name__ == '__main__':
    main()