In [1]:
from datetime import date, datetime
from pyspark.sql import SparkSession

import os
import json

In [2]:
path = os.getcwd()
initialLoadDate = datetime(2012, 12, 31, 0, 0, 0)
initialLoad = initialLoadDate.strftime("%Y-%m-%d %H:%M:%S")

In [3]:
newCutoff = ""
fromNotebook = True
source = ""
destination = ""
tables = ""
sparkMaster = "local[*]"
retriesMax = 1

In [4]:
jars = ""

if fromNotebook:
    f = open('load_wwi.json',)
    config = json.load(f)
    f.close()

    newCutoff = config["cutoff_date"]
    jars = "../resources/jars/mssql-jdbc-13.2.0.jre11.jar"
    source = config["source"]
    destination = config["destination"]
    tables = config["tables"]
else:
    googleCredentials = "{0}/resources/account_keys/dbt-tutorial-462014.json".format(path)
    jars = "{0}/resources/jars/mssql-jdbc-13.2.0.jre11.jar".format(path)

print("jars", jars)
print("source", source)
print("desstination", destination)
print("tables", tables)

jars ../resources/jars/mssql-jdbc-13.2.0.jre11.jar
source {'database': 'WideWorldImporters', 'url': 'jdbc:sqlserver://localhost\\MSSQLSERVER05;database=WideWorldImporters;user=sa;password=P@$$w0rd;encrypt=false'}
desstination {'database': 'WideWorldImportersDW2', 'url': 'jdbc:sqlserver://localhost\\MSSQLSERVER05;database=WideWorldImportersDW2;user=sa;password=P@$$w0rd;encrypt=false'}
tables [{'source': {'schema': 'Application', 'table': 'Cities', 'type': 'ValidDateRange'}, 'destination': {'schema': 'dbo', 'table': 'Application_Cities'}}, {'source': {'schema': 'Application', 'table': 'Cities_Archive', 'type': 'ValidDateRange'}, 'destination': {'schema': 'dbo', 'table': 'Application_Cities_Archive'}}, {'source': {'schema': 'Application', 'table': 'Countries', 'type': 'ValidDateRange'}, 'destination': {'schema': 'dbo', 'table': 'Application_Countries'}}, {'source': {'schema': 'Application', 'table': 'Countries_Archive', 'type': 'ValidDateRange'}, 'destination': {'schema': 'dbo', 'table': 

In [5]:
def get_spark_session():
    if fromNotebook:
        return (
            SparkSession.builder 
                .config(
                    "spark.driver.host", 
                    "localhost"
                )
                .master(sparkMaster)
                .appName("load_wwi")
                .config("spark.jars", jars)    
                .getOrCreate()
        )
    else:
        return (
            SparkSession.builder 
                .master(sparkMaster)
                .appName("load_wwi")
                .config("spark.jars", jars)
                .config("spark.executor.memory", "4g")
                .config("spark.driver.memory", "4g")
                .config("spark.executor.cores", "6")
                .config("spark.cores.max", "6")
                .config("spark.network.timeout", "600s")
                .config("spark.executor.heartbeatInterval", "599s")
                .getOrCreate()
        )

In [6]:
def stop_spark_session():
    active_spark_session = SparkSession.getActiveSession()
    
    if active_spark_session:
        active_spark_session.stop()

In [7]:
def verify_counts():
    sourceCountSqls = []
    destinationCountSqls = []
    
    for table in tables:
        sourceCountSqls.append(
            """
            SELECT
                [TableName] = '@Table',
                [Origin] = 'Source',
                [SourceCount] = 
                (
                    SELECT 
                        COUNT(*) 
                    FROM 
                        @Database.@Schema.@Table 
                    WHERE 
                        @DateColumn <= '@CutoffDate'
                )
            """
            .replace("@Database", source["database"])
            .replace("@Schema", table["source"]["schema"])
            .replace("@Table", table["source"]["table"])
            .replace("@CutoffDate", newCutoff)
            .replace("@DateColumn", "ValidFrom" if table["source"]["type"] >= "ValidDateRange" else "LastEditedWhen" )
        )

        destinationCountSqls.append(
            """
            SELECT
                [TableName] = '@OriginTable',
                [Origin] = 'Destination',
                [DestinationCount] = 
                (
                    SELECT 
                        COUNT(*) 
                    FROM 
                        @Database.@Schema.@Table 
                    WHERE 
                        LoadDate <= '@CutoffDate'
                )
            """
            .replace("@Database", destination["database"])
            .replace("@Schema", table["destination"]["schema"])
            .replace("@Table", table["destination"]["table"])
            .replace("@OriginTable", table["source"]["table"])
            .replace("@CutoffDate", newCutoff)
        )

    sourceSql = """ 
    UNION 
    """.join(sourceCountSqls)
    sourceSql = "(" + sourceSql + ") AS SourceCount"

    destinationSql = """ 
    UNION 
    """.join(destinationCountSqls)
    destinationSql = "(" + destinationSql + ") AS DestinationCount" 

    retries = 0
    isSuccessful = False
    mismatchedTables = ""

    while retries < retriesMax:

        retries = retries + 1

        try:
            
            spark = get_spark_session()

            # read 
            df_source = (
                spark.read
                    .format("jdbc")
                    .option("url", source["url"])
                    .option("dbtable", sourceSql)
                    .option("driver", "com.microsoft.sqlserver.jdbc.SQLServerDriver")
                    .load()
            )

            df_destination = (
                spark.read
                    .format("jdbc")
                    .option("url", destination["url"])
                    .option("dbtable", destinationSql)
                    .option("driver", "com.microsoft.sqlserver.jdbc.SQLServerDriver")
                    .load()
            )

            df_final = df_source.join(
                df_destination, 
                on="TableName", 
                how="left"
            )

            df_final = df_final[df_final['SourceCount'] != df_final['DestinationCount']].toPandas().fillna(0)

            if len(df_final) > 0:
                mismatchedCountTables = []

                for index, row in df_final.iterrows():
                    tableName = str(row["TableName"])
                    sourceCount = str(row["SourceCount"])
                    destinationCount = str(row["DestinationCount"])
                    mismatchedCountTables.append(f"{tableName}: Source={sourceCount}, Destination={destinationCount}")

                mismatchedTables = """
                """.join(mismatchedCountTables)


            retries = retriesMax

            isSuccessful = True
            
        except Exception as ex:
            stop_spark_session()

            if retries >= retriesMax:
                raise ex

    if len(mismatchedTables) > 0:
        raise Exception("The following tables does not have matching count: ", mismatchedTables)
    
    if isSuccessful == False:
        raise Exception("Unable to verify counts")

verify_counts()
        