In [None]:
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql import functions as F
from pyspark.sql.types import StructType
from delta.tables import DeltaTable

from databricks.labs.dqx.engine import DQEngine

# Define function

In [None]:
def read_data(spark: SparkSession, file_format: str, schema: StructType, file_path: str) -> DataFrame:
    """
    Read data from data lake storage.

    Parameter:
        spark: Spark session.
        file_format: File format name.
        schema: Dataframe schema.
        file_path: File path in storage.

    Return:
        Data as dataframe.
    """

    df = spark \
            .read \
            .format(file_format) \
            .schema(schema) \
            .load(file_path)
    
    return df

In [None]:
def read_dim_data(spark: SparkSession, dim_table: str, col: str, file_path: str) -> DataFrame:
    """
    Read data from dimension table if existed. If not, select empty table from silver data.

    Parameter:
        spark: Spark session.
        dim_table: Dimension table name.
        col: Necessary columns.
        file_path: Silver data file path to select columns from.
    
    Return:
        Dimension table data as dataframe.
    """

    # if table existed (incremental load)
    if spark.catalog.tableExists(dim_table):

        df_dimension = spark.sql(f'''
                                SELECT {col}
                                FROM {dim_table}
                            ''')
    
    # if not existed (full load)
    else:

        query = f'''
            SELECT 1 AS {col}
            FROM parquet.`{file_path}`
            WHERE 1 = 0
        '''
        df_dimension = spark.sql(query)

    return df_dimension

In [None]:
def get_surrogate_key(spark: SparkSession, dim_table: str, df_dimension: DataFrame, col: str) -> int:
    """
    Get maximum surrogate key before generating new key for new records.

    Parameter:
        spark: Spark session.
        dim_table: Dimension table name.
        df_dimension: Dimension table dataframe.
        col: Surrogate key column.

    Return:
        Maximum surrogate key as integer.
    """

    # if table existed, get max key from column
    if spark.catalog.tableExists(dim_table):

        max_key = (df_dimension.select(F.max(col)).collect()[0][0] or 0) + 1

    # if not existed, use 1 as starter value
    else:

        max_key = 1

    return max_key

In [None]:
def data_quality_checks(dq_engine: DQEngine, checks_file_path: str, df: DataFrame) -> None:
    """
    Run data quality checks in dataframe using checks file, raise error if check failed.

    Parameter:
        dq_engine: Data quality instance.
        checks_file_path: Workspace file path for checks file.
        df: Dataframe to check data quality.

    Return:
        None.
    """

    # load checks file
    checks = dq_engine.load_checks_from_workspace_file(workspace_path=checks_file_path)

    # apply checks to dataframe and return checks dataframe
    df_check = dq_engine.apply_checks_by_metadata(df, checks)

    # count checks error
    error_count = df_check.select('_errors').filter(F.col('_errors').isNotNull()).count()

    # raise error if failed
    assert error_count == 0, f'{error_count} errors found in the data'

In [None]:
def merge_data(
    spark: SparkSession,
    table_name: str,
    source_df: DataFrame,
    merge_condition: str,
    storage_path: str
) -> None:
    """
    Upsert data into delta table, create new table if not existed.

    Parameter:
        spark: Spark Session.
        table_name: Delta table name.
        source_df: Source dataframe to merge into dimension table.
        merge_condition: Key condition to merge between source and target.
        storage_path: Storage path to write data into.

    Return:
        None.
    """

    # merge data into existing table
    if spark.catalog.tableExists(table_name):

        delta_table = DeltaTable.forName(spark, table_name)

        delta_table.alias('trg').merge(
            source=source_df.alias('src'),
            condition=merge_condition
        ) \
        .whenMatchedUpdateAll() \
        .whenNotMatchedInsertAll() \
        .execute()

    # create new table if not existed
    else:

        source_df \
            .write \
            .format('delta') \
            .mode('overwrite') \
            .option('path', storage_path) \
            .saveAsTable(table_name)