In [0]:
from pyspark.sql.window import *
from pyspark.sql.functions import col, lead, lag, broadcast, collect_set, collect_list, size, array_contains

In [0]:
class Transformer:
    def __init__(self):
        pass

    def transform(self, inputDFs):
        pass

class AirpodsAfterIphoneTransformer(Transformer):

    def transform(self, inputDFs):
        """
        # Customers who have brought Airpods after buying a iPhone
        """

        transactionInputDF = inputDFs.get("transactionInputDF")
        print("transactionInputDF in transform")

        transactionInputDF.show()

        windowSpec = Window.partitionBy("customer_id").orderBy("transaction_date")

        transformedDF = transactionInputDF.withColumn(
            "next_product_name", lead("product_name").over(windowSpec)
        )

        print("AirPods after buying iPhone")
        
        transformedDF.orderBy("customer_id", "transaction_date", "product_name").show()
        filteredDF = transformedDF.filter(
            (col("product_name") == "iPhone") & (col("next_product_name") == "AirPods")
            )

        filteredDF.orderBy("customer_id", "transaction_date", "product_name").show()

        customerInpurDF = inputDFs.get("customerInpurDF")
        customerInpurDF.show()
        
        joinDF = customerInpurDF.join(
            broadcast(filteredDF),
            "customer_id"
        )

        print("Joined DF")
        joinDF.show()

        # joinDF.select(
        #     "customer_id",
        #     "customer_name",
        #     "location"
        # ).show()

        return joinDF.select(
            "customer_id",
            "customer_name",
            "location"
        )


class OnlyAirpodsAndIphone(Transformer):
    def transform(self, inputDFs):
        """
        Customer who brought only iPhone and AirPods nothing else
        """
        transactionInputDF = inputDFs.get("transactionInputDF")
        print("transactionInputDF in transform")
        
        groupedDF = transactionInputDF.groupBy("customer_id").agg(
            collect_set("product_name").alias("products")
            )
        
        print("Grouped DF")
        groupedDF.show()

        filteredDF = groupedDF.filter(
            (array_contains(col("products"), "iPhone")) &
            (array_contains(col("products"), "AirPods")) &
            (size(col("products")) == 2)
        )
        print("Only Airpods and iPhone")

        filteredDF.show()
        # filteredDF.orderBy("customer_id", "transaction_date", "product_name").show()

        customerInpurDF = inputDFs.get("customerInpurDF")

        customerInpurDF.show()

        joinDF = customerInpurDF.join(
            broadcast(filteredDF),
            "customer_id"
        )

        print("Joined DF")
        joinDF.show()

        # joinDF.select(
        #     "customer_id",
        #     "customer_name",
        #     "location"
        # ).show()

        return joinDF.select(
            "customer_id",
            "customer_name",
            "location"
        )