In [0]:
import pyspark.sql
from pyspark.sql.functions import udf
import pyspark.sql.functions as F
import pyspark.sql.types as T
from pyspark.sql.functions import col, when, lit, coalesce,count 
from pyspark.sql.types import ArrayType, StringType, MapType, StructField, StructType, FloatType
from pyspark.sql.window import Window
import sys
from functools import reduce
from typing import Dict, List, Any
from datetime import datetime
import pkg_resources
import os, glob
import importlib
import logging
import pkgutil
import re

In [0]:
# this UDF is to fetch time from measure value and concatenate to reportdate
def reportdate_shifttime( reportdate, time):
  reportdate = (str(reportdate)+" " + str(time))
  return reportdate
spark.udf.register("reportdate_shifttime", reportdate_shifttime)

In [0]:
#Function to get values from Json input file and sets default value to parameter if mandatory
def default_inputs(jsonval, defaultval):
  try:
    jsonval= inputs[jsonval] #True if filename column is needed. i.e. name of file from which record comes
  except:
    jsonval = defaultval
  return jsonval

In [0]:
# Stage: Replace char function is to remove any unexpected charater from data in stage notebook
def replaceChar(character, df: pyspark.sql.DataFrame) -> pyspark.sql.DataFrame:
  charReplace = udf(lambda x: x.replace(character,'') if x is not None else x)
  for column in df.schema.fields:
    if isinstance(column.dataType, StringType):
      df=df.withColumn(column.name,charReplace(column.name)) 
  return df

In [0]:
def date_func(date):
  date_partition = "Region=na/year="+date[0:4]+"/month="+date[5:7]+"/date="+date[8:10]+"/"
  return date_partition 

In [0]:
def align_and_union(df1: pyspark.sql.DataFrame, df2: pyspark.sql.DataFrame,non_align_columns:list) -> pyspark.sql.DataFrame:
    # union requries that df columns are aligned. this function aligns and unions by sorting the columns
    return df1.select(sorted([colname for colname in list(set(df1.columns) - set(non_align_columns))])+non_align_columns).\
        union(df2.select(sorted([colname for colname in list(set(df2.columns) - set(non_align_columns))])+non_align_columns))

In [0]:
def convert_to_list(list):
  list_item = list.replace(" ", "")
  item_list = []
  for i in list_item.split(','):
    item_list.append(i)
  return item_list

In [0]:
def create_table_dataframe(table) -> pyspark.sql.DataFrame:
  tabledf= spark.sql("select * from " + table).alias(table.split(".")[1]) 
  return tabledf

# def create_sql_dataframe(table) -> pyspark.sql.DataFrame:
#   tabledf = spark.read.jdbc(url=jdbcUrl, table=table, properties=connectionProperties)
#   return tabledf

def create_csv_dataframe(file_loc,delimiter ) -> pyspark.sql.DataFrame:
  tabledf = spark.read.format("csv").option("header",True).option("inferSchema", True)\
.option("delimiter", delimiter).load(file_loc)
  return tabledf


In [0]:
def create_sql_dataframe(table, source_sql_conn) -> pyspark.sql.DataFrame:
  with open("/dbfs/mnt/deltajobinputs/"+source_sql_conn, 'r') as file:
    data = file.read()
  connObject = json.loads(data)

  conn: Dict[str, Dict[str, Any]] = connObject["SQLserver"]  

  jdbcHostname= conn["jdbcHostname"]
  jdbcPort= conn["jdbcPort"]
  jdbcDatabase= conn["jdbcDatabase"]
  connectionProperties= conn["connectionProperties"]

  # Create the JDBC URL without passing in the user and password parameters.
  jdbcUrl = "jdbc:sqlserver://{0}:{1};database={2}".format(jdbcHostname,jdbcPort,jdbcDatabase)
  tabledf = spark.read.jdbc(url=jdbcUrl, table=table, properties=connectionProperties)
  return tabledf

In [0]:
def dumpsqlserver(df,url,table,mode, truncate,properties):  
  df.write.mode(mode).option("truncate",truncate).jdbc(url=url, table= table, mode =mode, properties = properties)

In [0]:
def extract(sources: Dict[str, Any]) -> Dict[str, pyspark.sql.DataFrame]:
  inputs: Dict[str, pyspark.sql.DataFrame] = {}
  for alias, properties in sources.items():
    if properties["type"] == "table":
      df_input = create_table_dataframe(properties["source"])      
    elif properties["type"] == "query":
      df_input = spark.sql(properties["source"])
    elif properties["type"] == "sqltable":
      try:
        sqlconnection = properties["sqlconnection"]
      except:
        sqlconnection = sqlconnectionFile        
      df_input =  create_sql_dataframe(properties["source"], sqlconnection )     
    elif properties["type"] == "csv":
      df_input = create_csv_dataframe(properties["source"], properties["delimiter"])
    elif properties["type"] == "dataframe":
      df_input = properties["source"]
    inputs[alias] = df_input.alias(alias)
  return inputs

In [0]:
def transform( df_joined: pyspark.sql.DataFrame) -> pyspark.sql.DataFrame:
    """Transform the joined dataframe and return the target dataframe to be loaded into the DB
    """
    columns = [mapping["source"].alias(mapping.get("target")) for mapping in target_mappings]
    df_target = df_joined.select(columns)    
    return df_target


In [0]:
def input(sources, joins, target_mappings):
  inputs: Dict[str, pyspark.sql.DataFrame] = extract(sources)
  #get the plantID
  df_joined = join(inputs).distinct()  
  df_target = transform(df_joined)
  return df_target

In [0]:
def join(inputs: Dict[str, pyspark.sql.DataFrameReader]) -> pyspark.sql.DataFrame: 
  # set base dataframe
  source_alias = joins[0]["source"]
  df_joined: pyspark.sql.DataFrame = inputs[source_alias]
  # loop over join conditions and join dfs
  for join_op in joins[1:]:
    df_joined = df_joined.join(inputs[join_op["source"]],
                                       join_op.get("conditions"),
                                       how=join_op.get("type", "inner")
                                       )
  return df_joined

In [0]:
#updated 10/22/2020 added ` for columns name where space is in column names
def convert_datatype(schema,inputdf,tableAppend,tableName,dateFormat,tsFormat,no_stage_table='N')-> pyspark.sql.DataFrame:
    from pyspark.sql.types import IntegerType,DateType,TimestampType,DoubleType,LongType
    inputColumns=set(inputdf.columns)
    outputColumns=[]
    
    for column in schema.fields:
      if (tableAppend=='Y') & (no_stage_table=='N'):
        oldColumnName=tableName+'~'+column.name
      elif no_stage_table=='Y':
        oldColumnName=column.name
        column.name=column.name.replace(tableName+'~','')
      else:
        oldColumnName=column.name
      
      if (no_stage_table=='N') & (oldColumnName not in inputColumns):
          oldColumnName=column.name
          if oldColumnName not in inputColumns:
              inputdf=inputdf.drop(oldColumnName)
              continue
      if isinstance(column.dataType, StringType) :
         inputdf=inputdf.withColumn(column.name, F.col("`"+oldColumnName+"`").cast("string"))
       
      if isinstance(column.dataType, IntegerType) :
         inputdf=inputdf.withColumn(column.name, F.col("`"+oldColumnName+"`").cast("integer"))
          
      if isinstance(column.dataType, LongType) :
         inputdf=inputdf.withColumn(column.name, F.col("`"+oldColumnName+"`").cast(T.LongType()))
         
      if isinstance(column.dataType, DateType):
         inputdf =inputdf.withColumn(column.name, F.to_date(F.col("`"+oldColumnName+"`").cast("string"),dateFormat))

      if isinstance(column.dataType, TimestampType):
         inputdf =inputdf.withColumn(column.name, F.to_timestamp(F.col("`"+oldColumnName+"`"),tsFormat))
  
      if isinstance(column.dataType, DoubleType):
         inputdf =inputdf.withColumn(column.name, F.col("`"+oldColumnName+"`").cast("double"))

      if isinstance(column.dataType, FloatType):
         inputdf =inputdf.withColumn(column.name, F.col("`"+oldColumnName+"`").cast("float"))
      
      #if stage column type is decimal and with (scale, precision) eg. decimal(10,2)
      if str(column.dataType).startswith('DecimalType'):
         inputdf =inputdf.withColumn(column.name, F.col("`"+oldColumnName+"`").astype(str(column.dataType).replace("Type","")))
          
      if isinstance(column.dataType, StructType):
        for arr in column.dataType.fields:
          inputdf = inputdf.withColumn(column.name, F.explode(F.array(F.col("`"+column.name+"."+arr.name +"`"))))
          
      if isinstance(column.dataType, ArrayType):
        for arr in column.dataType.fields:
          inputdf = inputdf.withColumn(column.name, F.explode(F.col("`" +column.name+"."+arr.name +"`")))
          
      if (column.name!=oldColumnName):
              inputdf=inputdf.drop(oldColumnName) 
      
      outputColumns.append("`"+column.name+"`")
    inputdf=inputdf.select(outputColumns)
    return inputdf

In [0]:
# replace invalid character(s) space, ,;{}()\n\t= in column names with '_'
def getalias(df_incoming):
  for col in df_incoming.columns:
    orig_col = col
    col = col.strip().replace(" ", "_")
    for ch in [",", ";", "{", "}", "(", ")", "\n", "\t", "=", "."]:
      col = col.replace(ch, "") 
    df_incoming = df_incoming.withColumnRenamed(orig_col, col )
  return df_incoming

In [0]:
def add_partition(partition,df) -> pyspark.sql.DataFrame:
  partitions=partition.split("/")
  incoming_partition = []
  for column in partitions:
      incoming_partition.append(column.split("=")[0])
      df=df.withColumn(column.split("=")[0], F.lit(column.split("=")[1]))
  return df,incoming_partition

In [0]:
def getpartition(partition):
  partitions=partition.split("/")
  outpartition = []
  for column in partitions:
      outpartition.append(column.split("=")[0])
  return outpartition

In [0]:
def align_columns(df_source_raw, df_target_raw, add_missing=False, ignore_columns=[],non_align_columns=[]):
    if add_missing:
        # add NULL columns to align the df
        df_source = df_source_raw
        for column in df_target_raw.columns:
            if column not in df_source.columns and column not in ignore_columns:
                df_source = df_source.withColumn(column, F.lit(None).cast(df_target_raw.schema[column].dataType))
    else:
        df_source = df_source_raw

    # return only shared columns
    shared_existing_columns:List[str] = list(
        set(df_source.columns).\
        intersection(set(df_target_raw.columns).\
        union(set(ignore_columns)))
    )
    df_source = df_source.select(*shared_existing_columns+non_align_columns)

    return df_source

In [0]:
def add_missing_columns(schema, inputdf):
  # add NULL columns if columns are missing in incoming and present in stage
  for column in schema:
    if column.name not in inputdf.columns:
      inputdf = inputdf.withColumn(column.name, F.lit(None).cast(column.dataType))  
  return inputdf

In [0]:
def readcsv(inferSchema, delimiter,quote_char,escape_char,multiline, raw):
  if quote_char != 'N' and escape_char == 'N':
    df_incoming = spark.read.format("csv").option("header", True).option("inferSchema",inferSchema).option("delimiter",delimiter).option("quote", quote_char).load(raw)
  elif quote_char == 'N' and escape_char != 'N' and multiline.lower() == "true":
    df_incoming = spark.read.format("csv").option("header", True).option("inferSchema",inferSchema).option("delimiter",delimiter).option("multiline", True).option("escape", escape_char).load(raw)
  elif quote_char == 'N' and escape_char != 'N' and multiline.lower() == "false":
    df_incoming = spark.read.format("csv").option("header", True).option("inferSchema",inferSchema).option("delimiter",delimiter).option("escape", escape_char).load(raw)
  elif quote_char != 'N' and escape_char != 'N' and multiline.lower() == "false":
    df_incoming = spark.read.format("csv").option("header", True).option("inferSchema",inferSchema).option("delimiter",delimiter).option("quote", quote_char).option("escape", escape_char).load(raw)
  elif quote_char != 'N' and escape_char != 'N' and multiline.lower() == "true":
    df_incoming = spark.read.format("csv").option("header", True).option("inferSchema",inferSchema).option("delimiter",delimiter).option("quote", quote_char).option("multiline", True).option("escape", escape_char).load(raw) 
  elif(quote_char == 'N' and escape_char == 'N'):
    df_incoming = spark.read.format("csv").option("header", True).option("inferSchema",inferSchema).option("delimiter",delimiter).load(raw)
  return df_incoming

In [0]:
def file_name(filepath): 
  path_list = filepath.split("/")
  return path_list[len(path_list)-1] 

In [0]:
def pre_stage(df_incoming, tableAppend,tableName, dateFormat, tsFormat):
  df_incoming = getalias(df_incoming)
  no_stage_table='N'

  if replace_char !='N' :
    df_incoming = replaceChar(replace_char, df_incoming)     

  if len(out_partition)!=0:  
    df_incoming,stg_partition = add_partition(out_partition,df_incoming)
  else:
    stg_partition=[]
  try :
    df_existing_stage = spark.sql("SELECT * FROM deltastage."+tableName)
    schema=df_existing_stage.schema
  except:
    no_stage_table='Y'
    schema=df_incoming.schema    

  df_prestage=convert_datatype(schema,df_incoming,tableAppend=tableAppend,tableName=tableName,dateFormat=dateFormat, tsFormat=tsFormat,no_stage_table=no_stage_table)

  df_prestage = add_missing_columns(schema, df_prestage)
  if filename.lower() == "true":
    get_file_name = udf(file_name, StringType())

    df_prestage = df_prestage.withColumn("filename", F.regexp_replace(get_file_name(F.input_file_name()),"%20"," "))

  
  #This is to replace extra spaces from table name mostly for Excel sheet names
  tableName=tableName.replace(" ", "")
  
  if len(stg_partition) != 0:  
    df_prestage.write.format("delta").mode("overwrite").partitionBy(stg_partition).option("path",stage).\
            saveAsTable("deltastage."+tableName.replace(" ",""))
  else:
    df_prestage.write.format("delta").mode("overwrite").option("path",stage).saveAsTable("deltastage."+tableName.replace(" ",""))
  
  spark.sql("refresh table deltastage.`" + tableName.replace(" ","")+ "`")

  return df_prestage

In [0]:
def upsert_dataframe(df_existing,df_insert,df_update,primary_key:Dict[str,Any]):
    # join by primary key
    df_update_joined = (df_existing.alias("existing_inactive")).join(df_update.alias("new_update"),
    [
        (F.col(f"existing_inactive.{key}")==F.col(f"new_update.{key}")) 
        # &
        # (~F.isnull(F.col(f"incoming.{key}"))) 
        for key in primary_key.keys()
    ],
    how="leftouter")
    # we want to update, so populate the missing columns with existing data
    missing_cols_update = set(df_existing.columns)-set(df_update.columns)
    df_update_joined = df_update_joined.select("existing_inactive.*",*[F.col(f"new_update.{col}") for col in missing_cols_update])
    # add blank columns for any missing columns in the insert
    missing_cols_insert = set(df_existing.columns)-set(df_insert.columns)
    df_insert_joined = df_insert
    for column in missing_cols_insert:
        df_insert_joined = df_insert_joined.withColumn(column,F.lit(None).cast(df_existing.schema[column].dataType))

    # add blank columns for any missing columns in the update
    missing_cols_update = set(df_insert_joined.columns)-set(df_update_joined.columns)
    for column in missing_cols_update:
        df_update_joined = df_update_joined.withColumn(column,F.lit(None).cast(df_insert_joined.schema[column].dataType))

    return align_and_union(df_insert_joined,df_update_joined)

In [0]:
def get_max_value(df, column, default=0):
    max_value = df.select(F.max(F.col(column).cast("integer")).alias("MAX")).limit(1).collect()[0].MAX
    if (max_value is None):
        return default
    else:
        return max_value

In [0]:
def df_zipwithindex(df, offset=1, col_name="rowId"):
    '''
        Enumerates dataframe rows is native order, like rdd.ZipWithIndex(), but on a dataframe 
        and preserves a schema

        :param df: source dataframe
        :param offset: adjustment to zipWithIndex()'s index
        :param colName: name of the index column
    '''

    new_schema = StructType(
                    [StructField(col_name,T.LongType(),True)]        # new added field in front
                    + df.schema.fields                            # previous schema
                )

    zipped_rdd = df.rdd.zipWithIndex()

    new_rdd = zipped_rdd.map(lambda args: ([args[1] + offset] + list(args[0])))

    return new_rdd.toDF(new_schema)

In [0]:
def fill_auto_increment(
        df_existing: pyspark.sql.DataFrame,
        df_new: pyspark.sql.DataFrame,
        autoincrement_column: str):
    # add primary key from existing to new by shared business key
    # returns: df_new with a new primary key column
    #           for new entries, returns an autoincrement value
    try:
      max_id = get_max_value(df_existing,autoincrement_column,0)+1
    except:
      max_id=0
      
    df_combined = df_new.orderBy(autoincrement_column)
    df_combined = df_zipwithindex(df_combined,offset=max_id)
    df_combined = df_combined.withColumn(autoincrement_column,
        F.when( F.col(autoincrement_column).isNull(), 
                F.col("rowId"))\
                .otherwise(F.col(autoincrement_column))).drop("rowId")
    return df_combined

In [0]:
def get_latest_row(df, business_key, sort_key):
  descorderby=[]
  for sort in sort_key:
    descorderby.append(F.desc(sort))
  
  df = df.select(F.row_number().over(Window.partitionBy(business_key).orderBy(descorderby)).alias("row_num"),"*" )\
  .where("row_num == 1").drop("row_num")
  return df

In [0]:
def find_delta(df_existing, df_incoming,business_key, primary_key,non_align_columns=[]) -> pyspark.sql.DataFrame:
  exist_key=""
  #getting latest record when duplicate found in incoming
  try:
    df_incoming_raw = get_latest_row(df_incoming, business_key, sort_key)
  except:  
    df_incoming_raw = df_incoming.withColumn("curr_row_flg", F.lit('Y'))
  
  
  df_existing_raw = align_columns(df_existing,df_incoming_raw,ignore_columns=primary_key.keys(),non_align_columns=non_align_columns,add_missing=True)
  
  join_condition = [
        F.lower(F.coalesce(df_incoming_raw[business_key_column].cast("string"), F.lit(''))) == F.lower(F.coalesce(df_existing_raw[business_key_column].cast("string"), F.lit('')))
        for business_key_column in business_key]
  if len(join_condition) != 0:
    df_merged =df_incoming_raw.alias('incoming').join(df_existing_raw.alias('existing'), join_condition, 'leftouter').distinct()
  
  
    #New records in incremental
  df_insert = df_merged
  for business_key_col in business_key:
      df_insert = df_insert.filter(col("existing."+business_key_col).isNull())
  new_insert = df_insert.select("incoming.*").withColumn("curr_row_flg", F.lit('Y')).distinct()
  if  len(primary_key) != 0:
    if set(business_key) != set(primary_key.keys()):
    # add empty column for primary key
      for primary_key_col, primary_key_type in primary_key.items():
        new_insert = new_insert.withColumn(
            primary_key_col,
            F.lit(None).cast(primary_key_type)).distinct() 
  
    
  #Incremental Matching records with Initial load
  df_update = df_merged
  get_diff = df_existing_raw.columns
  common_key=[]
  uncommon_key=[]
  try:
    for non in business_key+non_align_columns+list(primary_key):
      get_diff.remove(non)
      common_key.append("existing."+non)
  except:
    print(non + " key not present")
  
  for non_key in get_diff:
    uncommon_key.append("incoming."+non_key)
    
  
  delt_cols=[]
  for i in get_diff:
    if i not in partition_keys:
      delt_cols.append("(incoming.`"+i+ "`!= existing.`"+i+"`)")
      delt_cols_diff = " or ".join(delt_cols)
  
  
  for business_key_col in business_key:
      df_update = df_update.filter(col("existing."+business_key_col).isNotNull())\
      .where(delt_cols_diff) 
  if  len(primary_key) != 0:    
    if set(business_key) != set(primary_key.keys()):
    # add empty column for primary key
      for primary_key_col, primary_key_type in primary_key.items():
        if exist_key =="":
          exist_key = "existing."+primary_key_col
        else:
          exist_key = exist_key + ","+"existing."+primary_key_col
          uncommon_key.append(exist_key)
    new_update = df_update.select(common_key+uncommon_key).withColumn("curr_row_flg", F.lit('Y')).distinct()
  else:
    new_update = df_update.select(common_key+uncommon_key).withColumn("curr_row_flg", F.lit('Y')).distinct()
          
  #Initial load records which are updated in incremental 
  existing_inactive = df_update.select("existing.*").withColumn("curr_row_flg", F.lit('Y')).distinct()
  
  if hour_partition !='True':
  #Records which don't have any update in incremental
    df_existing_raw = df_existing_raw.withColumn("curr_row_flg", F.lit('Y')).distinct()
  else:
    df_existing_raw = df_existing_raw.withColumn("curr_row_flg", F.lit('Y'))
    
  existing_active = df_existing_raw.select(sorted([colname for colname in df_existing_raw.columns]))\
  .subtract(existing_inactive.select(sorted([colname for colname in existing_inactive.columns]))\
           ).distinct()
  
  return new_insert,new_update,existing_inactive,existing_active

In [0]:
def get_param_value(input_param, index):  
  param = input_param.split(";")[index].split("=",1)[1]
  return param

In [0]:
def dumpxml(df_stage, tableName, path, rootTag, rowTag,mode):  
  df_stage.repartition(1).write.mode(mode).format("com.databricks.spark.xml").partitionBy(partition_keys).option("rootTag", rootTag).option("rowTag", rowTag).save(path)

In [0]:
def unpivot_fields(df, column0,column1 ):  
  pivot_cols = []
  id_cols =[]
  for i in df.columns:
    df = df.withColumn(i, F.col(i).cast("string")) 

  cnt = 0  
  for i in df.columns:
    if re.search("........-....-....-....-............", i):      
      cnt = cnt+1
      colnm = "'"+i+"'"
      cols = "`"+i+"`"
      pivot_cols.append(colnm)
      pivot_cols.append(cols)
    else:
      colnm = i
      id_cols.append(colnm)
      
  stack_str = ",".join(pivot_cols)
  id_cols.append("stack("+str(cnt)+","+stack_str+")")  
  
  unpivot = df.selectExpr(id_cols)\
            .withColumnRenamed("col0",column0)\
            .withColumnRenamed("col1",column1)
  
  return unpivot


In [0]:
def Loadtodatabricks(df, mode, partition_keys, transform_path, outputDB, outputTable,outputFileformat ):
  if partition_keys != []:
    if outputFileformat.lower() == "orc":
      df.write.mode(mode).partitionBy(partition_keys).format("orc").option("path", transform_path).saveAsTable(outputDB+"."+ outputTable)
    elif outputFileformat.lower() == "delta":
      df.write.mode(mode).partitionBy(partition_keys).format("delta").option("path",  transform_path).saveAsTable(outputDB+"."+ outputTable)
    else:
      df.write.mode(mode).partitionBy(partition_keys).option("path",  transform_path).saveAsTable(outputDB+"."+ outputTable)
  else:
    if outputFileformat.lower() == "orc":
      df.write.format("orc").mode(mode).partitionBy(partition_keys).option("path", transform_path).saveAsTable(outputDB+"."+ outputTable)
    elif outputFileformat.lower() == "delta":
      df.write.format("delta").mode(mode).partitionBy(partition_keys).option("path", transform_path).saveAsTable(outputDB+"."+ outputTable)      
    else:
      df.write.mode(mode).partitionBy(partition_keys).option("path",  transform_path).saveAsTable(outputDB+"."+ outputTable)
  sqlContext.sql("refresh table "+outputDB+"."+ outputTable)
  return 

In [0]:
def add_datepartition(df, date_column, dateFormat):  
  if date_column in ["current_date", "DateUpdated"]:
    df=df.withColumn("year", F.year(F.lit(_timestamp)))
    df=df.withColumn("month", F.month(F.lit(_timestamp)))
    df=df.withColumn("date", F.dayofmonth(F.lit(_timestamp)))
  elif date_column == 'N' :
    df =df
  else:
    df.drop('year','month','date')
    if dateFormat != "":
      df = df.withColumn("date_column", F.to_date(F.col(date_column).cast("string"), dateFormat))        
      df=df.withColumn("year", F.year(F.col("date_column")))
      df=df.withColumn("month", F.month(F.col("date_column")))
      df=df.withColumn("date", F.dayofmonth(F.col("date_column")))
      df=df.drop("date_column")
    else:
      df=df.withColumn("year", F.year(F.col("date_column")))
      df=df.withColumn("month", F.month(F.col("date_column")))
      df=df.withColumn("date", F.dayofmonth(F.col("date_column")))
  return df

In [0]:
""""this function generates required conditions for merge: 1. mergejoin: merge join condition based on business key 2. whenMatchedUpdateset: when business key columns matched then replace existing data values with incoming data values. DateInserted will remain same as existing. 3. when business key columns do not match then insert incoming data to existing table. """

def merge_inputs(df_existing,business_key ) :  
  mergejoin = ""
  tablecols = df_existing.columns
  for bk in business_key:
    mergejoin = mergejoin + "existing."+bk+" = incoming."+bk
    if bk != business_key[len(business_key)-1]:
      mergejoin =  mergejoin + " and "

    tablecols.remove(bk)

  tablecols.remove("DateInserted")
  whenMatchedUpdateset:Dict[str, Any] = {}
  for col in tablecols:
    whenMatchedUpdateset["existing.`"+col +"`"] ="incoming.`"+col+"`"

  whenNotMatchedUpdateset:Dict[str, Any] = {}
  for excol in  df_existing.columns:
    whenNotMatchedUpdateset["`"+excol+"`"] = "incoming.`"+excol+"`"

  return mergejoin, whenMatchedUpdateset, whenNotMatchedUpdateset

In [0]:
def report_Measure_calculation(stepdf, calc_level, reports_input):  
  df_report = spark.read.format("csv").option("header", True).option("inferSchema",True).option("delimiter",",").load(reports_input).na.fill(0).alias("df_report")  
  
  df_report_measured_val = df_report.alias("reports").join(stepdf.alias("source"), [F.col("reports.CalcName") == F.col("source.CalcName")], 'left' )
  df_report_measured_val = df_report_measured_val.selectExpr("ID", "PlantID","reports.CalcName",  "Calculation", "case when reports.CalcValue = 0 then source.CalcValue else reports.CalcValue end as CalcValue", "CalcDesc" ).na.fill(0)
  
  casestr =""
  finalstr=""
  for i in range(0, calc_level):
    df_report_pivot = df_report_measured_val\
        .groupby(F.col("PlantID"))\
        .pivot("CalcName")\
        .agg(F.avg("CalcValue"))
    df_report_measured = df_report_measured_val.join(df_report_pivot, ["PlantID"] ).orderBy("id") 
    caserdd = df_report_measured.rdd.collect()

    for i in range(0,  len(caserdd)):
        casestr =   " when CalcName = '" + str(caserdd[i]["CalcName"]) + "' then coalesce(nvl(" + str(caserdd[i]["Calculation"]) + " , CalcValue),0)" 
        finalstr = finalstr + casestr
    df_report_measured = df_report_measured.selectExpr("*","""case """+ finalstr.replace("None", "Null")+"""  else `CalcValue` end as `finaValue`""") 
    df_report_measured = df_report_measured.fillna(0).withColumn("CalcValue", F.col("finaValue")).select(df_report_measured_val.columns)
    df_report_measured = df_report_measured.withColumn("ReportDateID", F.date_format(F.lit(ReportDate).cast("string"),'yyyyMMdd'))
  df_report_measured = df_report_measured.where("CalcName is not null").selectExpr("PlantID", "ReportDateID","'' as CalcCategory", "'Daily' as CalcFrequency","CalcName", "cast(CalcValue as string)", "'Null' as UoM", "CalcDesc", "'Null' as TextInput")
  df_report_measured.registerTempTable("tempreport")
  return True


In [0]:
def parse_array_from_string(x):
    res = json.loads(x)
    return res

retrieve_array = F.udf(parse_array_from_string, T.ArrayType(T.MapType(T.StringType(),T.StringType())))

In [0]:
def pre_stage_formconfig(df_incoming, tableAppend,tableName, dateFormat, tsFormat, filename="false"):
  df_incoming = getalias(df_incoming)
  no_stage_table='N'

  if replace_char !='N':
    df_incoming = replaceChar(replace_char, df_incoming)     

  if len(out_partition)!=0:  
    df_incoming,stg_partition = add_partition(out_partition,df_incoming)
  else:
    stg_partition=[]
  try :
    df_existing_stage = spark.sql("SELECT * FROM deltastage."+tableName)
    schema=df_existing_stage.schema
  except:
    no_stage_table='Y'
    schema=df_incoming.schema    
    
  df_incoming=df_incoming.withColumn("column1", retrieve_array(F.col("columns"))).select("*", explode("column1").alias("col"))
  df_incoming=df_incoming.withColumn("name", F.col("col.name"))\
                          .withColumn("index", F.col("col.index"))\
                          .withColumn("Coluuid", F.col("col.uuid"))
  
  df_prestage=convert_datatype(schema,df_incoming,tableAppend=tableAppend,tableName=tableName,dateFormat=dateFormat, tsFormat=tsFormat,no_stage_table=no_stage_table)

  df_prestage = add_missing_columns(schema, df_prestage)
  #This is to replace extra spaces from table name mostly for Excel sheet names
  tableName=tableName.replace(" ", "")
  if len(stg_partition) != 0:  
    df_prestage.write.format("delta").mode("overwrite").partitionBy(stg_partition).option("path",stage).\
            option("overwriteSchema", True).saveAsTable("deltastage."+tableName.replace(" ",""))
  else:
    df_prestage.write.format("delta").mode("overwrite").option("path",stage).saveAsTable("deltastage."+tableName.replace(" ",""))
  
  spark.sql("refresh table deltastage.`" + tableName.replace(" ","")+ "`")

  return df_prestage