In [None]:
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql import functions as F
from pyspark.sql.types import (
    StructField, StructType, IntegerType,
    StringType, DateType, DecimalType
)

from databricks.labs.dqx.engine import DQEngine
from databricks.sdk import WorkspaceClient

import logging

# Setting log level

In [None]:
logging.basicConfig(level=logging.INFO)

# Create parameter

In [None]:
dbutils.widgets.text('storage_account', '0')
dbutils.widgets.text('year', '0')
dbutils.widgets.text('month', '0')
dbutils.widgets.text('day', '0')

In [None]:
storage_account = dbutils.widgets.get('storage_account')
year = dbutils.widgets.get('year')
month = dbutils.widgets.get('month')
day = dbutils.widgets.get('day')

silver_file_path = f'abfss://silver@{storage_account}.dfs.core.windows.net/transformed_data/{year}/{month}/{day}/'

# create data quality instance connected to databricks workspace
dq_engine = DQEngine(WorkspaceClient())

# Define schema

In [None]:
silver_schema = StructType([
    StructField('Sales_Person_ID', IntegerType(), False),
    StructField('Sales_Person', StringType(), False),
    StructField('Country', StringType(), False),
    StructField('Product_ID', IntegerType(), False),
    StructField('Product', StringType(), False),
    StructField('Date', DateType(), False),
    StructField('Revenue', IntegerType(), False),
    StructField('Boxes_Shipped', IntegerType(), False),
    StructField('First_Name', StringType(), False),
    StructField('Last_Name', StringType(), False),
    StructField('Revenue_Per_Box', DecimalType(10, 2), False),
    StructField('Date_Key', StringType(), False),
    StructField('Year', IntegerType(), False),
    StructField('Quarter', IntegerType(), False),
    StructField('Month', IntegerType(), False),
    StructField('Day', IntegerType(), False),
    StructField('Start_Of_Year', DateType(), False),
    StructField('Start_Of_Quarter', DateType(), False),
    StructField('Start_Of_Month', DateType(), False)
])

# Run common functions

In [None]:
%run ./utils/common_functions

# Define extra check function

In [None]:
def check_composite_key(df: DataFrame, col: list[str]) -> None:
    """
    Check for duplicates in composite key columns, raise error if check failed.

    Parameter:
        df: Dataframe to check data quality.
        col: List of composite key columns to count duplicates.

    Return:
        None.
    """

    composite_duplicate_count = df.groupBy(col).count().filter(F.col('count') > 1).count()

    assert composite_duplicate_count == 0, f'{composite_duplicate_count} composite duplicates found in the data'

# Define main function

In [None]:
def main() -> None:
    """
    Main function to create fact table.

    Parameter:
        None.

    Return:
        None.
    """

    # define variables
    fact_table_name = 'sales_catalog.gold.fact_sales'
    checks_file_path = '/pipeline_project/check/checks_gold_fact_sales.yml'
    merge_condition = '''
        trg.Sales_Person_Key = src.Sales_Person_Key
        AND trg.Product_Key = src.Product_Key
        AND trg.Country_Key = src.Country_Key
        AND trg.Date_Key = src.Date_Key
    '''
    gold_storage_path = f'abfss://gold@{storage_account}.dfs.core.windows.net/fact_sales'

    try:
        logging.info('Creating fact sales.')

        # read data from silver layer
        df_silver = read_data(spark, 'parquet', silver_schema, silver_file_path)

        # read data from dimension tables
        df_dim_sales_person = spark.read.table('sales_catalog.gold.dim_sales_person')
        df_dim_product = spark.read.table('sales_catalog.gold.dim_product')
        df_dim_country = spark.read.table('sales_catalog.gold.dim_country')
        df_dim_date = spark.read.table('sales_catalog.gold.dim_date')

        # select necessary columns
        df_total = df_silver.join(df_dim_sales_person, on='Sales_Person_ID', how='left') \
                            .join(df_dim_product, on='Product_ID', how='left') \
                            .join(df_dim_country, on='Country', how='left') \
                            .join(df_dim_date, on='Date_Key', how='left') \
                            .select(
                                df_dim_sales_person['Sales_Person_Key'],
                                df_dim_product['Product_Key'],
                                df_dim_country['Country_Key'],
                                df_dim_date['Date_Key'],
                                df_silver['Revenue'],
                                df_silver['Boxes_Shipped'],
                                df_silver['Revenue_Per_Box']
                            )
        
        # data quality checks
        data_quality_checks(dq_engine, checks_file_path, df_total)

        # composite key check
        check_composite_key(df_total, ['Sales_Person_Key', 'Product_Key', 'Country_Key', 'Date_Key'])

        # merge data
        merge_data(spark, fact_table_name, df_total, merge_condition, gold_storage_path)

        logging.info('Finished creating fact table.')

    except Exception as e:
        logging.error(f'Error occured: {e}')
        raise

# Run

In [None]:
if __name__ == '__main__':
    main()