In [1]:
from datetime import date, datetime
from pyspark.sql import SparkSession

import os
import json

In [2]:
path = os.getcwd()
initialLoadDate = datetime(2012, 12, 31, 0, 0, 0)
initialLoad = initialLoadDate.strftime("%Y-%m-%d %H:%M:%S")

In [3]:
newCutoff = ""
fromNotebook = True
destination = ""
tables = ""
sparkMaster = "local[*]"
retriesMax = 5
wh_process = "warehouse_wwi"
purge_existing="1"

In [4]:
jars = ""

if fromNotebook:
    f = open('warehouse_wwi.json',)
    config = json.load(f)
    f.close()

    newCutoff = config["cutoff_date"]
    jars = "../resources/jars/mssql-jdbc-13.2.0.jre11.jar,../jars/spark-bigquery-with-dependencies_2.12-0.42.2.jar"
    destination = config["destination"]
    tables = config["tables"]
else:
    jars = "{0}/resources/jars/mssql-jdbc-13.2.0.jre11.jar,{0}/jars/spark-bigquery-with-dependencies_2.12-0.42.2.jar".format(path)

print("jars", jars)
print("desstination", destination)
print("tables", tables)

jars ../resources/jars/mssql-jdbc-13.2.0.jre11.jar,../jars/spark-bigquery-with-dependencies_2.12-0.42.2.jar
desstination {'database': 'WideWorldImportersDW2', 'url': 'jdbc:sqlserver://localhost\\MSSQLSERVER05;database=WideWorldImportersDW2;user=sa;password=P@$$w0rd;encrypt=false'}
tables [{'table': 'DimCity', 'schema': 'dbo', 'storedProcedure': '[dbo].[ProcessDimCity]', 'validationSPs': [{'validationSP': 'dbo.ValidateDimCityData'}], 'notNullFields': 'CityKey,WWICityID,City,StateProvince,Country,Continent,SalesTerritory,Region,Subregion,LatestRecordedPopulation,LoadDate', 'uniqueFields': 'CityKey,WWICityID', 'notNullUniqueFields': '', 'foreignKeyFields': []}, {'table': 'DimCustomer', 'schema': 'dbo', 'storedProcedure': '[dbo].[ProcessDimCustomer]', 'validationSPs': [{'validationSP': 'dbo.ValidateDimCustomerData'}], 'notNullFields': 'CustomerKey,WWICustomerID,WWIDeliveryCityID,Customer,BillToCustomer,Category,BuyingGroup,PrimaryContact,PostalCode,LoadDate', 'uniqueFields': 'CustomerKey,W

In [5]:
def get_spark_session():
    spark = None

    if fromNotebook:
        spark = (
            SparkSession.builder 
                .config(
                    "spark.driver.host", 
                    "localhost"
                )
                .master(sparkMaster)
                .appName("load_wwi")
                .config("spark.jars", jars)    
                .getOrCreate()
        )
    else:
        spark = (
            SparkSession.builder 
                .master(sparkMaster)
                .appName("load_wwi")
                .config("spark.jars", jars)
                .config("spark.executor.memory", "4g")
                .config("spark.driver.memory", "4g")
                .config("spark.executor.cores", "6")
                .config("spark.cores.max", "6")
                .config("spark.network.timeout", "600s")
                .config("spark.executor.heartbeatInterval", "599s")
                .getOrCreate()
        )
    
    # Establish connecttion
    df = (
        spark.read
            .format("jdbc")
            .option("url", destination["url"])
            .option("dbtable", "(SELECT TempCol = 1) AS Temp")
            .option("driver", "com.microsoft.sqlserver.jdbc.SQLServerDriver")
            .load()
    )

    return spark

In [6]:
def stop_spark_session():
    active_spark_session = SparkSession.getActiveSession()
    
    if active_spark_session:
        active_spark_session.stop()

In [7]:
def validate_table(table):
    tableValidations = []

    # validation sps
    if len(table["validationSPs"]) > 0:
        for validationSP in table["validationSPs"]: 
            tableValidations.append(f"""
                EXEC {validationSP["validationSP"]} 
                    @Process='{wh_process}', 
                    @CutoffDate='{newCutoff}', 
                    @Schema='{table["schema"]}', 
                    @Table='{table["table"]}',
                    @PurgeExisting='{purge_existing}'
            """)

    # not null fields
    if len(table["notNullFields"]) > 0:
        tableValidations.append(f"""
            EXEC ValidateNotNullFields 
                @Process='{wh_process}', 
                @CutoffDate='{newCutoff}', 
                @Schema='{table["schema"]}', 
                @Table='{table["table"]}',
                @NotNullFields='{table["notNullFields"]}',
                @PurgeExisting='{purge_existing}'
        """)

    # unique fields
    if len(table["uniqueFields"]) > 0:
        tableValidations.append(f"""
            EXEC ValidateUniqueFields 
                @Process='{wh_process}', 
                @CutoffDate='{newCutoff}', 
                @Schema='{table["schema"]}', 
                @Table='{table["table"]}',
                @UniqueFields='{table["uniqueFields"]}',
                @PurgeExisting='{purge_existing}'
        """)

    # not null unique fields
    if len(table["notNullUniqueFields"]) > 0:
        tableValidations.append(f"""
            EXEC ValidateNotNullUniqueFields 
                @Process='{wh_process}', 
                @CutoffDate='{newCutoff}', 
                @Schema='{table["schema"]}', 
                @Table='{table["table"]}',
                @NotNullUniqueFields='{table["notNullUniqueFields"]}',
                @PurgeExisting='{purge_existing}'
        """)

    # foreign key fields
    if len(table["foreignKeyFields"]) > 0:
        tableValidations.append(f"""
            EXEC ValidateForeignKeyFields 
                @Process='{wh_process}', 
                @CutoffDate='{newCutoff}', 
                @Schema='{table["schema"]}', 
                @Table='{table["table"]}',
                @ForeignKeyFields='{json.dumps(table["foreignKeyFields"])}',
                @PurgeExisting='{purge_existing}'
        """)

    sql = ";".join(tableValidations)

    # print(sql)

    retries = 0
    isSuccessful = False
    
    while retries < retriesMax:

        retries = retries + 1

        try:
            
            spark = get_spark_session()

            connection = spark._jvm.java.sql.DriverManager.getConnection(destination["url"])
            
            statement = connection.createStatement()

            statement.executeUpdate(sql)

            retries = retriesMax

            isSuccessful = True
            
        except Exception as ex:
            stop_spark_session()

            if retries >= retriesMax:
                raise ex

    if isSuccessful == False:
        raise Exception("Unable to verify counts")

def raise_validation_error():
    error_count = 0
    retries = 0
    isSuccessful = False

    while retries < retriesMax:

        retries = retries + 1

        try:

                spark = get_spark_session()

                dbTable = f"""
                    (
                        SELECT 
                            error_count = COUNT(*) 
                        FROM 
                            dbo.DataValidationErrors
                        WHERE
                            Process = '{wh_process}' AND
                            CutoffDate = '{newCutoff}'
                    ) AS error
                """

                # read 
                df = (
                    spark.read
                        .format("jdbc")
                        .option("url", destination["url"])
                        .option("dbtable", dbTable)
                        .option("driver", "com.microsoft.sqlserver.jdbc.SQLServerDriver")
                        .load()
                )

                df.show()

                error_count = int(df.toPandas()["error_count"][0])

                retries = retriesMax

                isSuccessful = True
            
        except Exception as ex:
            stop_spark_session()

            if retries >= retriesMax:
                raise ex

    if isSuccessful == False:
        raise Exception("Unable to raise validation error")
    else:
        if error_count > 0:
            raise Exception(F"{error_count} validation error(s) found.  Please check DataValidationErrors table.")

for table in tables:
    validate_table(table)

raise_validation_error()


+-----------+
|error_count|
+-----------+
|          0|
+-----------+

