In [1]:
from datetime import date, datetime
from pyspark.sql import SparkSession

import os
import json

In [2]:
path = os.getcwd()
initialLoadDate = datetime(2012, 12, 31, 0, 0, 0)
initialLoad = initialLoadDate.strftime("%Y-%m-%d %H:%M:%S")

In [3]:
newCutoff = ""
fromNotebook = True
source = ""
destination = ""
tables = ""
sparkMaster = "local[*]"
retriesMax = 5

In [5]:
jars = ""

if fromNotebook:
    f = open('load_wwi.json',)
    config = json.load(f)
    f.close()

    newCutoff = config["cutoff_date"]
    jars = "../resources/jars/mssql-jdbc-13.2.0.jre11.jar"
    source = config["source"]
    destination = config["destination"]
    tables = config["tables"]
else:
    jars = "{0}/resources/jars/mssql-jdbc-13.2.0.jre11.jar".format(path)

print("jars", jars)
print("source", source)
print("desstination", destination)
print("tables", tables)

jars ../resources/jars/mssql-jdbc-13.2.0.jre11.jar
source {'database': 'WideWorldImporters', 'url': 'jdbc:sqlserver://localhost\\MSSQLSERVER05;database=WideWorldImporters;user=sa;password=P@$$w0rd;encrypt=false'}
desstination {'database': 'WideWorldImportersDW', 'url': 'jdbc:sqlserver://localhost\\MSSQLSERVER05;database=WideWorldImportersDW;user=sa;password=P@$$w0rd;encrypt=false'}
tables [{'source': {'schema': 'Application', 'table': 'Cities', 'type': 'ValidDateRange'}, 'destination': {'schema': 'dbo', 'table': 'Application_Cities'}}, {'source': {'schema': 'Application', 'table': 'Cities_Archive', 'type': 'ValidDateRange'}, 'destination': {'schema': 'dbo', 'table': 'Application_Cities_Archive'}}, {'source': {'schema': 'Application', 'table': 'Countries', 'type': 'ValidDateRange'}, 'destination': {'schema': 'dbo', 'table': 'Application_Countries'}}, {'source': {'schema': 'Application', 'table': 'Countries_Archive', 'type': 'ValidDateRange'}, 'destination': {'schema': 'dbo', 'table': 'A

In [6]:
def get_spark_session():
    if fromNotebook:
        return (
            SparkSession.builder 
                .config(
                    "spark.driver.host", 
                    "localhost"
                )
                .master(sparkMaster)
                .appName("load_wwi")
                .config("spark.jars", jars)    
                .getOrCreate()
        )
    else:
        return (
            SparkSession.builder 
                .master(sparkMaster)
                .appName("load_wwi")
                .config("spark.jars", jars)
                .config("spark.executor.memory", "4g")
                .config("spark.driver.memory", "4g")
                .config("spark.executor.cores", "6")
                .config("spark.cores.max", "6")
                .config("spark.network.timeout", "600s")
                .config("spark.executor.heartbeatInterval", "599s")
                .getOrCreate()
        )

In [7]:
def stop_spark_session():
    active_spark_session = SparkSession.getActiveSession()
    
    if active_spark_session:
        active_spark_session.stop()

In [8]:
def get_source_db_table(sourceDatabase, sourceSchema, sourceTable, sourceType, lastCutoff):
    sourceDbTable = ""

    if sourceType == 'ValidDateRange':
        sourceDbTable = """
            (
                SELECT 
                    *,
                    [LoadDate] = CAST('@NewCutoff' AS DATETIME)
                FROM 
                    @Database.@Schema.@Table 
                WHERE 
                    ValidFrom > '@LastCutoff' AND	
                    ValidFrom <= '@NewCutoff' 
            ) AS @Table
        """
    else:
        sourceDbTable = """
            (
                SELECT 
                    *,
                    [LoadDate] = CAST('@NewCutoff' AS DATETIME)
                FROM 
                    @Database.@Schema.@Table 
                WHERE 
                    LastEditedWhen > '@LastCutoff' AND
                    LastEditedWhen <= '@NewCutoff'
            ) AS @Table
        """

    sourceDbTable = sourceDbTable.replace(
        "@Database", 
        sourceDatabase
    ).replace(
        "@Schema", 
        sourceSchema
    ).replace(
        "@Table",
        sourceTable
    ).replace(
        "@LastCutoff",
        lastCutoff
    ).replace(
        "@NewCutoff",
        newCutoff
    ).replace(
        "@InitialLoad",
        initialLoad
    )

    return sourceDbTable

def load_wwi_to_wh(table, lastCutoff):
    sourceDatabase = source["database"]
    sourceSchema = table["source"]["schema"]
    sourceTable = table["source"]["table"]
    sourceType = table["source"]["type"]
    sourceDbTable = get_source_db_table(
        sourceDatabase,
        sourceSchema,
        sourceTable,
        sourceType,
        lastCutoff
    )
    
    destDatabase = destination["database"]
    destSchema = table["destination"]["schema"]
    destTable = table["destination"]["table"]
    destDbTable = "{0}.{1}.{2}".format(destDatabase, destSchema, destTable)

    retries = 0
    isSuccessful = False

    while retries < retriesMax:
        retries = retries + 1

        try:
            
            spark = get_spark_session()
            
            # delete 
            if (retries - 1) > 0:
                delete_duplicates(table)

            # read 
            df = (
                spark.read
                    .format("jdbc")
                    .option("url", source["url"])
                    .option("dbtable", sourceDbTable)
                    .option("driver", "com.microsoft.sqlserver.jdbc.SQLServerDriver")
                    .load()
            )
            
            # write
            (
                df.write
                    .format("jdbc")
                    .mode("append")
                    .option("url", destination["url"])
                    .option("dbtable", destDbTable)
                    .option("driver", "com.microsoft.sqlserver.jdbc.SQLServerDriver")
                    .save()
            )

            retries = retriesMax

            isSuccessful = True

        except Exception as ex:
            stop_spark_session()

            if retries >= retriesMax:
                raise ex

    if isSuccessful == False:
        raise Exception("Unable to load table to warehouse")
    else:
        # save success load history
        insert_load_history(table, 'Successful')

def delete_duplicates(table):
    destDatabase = destination["database"]
    destSchema = table["destination"]["schema"]
    destTable = table["destination"]["table"]
    destDbTable = "{0}.{1}.{2}".format(destDatabase, destSchema, destTable)

    retries = 0
    isSuccessful = False

    while retries < retriesMax:    
        
        retries = retries + 1
        
        try:
    
            spark = get_spark_session()

            connection = spark._jvm.java.sql.DriverManager.getConnection(destination["url"])
            
            statement = connection.createStatement()
            
            result = statement.executeQuery("""
                SELECT row_count = COUNT(*)
                FROM INFORMATION_SCHEMA.TABLES
                WHERE TABLE_CATALOG = '{0}' AND TABLE_SCHEMA = '{1}' AND TABLE_NAME = '{2}'
            """.format(destDatabase, destSchema, destTable))
            
            result.next()

            if result.getInt("row_count") > 0:
                statement.executeUpdate("DELETE {0} WHERE LoadDate = '{1}'".format(destDbTable, newCutoff))

            retries = retriesMax
    
            isSuccessful = True

        except Exception as ex:
            stop_spark_session()

            if retries >= retriesMax:
                raise ex

    if isSuccessful == False:
        raise Exception("Unable to delete duplicates")
    
def insert_load_history(table, status):
    destTable = table["destination"]["table"]

    retries = 0
    isSuccessful = False

    while retries < retriesMax:
        
        retries = retries + 1
        
        try:
            spark = get_spark_session()

            df = (
                spark.read
                    .format("jdbc")
                    .option("url", destination["url"])
                    .option("dbtable", """
                        (
                            SELECT
                                TableName,
                                LoadDate,
                                Status,
                                Details
                            FROM 
                                LoadHistory
                            WHERE
                                1 <> 1
                            
                            UNION

                            SELECT
                                [TableName] = '@TableName',
                                [LoadDate] = '@LoadDate',
                                [Status] = '@Status',
                                [Deatils] = NULL 
                        ) AS LoadHistory
                    """.replace(
                        "@TableName", 
                        destTable
                    ).replace(
                        "@LoadDate", 
                        newCutoff
                    ).replace(
                        "@Status", 
                        status
                    )
                )
                .option("driver", "com.microsoft.sqlserver.jdbc.SQLServerDriver")
                .load()
            )
                
            # write
            (
                df.write
                    .format("jdbc")
                    .mode("append")
                    .option("url", destination["url"])
                    .option("dbtable", "LoadHistory")
                    .option("driver", "com.microsoft.sqlserver.jdbc.SQLServerDriver")
                    .save()
            )

            retries = retriesMax

            isSuccessful = True

        except Exception as ex:            
            stop_spark_session()

            if retries >= retriesMax:
                raise ex

    if isSuccessful == False:
        raise Exception("Unable to insert load history")

def set_lastcutoff_date(table, initialLoadDate):
    destTable = table["destination"]["table"]
    lastCutoffDate = initialLoadDate
    
    loadHistoryTable = """
        (
            SELECT TOP 1
                * 
            FROM 
                LoadHistory 
            WHERE 
                TableName LIKE '@TableName' AND
                Status = 'Successful'
            ORDER BY
                LoadDate DESC
        ) AS LoadHistory
    """.replace(
        "@TableName", 
        destTable
    ).replace(
        "@CutoffDate", 
        newCutoff
    )

    retries = 0
    isSuccessful = False

    while retries < retriesMax:

        try:
            
            spark = get_spark_session()

            # read 
            df = (
                spark.read
                    .format("jdbc")
                    .option("url", destination["url"])
                    .option("dbtable", loadHistoryTable)
                    .option("driver", "com.microsoft.sqlserver.jdbc.SQLServerDriver")
                    .load()
            )

            if df.count() > 0:
                lastCutoffDate = df.collect()[0]["LoadDate"]

            retries = retriesMax

            isSuccessful = True
            
        except Exception as ex:
            stop_spark_session()

            if retries >= retriesMax:
                raise ex

    if isSuccessful == False:
        raise Exception("Unable to set lastcutoff date")
    
    return lastCutoffDate.strftime("%Y-%m-%d %H:%M:%S")

for table in tables:
    lastCutoff = set_lastcutoff_date(table, initialLoadDate)

    if datetime.strptime(newCutoff, "%Y-%m-%d %H:%M:%S") > datetime.strptime(lastCutoff, "%Y-%m-%d %H:%M:%S"):
        load_wwi_to_wh(
            table,
            lastCutoff
        )
        