In [0]:

cfg = {
    "config":
        {
            # ALL FIELDS ARE MANDATORY IN THIS DICTIONARY
             "jdbcHostname":"10.10.9.4" # Source Server Address
            ,"jdbcPort":1433    # Source Server Port
            ,"jdbcSourceDatabase":"SourceDB" # Source Server Database 
            ,"jdbcUsername":"USerName" # Source Server user
            ,"jdbcPassword":"password" # Source Server password
            ,"jdbcDriver":"com.microsoft.sqlserver.jdbc.SQLServerDriver" # Source Server JDBC Driver
            ##
            #,"jdbcDriver":"oracle.jdbc.driver.OracleDriver" # Oracle JDBC Driver
            #,"jdbcDriver":"com.ibm.db2.jcc.DB2Driver" # DB2
            #,"jdbcDriver":"org.mariadb.jdbc.Driver" # MariaDB/MySQL
            #,"jdbcDriver":"org.postgresql.Driver" # PostgreSQL
            ,"jdbcSourceSchema":"triyam" # Source Database Schema
            #,"targetDatabase":"koantek_parallel_load" Target databse for Unity Catalog implementation
            ,"targetSchema":"koantek_triyam_test" # Target Server Schema
            ,"parallelCores":8 # Number of vCPUs available in cluster
        }
,    "load":
        [
            {
                # ONLY table and keyField FIELDS ARE MANDATORY IN THIS DICTIONARY
                "table":"order_entry_fields" # source table name
                ,"keyField":"OE_FIELD_ID" # source keyfield column name (int, bigint or float)
             }
            ,
            {
                "table":"oe_format_fields"
                ,"keyField":"OE_FORMAT_ID"
                # OPTIONAL FIELDS IN THIS DICTIONARY - will increase load time significantly
                # NOTES
                # Clustering on a per table will rewrite table after parallel import
                # Clustering on a per table basis will disable Z-Order and Partitioning
                ,"partition":["ACCEPT_FLAG"] # partitioning column
                ,"zorder":["FIELD_SEQ"] # Z-Order column 
                ,"cluster":["EPILOG_METHOD"] # Liquid Clustering Columns
             }
        ]    
}
      

In [0]:

def ParallelJDBCLoad():
    json_cfg = dbutils.widgets.get("configuration_file") #URI to cfg JSON file
    with open(json_cfg, 'r') as openfile:
        json_cfg = ast.literal_eval(openfile.read())        
    dbcfg = json_cfg["config"]
    jdbcHostname = dbcfg["jdbcHostname"]
    jdbcPort = dbcfg["jdbcPort"]
    jdbcDatabase = dbcfg["jdbcSourceDatabase"]
    jdbcUsername = dbcfg["jdbcUsername"]
    jdbcPassword = dbcfg["jdbcPassword"]
    jdbcDriver = dbcfg["jdbcDriver"]
    jdbcSchema = dbcfg["jdbcSourceSchema"]
    targetSchema = dbcfg["targetSchema"]
    numPartitions = dbcfg["parallelCores"]
    jdbcUrl = f"jdbc:sqlserver://{jdbcHostname}:{jdbcPort};databaseName={jdbcDatabase};    user={jdbcUsername};password={jdbcPassword};TrustServerCertificate=True" 

    for k in json_cfg["load"]:
        tbl = k["table"]
        keyField = k["keyField"]
        partition = k.get("partition") if k.get("partition") else None
        z_order = ",".join( k.get("zorder")) if k.get("zorder") else None
        cluster = ",".join( k.get("cluster")) if k.get("cluster") else None
        print(f"{jdbcSchema}.{tbl}")
        tblconfig = (
            spark.read.jdbc(url=jdbcUrl, table=f"{jdbcSchema}.{tbl}")
                .select(
                    F.min(F.col(keyField)).alias("min_value"),
                    F.max(col(keyField)).alias("max_value"),
                    F.count("*").alias("row_count"))  )  
        # Extract the values from the result
        lower = int(tblconfig.collect()[0]["min_value"])
        #print(f"min_value {lower}")
        upper = int(tblconfig.collect()[0]["max_value"])
        #print(f"max_value {upper}")
        row_count = int(tblconfig.collect()[0]["row_count"])
        #print(f"row_count {row_count}")
        print(str(datetime.now()) + "> Starting to import " + tbl)
        dft = spark.read.jdbc(url=jdbcUrl, table=f"{jdbcSchema}.{tbl}"
                            , column=keyField, lowerBound=lower, upperBound=upper, numPartitions=numPartitions)
        tries = 3

        for i in range(tries):
            try:
                start_time =    str(datetime.now())
                #spark.catalog.setCurrentDatabase(targetDatabase)
                if(partition and cluster == None):
                    print(str(start_time) + f"> Partitioning {targetSchema}.{tbl} by {partition}")
                    dft.write.format("delta").mode("overwrite").option("overwriteSchema", "true").partitionBy(partition).saveAsTable(f"{targetSchema}.{tbl}")
                else:
                    dft.write.format("delta").mode("overwrite").option("overwriteSchema", "true").saveAsTable(f"{targetSchema}.{tbl}")
            except Exception as e:
                if i < tries - 1: 
                    print(str(datetime.now()) + f"> Exception attempting to sync for {targetSchema}.{tbl} - retrying in 15 seconds. Exception: {e}")
                    time.sleep(15)
                    continue
                else:
                    print(str(datetime.now()) + f"> MAX number of retries reached, Exception: {e}")
                raise
            break
        if(cluster):
            print(str(datetime.now()) + f"> Clustering {targetSchema}.{tbl} by {cluster}")
            spark.sql(f"CREATE TABLE {targetSchema}.{tbl}_temp CLUSTER BY ({cluster}) AS SELECT * FROM {targetSchema}.{tbl}")
            spark.sql(f"DROP TABLE {targetSchema}.{tbl}")
            spark.sql(f"ALTER TABLE {targetSchema}.{tbl}_temp RENAME TO {targetSchema}.{tbl}")            
        if(z_order and cluster == None):
            print(str(datetime.now()) + f"> Optimizing {targetSchema}.{tbl} by {z_order}")
            spark.sql(f"OPTIMIZE {targetSchema}.{tbl} ZORDER BY ({z_order})")
        end_time = str(datetime.now())
        end_rowcount = spark.read.table(f"{targetSchema}.{tbl}").select(
            F.count("*").alias("row_count")).collect()[0].asDict()["row_count"]
     
        print(end_time + f"> Import completed for {targetSchema}.{tbl}")
        new_row = Row(
        f"{targetSchema}.{tbl}" # table_name	STRINg
        , start_time# ,load_begin	TIMESTAMP
        , end_time # ,load_end	TIMESTAMP
        , row_count # ,begin_rowcount INT
        , end_rowcount # ,end_rowcount INT
        , end_rowcount/(
            (datetime.strptime(end_time, "%Y-%m-%d %H:%M:%S.%f") - datetime.strptime(start_time, "%Y-%m-%d %H:%M:%S.%f")).total_seconds()
            )# ,rows_per_minute FLOAT
        # Insert the row into the table
        )

        # Convert the row to a DataFrame
        new_row_df = spark.createDataFrame([new_row])
        new_row_df.write.insertInto(f"{targetSchema}.load_metrics", overwrite=False)"
      