In [0]:
from pyspark.sql.functions import explode, expr, sum, split, trim, lower, col, input_file_name

In [0]:
class Bronze():
  def __init__(self):
    self.base_data_dir = "/FileStore/test"

  def getSchema(self):
    return '''InvoiceNumber string, CreatedTime bigint, StoreID string, PosID string, CashierID string,
            CustomerType string, CustomerCardNo string, TotalAmount double, NumberOfItems bigint, 
            PaymentMethod string, TaxableAmount double, CGST double, SGST double, CESS double, 
            DeliveryType string,
            DeliveryAddress struct<AddressLine string, City string, ContactNumber string, PinCode string, State string>,
            InvoiceLineItems array<struct<ItemCode string, ItemDescription string, ItemPrice double, ItemQty bigint, TotalValue double>>'''

  def readInvoices(self):
    return ( spark.readStream
            .format("json")
            .schema(self.getSchema())
            .load(f"{self.base_data_dir}/data/invoices")
            .withColumn("InputFile", input_file_name())
          )

  def process(self):
    print("Starting Bronze Stream ...")
    invoicesDF = self.readInvoices()
    sQuery = ( invoicesDF.writeStream
              .queryName("bronze-ingestion")
              .format("delta")
              .option("checkpointLocation", f"{self.base_data_dir}/checkpoint/invoices_bz")
              .outputMode("append")
              .toTable("invoices_bz")
            )
    print("Done\n")
    return sQuery



In [0]:
class Gold():
  def __init__(self):
    self.base_data_dir = "/FileStore/test"
  
  def readBronze(self):
    return spark.readStream.table("invoices_bz")
  
  def getAggregates(self, temp_df):
    return ( temp_df.groupBy("CustomerCardNo")
            .agg(sum("TotalAmount").alias("TotalAmount"),
                 sum(expr("TotalAmount*0.02")).alias("TotalPoints"))
            )

  def saveResults(self, results_df):
    return ( results_df.writeStream
            .queryName("gold-processing")
            .format("delta")
            .option("checkpointLocation", f"{self.base_data_dir}/checkpoint/customer_rewards")
            .outputMode("complete")
            .toTable("customer_rewards")
          )
    
  def process(self):
    print("Starting Gold Stream ...")
    invoices_df = self.readBronze()
    aggregate_df = self.getAggregates(invoices_df)
    sQuery = self.saveResults(aggregate_df)
    return sQuery

