In [0]:
"""
Dec 01, 2020: When sqlrefresh true then newly added 'sqlrefresh_query' parameter can refresh only 7 days prior data from SQL server. 
Note: If sqlrefresh_query = "" or sqlrefresh_query is not provided in inputjson then entire SQL table will be refreshed into ADLS.

Dec02, 2020: when sqlrefresh true/false save SQL table in temporary ADLS table.
overwrite will directly overwrite ADLS and SQL data
Append: SQL refresh false: Incoming data will append to ADLS and SQL data
Append: SQL refresh true: SQL data will merge into ADLS and then incoming data will merge into ADLS and SQL
upsert: if sqlrefresh is true or false then SQL table data will be loaded into temp table.
  sqlrefresh: false: incoming data will merge into ADLS and then merge into internal temp(sql refresh data) with name outputDB.outputTable+"_temp". Temp table extract will load into sql
  sqlrefresh: true: incoming data will merge into ADLS and then merge into temp(sql refresh data). Temp table extract will load into sql. After this Temp table extract will merge into ADLS table
  - Temp table will be dropped once load completed.
then merge final sql table data into ADLS
#Dec, 12, 2020: 1. sqlSchema parameter indicates schema name of the table. if not provided or blank then default "dbo" schema will be used. 
                2. truncate parameter is added for truncate in sql table write. It is added incase of overwrite and if truncate option should be 'False'.
*****
"""

In [0]:
##This is to laod all required libraries
from functools import reduce
from typing import Dict, List, Any
from datetime import datetime
from pyspark.sql.window import Window
from pyspark.sql.types import ArrayType, StringType, MapType, StructField, StructType, FloatType
from pyspark.sql import *
from delta.tables import *
import pkg_resources
import os, glob 
import pyspark.sql.functions as F
import pyspark.sql.types as T
import importlib
import logging
import pkgutil
import json


In [0]:
%run ../functions/Functions_delta

In [0]:
try:
  trans_name
except:
  trans_name =""
input_file =dbutils.widgets.get("input_file")
#get input parameters
with open(input_file, 'r') as file:
  data = file.read()
jsonObject = json.loads(data)
if (trans_name == "jointransform"):  
  inputs: Dict[str, Dict[str, Any]] = jsonObject["jointransform"]
else:
  inputs: Dict[str, Dict[str, Any]] = jsonObject["transform"]

try:
  sort_key = inputs["sort_key"] # mandatory parameter for getting latest records only. avoids duplicate records
  if len(sort_key) == 0 or sort_key =="":
    sort_key = ['DateInserted']
except:
  sort_key = ['DateInserted']
loadstrategy = inputs["loadstrategy"]

# business_key parameter is not mandatory for loadstrategies APPEND and OVERWRITE, unless you want to remove duplicates in the same load. For UPSERT, business_key is mandatory.
try:
  business_key= inputs["bussiness_keys"]
  if business_key==[]:
    business_key=""
except:
  business_key = ""

# Add if extract needs to be partitioned
try:
    extractPartition = inputs["extractPartition"]
except: 
    extractPartition = []  

# Add if extract needs to have a specific delimiter
try:
  extractDelimiter = inputs['extractDelimiter']
except:
  extractDelimiter = ','

# If flag copytoTemp is true, then data to be dumped to SQL would be written to a temp location.
try:
  if inputs['copytoTemp'].lower() == 'true':
    copytoTemp = True
  else:
    copytoTemp = False
except: 
  copytoTemp = False

non_align_columns=inputs["non_align_columns"]
 

inputTable=inputs["inputTable"]
outputDB=inputs["outputDB"] # database name 'transform' for transform notebook, 'datamart' for join notebook
outputTable=inputs["outputTable"]
outputExtract=inputs["outputExtract"] #save as table or xml
  
try: 
  remove_keys = inputs['remove_keys'] #adding identity column
except:
  remove_keys = ""
  
## This is only applied to Mongo data coming for Neutrinos, should not be used for any other use case
try: 
  unpivot = inputs["unpivot"] #separating try except blocks
except: 
  unpivot = 'False'
   
try:
  outputFileformat=inputs["outputFileformat"] #separating try except blocks
except:
  outputFileformat = "" 
   
adlsPath=inputs["adlsPath"]
adlsfolder=inputs["adlsfolder"] #folder inside tranform container on adls eg. sap/
hour_partition=inputs["hour_partition"]
date_column= "N" if not inputs["date_column"].strip() else inputs["date_column"].strip()

try:
  dateFormat = inputs["dateFormat"].strip() #adding dateFormat variable if not present in json input 
except:
  dateFormat = "yyyy-MM-dd"

##This variable is timestamp format you want to use 
try:
  tsFormat = inputs["tsFormat"] #adding tsFormat variable 
except:
  tsFormat = ""

xmlpath= "/mnt/"+outputDB+"/" + adlsfolder+outputTable +"_xml" #path where xml file will be saved
try:
  repartition="False" if not inputs["repartition"].strip() else inputs["repartition"].strip()
except:
  repartition="False"
 

In [0]:
# stage and transform_path variables stores value for stage and transform/datamart(outputDB) ADLS location
stage = "abfss://deltastage@"+adlsPath + "/"+inputTable
transform_path = "abfss://"+outputDB+"@"+adlsPath+ "/"+ adlsfolder+ outputTable
_timestamp:datetime = datetime.now()

In [0]:
# If partition_keys are not in input json then partition_keys will be taken from inputTable if on single table is used as source table.
partition_keys =[]
partition_keys =inputs["partition_keys"]
if len(partition_keys)==0:
  try:
    targetPartition = spark.sql("show partitions deltastage."+ inputTable).collect()
    targetPartitions = targetPartition[0]["partition"]
    partitions1=targetPartitions.split("/")  
    for column in partitions1:
      partition_keys.append(column.split("=")[0])
  except:
    print("No partitions")
    
if hour_partition == "True":
      partition_keys.append("hour")
print(partition_keys)

In [0]:
if outputExtract.lower() == 'sqltable':
  tsFormat = inputs["tsFormat"]   #timestamp format 
  dateFormat= inputs["dateFormat"]
  sqlconnectionFile = "/dbfs/mnt/deltajobinputs/"+ inputs["sqlconnectionFile"]
  try:
    sqlrefresh = inputs["sqlrefresh"]
  except:
    sqlrefresh = 'False'
  try:
    sqlrefresh_query = inputs["sqlrefresh_query"]
  except:
    sqlrefresh_query = ""
  #Dec, 12, 2020: sqlSchema parameter indicates schema name of the table. if not provided or blank then default "dbo" schema will be used.  
  try:    
    sqlSchema = inputs["sqlSchema"]
    if sqlSchema == "":
      sqlSchema = "dbo"
  except:
    sqlSchema = "dbo"
    
  #get input parameters

  with open(sqlconnectionFile, 'r') as file:
    data = file.read()
  connObject = json.loads(data)
  
  conn: Dict[str, Dict[str, Any]] = connObject["SQLserver"]  

  jdbcHostname= conn["jdbcHostname"]
  jdbcPort= conn["jdbcPort"]
  jdbcDatabase= conn["jdbcDatabase"]
  connectionProperties= conn["connectionProperties"]

  # Create the JDBC URL without passing in the user and password parameters.
  jdbcUrl = "jdbc:sqlserver://{0}:{1};database={2}".format(jdbcHostname,jdbcPort,jdbcDatabase)

In [0]:
 # In case of upsert, table should be created first. If table does not exists then it will load same as overwrite. 
target =""
try: 
  df_existing = spark.sql("select * from "+outputDB+"."+ outputTable)
except:
  target = False  
  loadstrategy = "overwrite"

In [0]:
if outputExtract.lower() == 'sqltable':  
    try:
      #Dec 01, 2020: when sqlrefresh is true and if sql refresh is needed only for last 7 days data then sqlrefresh_query input parameter is mandatory.
      if (sqlrefresh.lower() == 'true'):
        if (sqlrefresh_query != ""):
          pushdown_query = "("+sqlrefresh_query+") sqltable"          
        else:
          pushdown_query = "(select * from " + sqlSchema +"."  + outputTable +") sqltable"
      else:    
          pushdown_query = "(select * from " + sqlSchema +"." + outputTable +") sqltable"
      # sql data will be refreshed at the begining and sql_df dataframe will be required while loading data back to the SQL
      sql_df = spark.read.jdbc(url=jdbcUrl, table=pushdown_query, properties=connectionProperties)
      sql_schema = sql_df.schema
    except:
      print("SQL table should be created in SQL server")
    
    
    if target != False :      
    #adl_refresh is sql data where partition key columns will be added as Null 
      try:
        adls_refresh = sql_df
      except:
        print("sqlrefresh is True but table does not exists in SQL server")

      adls_refresh = add_datepartition(adls_refresh, date_column, dateFormat) #add_datepartition function will add partition column based on date to dataframe before loading to ADLS as using date_column

      # add partition key columns to dataframe other than date columns which are not available in SQL table
      datepartitions= ["date","month","year"]
      nondatepartition_keys =partition_keys[:]
      for dp in datepartitions:
        nondatepartition_keys.remove(dp)
      print(nondatepartition_keys)
      
      for nk in nondatepartition_keys:
        if nk not in sql_df.columns:
          adls_refresh = adls_refresh.withColumn(nk, F.lit('sql'))

      # align sql table columns with existing ADLS table
      adls_refresh = align_columns(adls_refresh,df_existing,add_missing=True)
      # Dec 10 2020: change adls_refresh datatypes according to existing tables data types
      adls_refresh = convert_datatype(df_existing.schema,adls_refresh,tableAppend='N',tableName=outputTable,dateFormat=dateFormat, tsFormat=tsFormat,no_stage_table='N')
      adls_refresh = adls_refresh.distinct().drop(remove_keys)    
      
      #Dec 2, 2020: If append and sqlrefresh true then merge SQL data with ADLS. business_key parameter is mandatory. 
      if loadstrategy.lower() == 'append' and sqlrefresh.lower() == 'true':
        if type(business_key)==list and len(business_key)!=0:
          #this function generates required conditions for merge by using dataframe that needs to merge and business keys
          mergejoin, whenMatchedUpdateset, whenNotMatchedUpdateset = merge_inputs(df_existing,business_key)
          deltaTable = DeltaTable.forPath(spark, transform_path)
          deltaTable.alias("existing").merge(
              adls_refresh.alias("incoming"),
             mergejoin) \
            .whenMatchedUpdate(set = whenMatchedUpdateset ) \
            .whenNotMatchedInsert(values = whenNotMatchedUpdateset
            )\
          .execute()
        else:
          if partition_keys != []:
            adls_refresh.write.format("delta").mode("append").partitionBy(partition_keys).option("path", transform_path).saveAsTable(outputDB+"."+ outputTable)
          else:
            adls_refresh.write.format("delta").mode("append").option("path",  transform_path).saveAsTable(outputDB+"."+ outputTable)
    #Dec2, 2020: if upsert then load SQL data into internal temp delta tables
    
    if loadstrategy.lower() == 'upsert':
      if partition_keys != []:
        adls_refresh.write.format("delta").mode("overwrite").partitionBy(partition_keys).saveAsTable(outputDB+"."+ outputTable + "_temp")
      else:
        adls_refresh.write.format("delta").mode("overwrite").saveAsTable(outputDB+"."+ outputTable + "_temp")
    

In [0]:
# mode is append for append and overwrite for upsert and overwrite in case of SQL extract. ADLS overwrite will have mode overwrite but upsert will be done by merge stametment.
if loadstrategy.lower() == 'append':
  mode = 'append'
  truncate = 'False'
else:
  mode = 'overwrite'
  truncate = 'True'

In [0]:
#This creates temporary tables for sources like CSV, SQLtable, Query, dataframe and adls table. These temporary tables can be used in transformation query. 
sources: Dict[str, Dict[str, Any]] =inputs["sources"]
if len(sources)!=0:
  input_s: Dict[str, pyspark.sql.DataFrame] = extract(sources)
  for key in input_s:
     input_s[key].registerTempTable(key)

In [0]:
# Input parameter "query" is used for fetching incoming data 
df_incoming = spark.sql(inputs["query"]).alias("df_incoming")
incomingschema = df_incoming.schema

In [0]:
"""unpivot_fields function is to unpivot incoming data, it takes the 3 input parameters 
# 1. Dataframe which will be unpivoted
# 2. Column name for Unpivoted Column names
# 3. Column name for Unpivoted Column values
# this is for Mongodb-Form_data. """
  
if unpivot == "True":
    df_incoming = unpivot_fields(df_incoming, inputs["unpivot_columns"][0], inputs["unpivot_columns"][1])
    df_incoming.registerTempTable(outputTable)    
    incomingschema=df_incoming.schema

In [0]:
# if date_column is current_date then partitions will be based on current date
df_incoming = add_datepartition(df_incoming, date_column, dateFormat) #add_datepartition function will add partition column based on date to dataframe before loading to ADLS as using date_column 

if ((hour_partition == "True") & (date_column=="current_date")):
    df_incoming=df_incoming.withColumn("hour", F.hour(F.lit(_timestamp)))
elif ((hour_partition == "True") & (date_column!="current_date")):
  df_incoming=df_incoming.withColumn("hour", F.hour(F.col(date_column)))

In [0]:
# It adds dateInserted and DateUpdated to incoming data and get latest rows from incoming data. Also, It convert datatype same as existing table.
df_incoming = df_incoming.withColumn("DateInserted", F.lit(_timestamp)) \
                        .withColumn("DateUpdated", F.lit(_timestamp))

try:
  df_incoming = get_latest_row(df_incoming, business_key, sort_key) 
except:
  print("business_key and sort_key are required for removing duplicates")

try:
  df_incoming = convert_datatype(df_existing.schema,df_incoming,tableAppend='N',tableName=outputTable,dateFormat=dateFormat, tsFormat=tsFormat,no_stage_table='N')
except:
  target = False


In [0]:
""""this function generates required conditions for merge: 1. mergejoin: merge join condition based on business key 2. whenMatchedUpdateset: when business key columns matched then replace existing data values with incoming data values. DateInserted will remain same as existing. 3. when business key columns do not match then insert incoming data to existing table. """

if loadstrategy.lower() == 'upsert':  
  mergejoin, whenMatchedUpdateset, whenNotMatchedUpdateset = merge_inputs(df_existing,business_key)

In [0]:
if (loadstrategy.lower() == 'upsert') and ((outputExtract.lower() != 'sqltable') or (sqlrefresh.lower() == 'false')):
  deltaTable = DeltaTable.forPath(spark, transform_path)
  deltaTable.alias("existing").merge(
      df_incoming.alias("incoming"),
     mergejoin) \
    .whenMatchedUpdate(set = whenMatchedUpdateset ) \
    .whenNotMatchedInsert(values = whenNotMatchedUpdateset
    )\
  .execute()

In [0]:
#Dec 2, 2020: when SQLtable then upsert incoming data with temp table created from SQL table data
if (loadstrategy.lower() == 'upsert') and (outputExtract.lower() == 'sqltable') :
    temp_path = "dbfs:/user/hive/warehouse/"+outputDB.lower()+".db/"+outputTable.lower() + "_temp"
    
    deltaTable = DeltaTable.forPath(spark, temp_path)
    deltaTable.alias("existing").merge(
      df_incoming.alias("incoming"),
      mergejoin) \
    .whenMatchedUpdate(set = whenMatchedUpdateset ) \
    .whenNotMatchedInsert(values = whenNotMatchedUpdateset)\
    .execute()

In [0]:
if loadstrategy.lower() != 'upsert':
  if partition_keys != []:
    df_incoming.write.format("delta").mode(mode).partitionBy(partition_keys).option("path",  transform_path).saveAsTable(outputDB+"."+ outputTable)
  else:
    df_incoming.write.format("delta").mode(mode).option("path",  transform_path).saveAsTable(outputDB+"."+ outputTable)


In [0]:
if outputExtract.lower() == 'xml':
  #Dec 16,2020: Added default values for rootTag and rowTag for xml load
  try:
    rootTag = inputs["xmltags"]["rootTag"]
    if (rootTag == "") or (inputs["xmltags"] == ""):
      rootTag = 'ROWS'
  except:
    rootTag = 'ROWS'
  try:
    rowTag = inputs["xmltags"]["rootTag"]
    if (rowTag == "") or (inputs["xmltags"] == ""):
      rowTag = 'ROW'
  except:
    rowTag = 'ROW'
  # write to xml file. 
  dumpxml(df_incoming, outputTable, xmlpath, rootTag, rowTag, mode)  
elif outputExtract.lower() == 'csv':
  CSV_path = "abfss://" + outputDB + "@" + adlsPath + "/" + adlsfolder + outputTable + "_CSV"
  #transform_path is the folder location on ADLS 
  if repartition == "True":
    df_incoming.repartition(1).write.partitionBy(extractPartition).mode("overwrite").format("csv").option("header", True).option("inferschema", True ).option("delimiter", extractDelimiter).option("path", CSV_path).save()
  else:
    df_incoming.write.partitionBy(extractPartition).mode("overwrite").format("csv").option("header", True).option("inferschema", True ).option("delimiter", extractDelimiter).option("path", CSV_path).save()

In [0]:
if outputExtract.lower() == 'sqltable':
  # Dec 2, 2020: in case of append and ovewrite incoming data will be overwritten/appended in ADLS and SQL server
  if loadstrategy.lower() != 'upsert':
    final_schema = df_incoming.schema
  else:
    df_temp_extract = spark.sql("select * from "+ outputDB+"."+ outputTable+ "_temp")    
    final_schema = df_temp_extract.schema
    
  #sql_df is dataframe created using SQL DB data
  sql_schema = sql_df.schema
  
###################################################################################################################################
  #if sql db columns and ADLS table columns are small/capital 
  if  final_schema != sql_schema:
    for sql in sql_schema.fields:
      for fin in final_schema.fields:
        if sql.name.lower() == fin.name.lower():
          # Dec 2, 2020: in case of append and ovewrite incoming data will be overwritten/appended in ADLS and SQL server
          if loadstrategy.lower() != 'upsert':
            df_sqldump = df_incoming.withColumn(sql.name, F.col(fin.name))            
          else:
            df_sqldump = df_temp_extract.withColumn(sql.name, F.col(fin.name))
            
# output dataframe will be aligned to sql table, add columns which are missing. Convert data types of the output dataframe to match with SQL table. 
# get_latest_row function will take latest record and remove business key columns duplicates.
  df_sqldump = align_columns(df_sqldump,sql_df,add_missing=True)
  df_sqldump = convert_datatype(sql_schema,df_sqldump,tableAppend='N',tableName=outputTable,dateFormat=dateFormat, tsFormat=tsFormat,no_stage_table='N')
  df_sqldump= df_sqldump.distinct().drop(remove_keys) 
  # Dec 12, 2020: Get varchar and nvarchar column names and their maximum lengths.    
  schema_query="(select COLUMN_NAME, CHARACTER_MAXIMUM_LENGTH from INFORMATION_SCHEMA.COLUMNS where TABLE_NAME='"+outputTable+"' and TABLE_SCHEMA='"+sqlSchema+"' and DATA_TYPE in ('nvarchar','varchar')) schematable"
  sql_stringSchema = spark.read.jdbc(url=jdbcUrl, table=schema_query, properties=connectionProperties)
  
  # Dec 12, 2020: trim the string columns for whitespaces and truncate the value if it is greater than maximum length in SQL DB table
  for sqlcol in sql_stringSchema.rdd.collect():
    if sqlcol[0] in df_sqldump.columns:
      if int(sqlcol[1])!=-1:
        df_sqldump = df_sqldump.withColumn(sqlcol[0],F.trim(F.col(sqlcol[0])).substr(1,int(sqlcol[1])))    
  try:
      df_sqldump = get_latest_row(df_sqldump, business_key, sort_key)
  except:
      print("WARNING: Duplicate records can be in final SQL data") 
  
  # dumpsqlserver fumction loads data to sql server database. mode will be overwrite for upsert and overwrite. Defined in cmd 10
  #Dec, 12, 2020: sqlSchema parameter indicates schema name of the table. if not provided the default "dbo" schema will be used.
  
  #Jan, 29, 2020: If flag copytoTemp is true, then data to be dumped to SQL would be written to a temp location.
  if copytoTemp:
    temp_path = "abfss://"+outputDB+"@"+adlsPath+ "/"+ adlsfolder+ outputTable + '/Temp'
    spark.sql("drop table if exists "+ outputDB+ "." + "Temp_" + outputTable)
    df_sqldump.write.mode(mode).option("path", temp_path).saveAsTable(outputDB+ "." + "Temp_" + outputTable)
  else:
    dumpsqlserver(df_sqldump,jdbcUrl,sqlSchema +"." +outputTable,mode,truncate,connectionProperties)

  #Dec 2, 2020: if SQL refresh is 'True' and  loadstrategy is 'Upsert' then final SQL data will be merge into ADLS table
  if (target != False)  and ((sqlrefresh.lower() == 'true') and loadstrategy.lower() == 'upsert'):
    df_temp_extract = convert_datatype(df_existing.schema,df_temp_extract,tableAppend='N',tableName=outputTable,dateFormat=dateFormat, tsFormat=tsFormat,no_stage_table='N')
    #merge condition will be based on df_existing schema
    mergejoin, whenMatchedUpdateset, whenNotMatchedUpdateset = merge_inputs(df_existing,business_key)
    deltaTable = DeltaTable.forPath(spark, transform_path)
    deltaTable.alias("existing").merge(
      df_temp_extract.alias("incoming"),
     mergejoin) \
    .whenMatchedUpdate(set = whenMatchedUpdateset ) \
    .whenNotMatchedInsert(values = whenNotMatchedUpdateset
    )\
  .execute()

  if loadstrategy.lower() == 'upsert':
    # drop temp table created from SQL table data.
    spark.sql("drop table "+ outputDB+"."+ outputTable+ "_temp")
    