###common to all stages

logging table shape 

process_id string not null ,
batchID string not null ,
file_name string not null ,
file_table_size string,
file_table_import_status string,
file_table_records bigint,
file_error_records bigint,
ingestion_timestamp timestamp,
load_comment string,
current_dttm timestamp

In [0]:
import os
import pathlib
import fnmatch
import pandas as pd
import boto3
import delta
import pyspark.sql.functions as F
from pyspark.sql.functions import input_file_name
import pyspark.sql.types as T
from pyspark.sql.types import StructType,StructField, StringType, IntegerType , DateType
from datetime import datetime
from delta.tables import DeltaTable


"""Class will be implemented for each batch. Read each batch data and perform A) For each batch calculate record count
B) get file name C  ) write to logger table for each batch"""

class deltawriter:
  validation_seq = list()
  method_map = dict()
  def __init__(self,batchdf, batchID, processid , sink_delta_table, dq_table , validation_seq, method_map , groupbykey , count_col):
    """TBD"""
    self.validation_seq = validation_seq
    self.method_map = method_map
    self.batchdf = batchdf
    self.batchID = batchID
    self.processid = processid
    self.sink_delta_table = sink_delta_table
    self.dq_table = dq_table
    #self.batchdf.createOrReplaceGlobalTempView("batchData")
    self.dq_config = spark.sql(f"select * from {self.dq_table}").filter(f"Active = 'y' and process_id = '{self.processid}'").toPandas()
    #print(self.dq_config)
    self.prim_keys = self.dq_config.loc[self.dq_config["is_primary_key"] == 'y'].col_name.tolist()
    #print(self.prim_keys )
    self.non_nullable_keys = self.dq_config.loc[self.dq_config["is_required"] == 'y'].col_name.tolist()
    self.parse_config_rules()
    
  def create_register_delta(self , stage):
        raw_schema = (
          self.batchdf.schema
        )
        source_delta_table = self.sink_delta_table.split('.')[1:][0]
        if not DeltaTable.isDeltaTable(spark,'/user/hive/warehouse/'+ str (source_delta_table.replace('_bronze',f"_{stage}"))):
            emptyDF = spark.createDataFrame(spark.sparkContext.emptyRDD(), raw_schema)          
            emptyDF.write.format('delta').mode('overwrite')\
            .save('/user/hive/warehouse/'+ str (source_delta_table.replace('_bronze','_silver')))
            
            #register table
            silver_table = str(source_delta_table.replace('_bronze','_silver'))
            silver_location = '/user/hive/warehouse/'+ silver_table
            spark.sql(f"CREATE TABLE {silver_table} USING DELTA LOCATION '{silver_location}'")
            
    
  def append_logging_delta(self):
      df_count =  self.batchdf.count()
      log_df = self.batchdf.select(F.lit(self.processid).alias("processid") , 
                         F.lit(self.batchID).alias("batchID"),
                         F.col('_filename').alias("filename"),
                         F.lit(self.sink_delta_table).alias("filesize"),
                         F.lit('processed').alias("file_table_import_status"),
                         F.lit(df_count).alias("filecount") , 
                         F.lit(0).alias("file_error_records"),  ##to be updated in silver stage
                         F.col ('_execute_timestamp').alias("ingestion_timestamp"),
                         F.lit('NA').alias('comment'),
                         F.lit(datetime.now()).cast(T.TimestampType()).alias("current_timestamp")
                        )
      (
        log_df.dropDuplicates().write.format("delta")
        .mode("append")#|"overwrite"
        #.partitionBy("date") # optional
        .option("mergeSchema", "true") # option - evolve schema
        .saveAsTable("streaming_log") #| .save("/path/to/delta_table")
      )

  def parse_config_rules(self):
    config_df = spark.sql(f"select * from {self.dq_table}").filter(f"Active = 'y' and process_id = '{self.processid}'").toPandas()
    #display(config_df)

    config_dict = {x:eval(y) for x,y in zip(config_df.col_name,config_df.formatting_rules)}
    #print(config_dict)

    key_set = {key for sub_keys in map(dict.values, config_dict.values()) for key in sub_keys}
    #print(key_sets)

    self.dq_config = {key:[col for col in config_dict if key in config_dict[col].values()] for key in key_set }

  def deduplicate_merge_silver(self, df):
    delta_sink = self.sink_delta_table.replace("_bronze","_silver").replace("default.","")
    deltaTableSilver = DeltaTable.forPath(spark, 'dbfs:/user/hive/warehouse/'+str(delta_sink))
    join_condition = ' AND'.join([f" silver.{key} = updates.{key}" for key in self.prim_keys])
    drop_duplicates = ' ,'.join([f"{key}" for key in self.prim_keys])
    df = df.dropDuplicates(self.prim_keys)
    deltaTableSilver.alias('silver').merge(\
    df.alias('updates'), f"{join_condition}").whenMatchedUpdateAll().whenNotMatchedInsertAll().execute()

  def writeDeltaBronze(self):
    self.batchdf.write.format('delta')\
        .mode('append')\
        .option("mergeSchema", "true")\
         .saveAsTable(self.sink_delta_table)
    



  def deduplicate_merge_gold(self, df ):
    groupbykey = self.groupbykey
    count_col = self.count_col
    
    df = df.select(F.expr(f"date({groupbykey})").alias(groupbykey)).groupby(groupbykey).count()
    
    df =  df.withColumn(groupbykey, F.col(groupbykey).cast(DateType()))\
    .withColumn("count", F.col("count").cast(IntegerType()))
    
    source_delta_table = self.sink_delta_table.split('.')[1:][0].replace('_bronze',"_gold").replace("default.","")
    spark.sql(f"""CREATE TABLE IF NOT EXISTS {source_delta_table}
    ({groupbykey}  date , count int  )""")
    join_condition = f"gold.{groupbykey} == aggregates.{groupbykey}"
    deltaTableGold = DeltaTable.forPath(spark, 'dbfs:/user/hive/warehouse/'+str(source_delta_table))
    deltaTableGold.alias('gold')\
    .merge(
      df.alias('aggregates'),
      f"{join_condition}"\
    ).whenMatchedUpdateAll().whenNotMatchedInsertAll().execute()
    

  
  def pk_check(self):
    df = self.batchdf
    for key in (self.prim_keys):
      if key not in df.columns:
        df = df.withColumn("_record_type", F.lit("invalid"))
    return df

    
  def null_check(self,df):      #check required fields
    required_fields =self.non_nullable_keys
    #dict_null = {field : df.filter(f"{field} is null").count() if df.filter(f"{field} is null").count() > 0 for field in required_fields}  
    for field in required_fields:
        df = df.withColumn("_record_type", F.when(F.col(field).isNull(),"invalid").otherwise("silver"))         
    return df  
  
  def modify_execute_timestamp(self, df):
    #modify execcute timestamp
    df = df.withColumn("_execute_timestamp", F.lit(datetime.now()).cast(T.TimestampType()))
    return df

  def apply_dq_checks(self, df):
    
    for trans in self.validation_seq:
      for colm in self.dq_config[trans]:
        df = getattr(self, self.method_map[trans])(df,colm) # (dq_config[trans])
    return df
    
  def to_upper(self,df, col_name):
    df = df.withColumn(col_name , F.upper(F.col(col_name)).alias(col_name.strip()))
    return df

  def to_convert_dttm_format(self, df, col_name):
    df = df.withColumn(col_name+str("_cnv"), F.date_format(F.col(col_name).cast(T.StringType()), 'dd-MMM-yyyy HH:MM:SS'))
    return df