In [0]:
from pyspark.sql.functions import *
from pyspark.sql.types import *

In [0]:
stream_vat_df = spark.readStream.table("invoice_db.invoice_vat_data_stg")

In [0]:
stream_vat_df.display()

invoice_number,unique_invoice_identifier,account_number,customer_name,vat_amount,vat_breakdown_amount,tax_exemption_code,tax_exemption_amount,invoice_status,target_committime,load_month,load_date,batch_id
123462837-SS,f237165c-6dc2-4695-9635-6245f8913c54,123-456-658,Joseph Christ,250.0,100.0,ES,50.0,ADD,2024-07-21T07:58:04.651+0000,2024-07-01,2024-07-21,BATCH-2024-07-21-07_58_04
123462838-SS,f237165c-6dc2-4695-9635-6245f8913c54,123-456-659,Jane Christ,251.0,400.0,ES,55.0,ADD,2024-07-21T08:19:21.900+0000,2024-07-01,2024-07-21,BATCH-2024-07-21-08_19_21
123462835-SS,f237165c-6dc2-4695-9635-6245f8913c65,123-456-680,Alpha,351.0,400.0,ES,55.0,ADD,2024-07-21T08:30:27.614+0000,2024-07-01,2024-07-21,BATCH-2024-07-21-08_30_27
123462869-SS,l237165c-6dc2-4695-9635-6245f8913g57,123-456-659,Jane Christ,251.0,450.0,ES,55.0,EDIT,2024-07-21T08:30:27.614+0000,2024-07-01,2024-07-21,BATCH-2024-07-21-08_30_27


In [0]:
aggregated_df = stream_vat_df \
    .groupBy("account_number", "customer_name") \
    .agg(
        expr("sum(vat_amount) as total_vat_amount"),
        expr("sum(vat_breakdown_amount) as total_vat_brkdown_amt"),
        expr("sum(tax_exemption_amount) as total_tax_exmpt_amt")
    ) \
    .withColumn("target_committime", current_timestamp())

In [0]:
aggregated_df.display()

account_number,customer_name,total_vat_amount,total_vat_brkdown_amt,total_tax_exmpt_amt,target_committime
123-456-680,Alpha,351.0,400.0,55.0,2024-07-21T10:59:34.628+0000
123-456-658,Joseph Christ,250.0,100.0,50.0,2024-07-21T10:59:34.628+0000
123-456-659,Jane Christ,502.0,850.0,110.0,2024-07-21T10:59:34.628+0000


In [0]:
def update_table(vat_total_df,batch_id):
    vat_total_df.createOrReplaceTempView("vat_total_temp")
    merge_statement = """merge into invoice_db.invoice_total_vat_data_stg t using vat_total_temp s
    on t.account_number == s.account_number and t.customer_name == s.customer_name 
    when matched then
    update set
    t.total_vat_amount = t.total_vat_amount + s.total_vat_amount, 
    t.total_vat_brkdown_amt = t.total_vat_brkdown_amt + s.total_vat_brkdown_amt, 
    t.total_tax_exmpt_amt = t.total_tax_exmpt_amt + s.total_tax_exmpt_amt,
    t.target_committime = s.target_committime
    when not matched then
    insert *
    """
    vat_total_df._jdf.sparkSession().sql(merge_statement)

In [0]:
query = aggregated_df.writeStream \
    .foreachBatch(update_table) \
    .outputMode("update") \
    .option("checkpointLocation", "/FileStore/tables/invoices/invoice_total_vat_data_chkpt") \
    .start()

query.awaitTermination()