In [0]:
%run "./reader_factory"

In [0]:
from pyspark.sql.window import Window
from pyspark.sql.functions import col, lead, broadcast, collect_set, size, array_contains

class Transformer:

    def __init__(self):
        pass

    def transform(self, inputDFs):
        pass

class AirpodsAfterIphoneTransformer(Transformer):

    def transform(self, inputDFs):
        transactionInputDF = inputDFs.get("transactionInputDF")

        windowSpec = Window.partitionBy("customer_id").orderBy("transaction_date")
        transformedDF = transactionInputDF.withColumn(
            "next_product_name", lead("product_name").over(windowSpec)
        )
    
        filteredDF = transformedDF.filter(
            (col("product_name") == "iPhone") & (col("next_product_name") == "AirPods")
        )

        customerInputDF = inputDFs.get("customerInputDF")
        joinDF = customerInputDF.join(
            broadcast(filteredDF),
            "customer_id"
            )

        return joinDF.select("customer_id", "customer_name", "location")

class OnlyAirpodsAndIphoneTransformer(Transformer):

    def transform(self, inputDFs):

        transactionInputDF = inputDFs.get("transactionInputDF")
        groupedDF = transactionInputDF.groupBy("customer_id").agg(
            collect_set("product_name").alias("products")
        )

        filteredDF = groupedDF.filter(
            (array_contains(col("products"), "iPhone")) & 
            (array_contains(col("products"), "AirPods")) &
            (size(col("products")) == 2)
        )
        
        customerInputDF = inputDFs.get("customerInputDF")
        joinDF = customerInputDF.join(
            broadcast(filteredDF),
            "customer_id"
            )
        
        return joinDF.select("customer_id", "customer_name", "location")
            