In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import collect_set
from pyspark.ml.param import Param, Params
from pyspark.ml import Estimator, Model
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
from pyspark.rdd import RDD
from typing import List, Tuple, Optional
from pyspark.sql import DataFrame
from pyspark import SparkContext

# PCYParams class for managing parameters
class PCYParams(Params, DefaultParamsReadable, DefaultParamsWritable):
    minSupport = Param(Params._dummy(), "minSupport", "Support threshold for frequent items")
    minConfidence = Param(Params._dummy(), "minConfidence", "Confidence threshold for association rules")
    numBuckets = Param(Params._dummy(), "numBuckets", "Number of buckets for hashing")
    itemsCol = Param(Params._dummy(), "itemsCol", "Name of items column")

    def __init__(self):
        super(PCYParams, self).__init__()
        self._setDefault(minSupport=10, minConfidence=0.2, numBuckets=100, itemsCol="items")

    def setMinSupport(self, value: int):
        return self._set(minSupport=value)

    def getMinSupport(self) -> int:
        return self.getOrDefault(self.minSupport)

    def setMinConfidence(self, value: float):
        return self._set(minConfidence=value)

    def getMinConfidence(self) -> float:
        return self.getOrDefault(self.minConfidence)

    def setNumBuckets(self, value: int):
        return self._set(numBuckets=value)

    def getNumBuckets(self) -> int:
        return self.getOrDefault(self.numBuckets)

    def setItemsCol(self, value: str):
        return self._set(itemsCol=value)

    def getItemsCol(self) -> str:
        return self.getOrDefault(self.itemsCol)

# PCY (Estimator) class for training the model
class PCY(Estimator, PCYParams, DefaultParamsReadable, DefaultParamsWritable):
    def __init__(self, minSupport: Optional[int] = None,
                 minConfidence: Optional[float] = None,
                 numBuckets: Optional[int] = None,
                 itemsCol: Optional[str] = None):
        super(PCY, self).__init__()

        # Set parameters
        if minSupport is not None:
            self.setMinSupport(minSupport)
        if minConfidence is not None:
            self.setMinConfidence(minConfidence)
        if numBuckets is not None:
            self.setNumBuckets(numBuckets)
        if itemsCol is not None:
            self.setItemsCol(itemsCol)

    def _hash_pair(self, pair: frozenset, num_buckets: int) -> int:
        pair_tuple = tuple(sorted(pair))
        return hash(pair_tuple) % num_buckets

    def _fit(self, dataset: DataFrame) -> "PCYModel":
        s = self.getMinSupport()
        c = self.getMinConfidence()
        num_buckets = self.getNumBuckets()
        items_col = self.getItemsCol()

        # Prepare data: Get carts from dataset
        # If items_col column exists then use it, otherwise create
        if items_col in dataset.columns:
            baskets_rdd: RDD[List[str]] = dataset.select(items_col).rdd.map(lambda x: x[0])
        else:
            # Assuming the data format is as in the original example
            baskets_df = dataset.groupBy("Member_number", "Date").agg(collect_set("itemDescription").alias(items_col))
            baskets_rdd: RDD[List[str]] = baskets_df.rdd.map(lambda x: x[items_col])

        # Pass 1: Create and count buckets
        pairs_rdd = baskets_rdd.flatMap(lambda items: [frozenset([items[i], items[j]])
                                                       for i in range(len(items))
                                                       for j in range(i + 1, len(items))])
        buckets_rdd = pairs_rdd.map(lambda pair: (self._hash_pair(pair, num_buckets), 1))
        bucket_counts = buckets_rdd.reduceByKey(lambda a, b: a + b)
        frequent_buckets = bucket_counts.filter(lambda x: x[1] >= s).collectAsMap()

        # Create a bitmap(set) and broadcast it
        sc = SparkContext.getOrCreate()
        frequent_buckets_set = set(frequent_buckets.keys())
        broadcast_frequent_buckets = sc.broadcast(frequent_buckets_set)

        # Pass 2: Count pairs in common buckets using broadcast variable
        def count_pairs(items: List[str]) -> List[Tuple[frozenset, int]]:
            pairs = [frozenset([items[i], items[j]])
                     for i in range(len(items))
                     for j in range(i + 1, len(items))]
            return [(pair, 1) for pair in pairs if self._hash_pair(pair, num_buckets) in broadcast_frequent_buckets.value]

        pair_counts = baskets_rdd.flatMap(count_pairs).reduceByKey(lambda a, b: a + b)
        frequent_pairs = pair_counts.filter(lambda x: x[1] >= s).collectAsMap()

        # Calculate individual item support and create association rules
        item_counts = baskets_rdd.flatMap(lambda items: [(item, 1) for item in items])\
                                 .reduceByKey(lambda a, b: a + b)\
                                 .collectAsMap()

        rules = []
        for pair, pair_support in frequent_pairs.items():
            item1, item2 = list(pair)
            conf1 = pair_support / item_counts[item1] if item_counts[item1] > 0 else 0
            conf2 = pair_support / item_counts[item2] if item_counts[item2] > 0 else 0
            if conf1 >= c:
                rules.append((item1, item2, conf1, pair_support))
            if conf2 >= c:
                rules.append((item2, item1, conf2, pair_support))

        # Returns the trained model
        return PCYModel(self.getMinSupport(), self.getMinConfidence(),
                       self.getNumBuckets(), self.getItemsCol(),
                       frequent_buckets, frequent_pairs, rules)

# PCYModel class to store training results
class PCYModel(Model, PCYParams, DefaultParamsReadable, DefaultParamsWritable):
    def __init__(self, minSupport: Optional[int] = None,
                 minConfidence: Optional[float] = None,
                 numBuckets: Optional[int] = None,
                 itemsCol: Optional[str] = None,
                 frequent_buckets: Optional[dict] = None,
                 frequent_pairs: Optional[dict] = None,
                 rules: Optional[List[Tuple[str, str, float, int]]] = None):

        super(PCYModel, self).__init__()

        # Set parameters
        if minSupport is not None:
            self.setMinSupport(minSupport)
        if minConfidence is not None:
            self.setMinConfidence(minConfidence)
        if numBuckets is not None:
            self.setNumBuckets(numBuckets)
        if itemsCol is not None:
            self.setItemsCol(itemsCol)

        self.frequent_buckets = frequent_buckets if frequent_buckets is not None else {}
        self.frequent_pairs = frequent_pairs if frequent_pairs is not None else {}
        self.rules = rules if rules is not None else []

        # Get SparkSession from current context
        self.spark = SparkSession.builder.getOrCreate()

    def _transform(self, dataset: DataFrame) -> DataFrame:
        # PCY does not transform the data, returning the original dataset
        return dataset

    def getFrequentBuckets(self) -> dict:
        return self.frequent_buckets

    def getFrequentPairs(self) -> dict:
        return self.frequent_pairs

    def getRules(self) -> List[Tuple[str, str, float, int]]:
        return self.rules

    # Function to print Frequent Pairs as DataFrame
    def printFrequentPairs(self):
        pairs_data = [(str(pair), count) for pair, count in self.frequent_pairs.items()]
        pairs_df = self.spark.createDataFrame(pairs_data, ["Pair", "Support"])
        print("Frequent Pairs:")
        pairs_df.show(truncate=False)

    # Function to print Association Rules as DataFrame
    def printRules(self):
        rules_data = [(rule[0], rule[1], rule[2], rule[3]) for rule in self.rules]
        rules_df = self.spark.createDataFrame(rules_data, ["Antecedent", "Consequent", "Confidence", "Support"])
        print("Association Rules:")
        rules_df.show(truncate=False)

if __name__ == "__main__":
    # Initialize Spark session
    spark = SparkSession.builder.appName("PCY Algorithm").getOrCreate()

    file_path = "baskets.csv"
    df = spark.read.csv(file_path, header=True)

    # Initialize PCY estimator with parameters directly in constructor
    pcy = PCY(minSupport=10, minConfidence=0.2, numBuckets=1000, itemsCol="basket")

    model = pcy.fit(df)

    model.printFrequentPairs()
    model.printRules()

    spark.stop()

Frequent Pairs:
+----------------------------------------------------+-------+
|Pair                                                |Support|
+----------------------------------------------------+-------+
|frozenset({'sausage', 'whole milk'})                |134    |
|frozenset({'whole milk', 'yogurt'})                 |167    |
|frozenset({'whole milk', 'semi-finished bread'})    |25     |
|frozenset({'sausage', 'yogurt'})                    |86     |
|frozenset({'yogurt', 'semi-finished bread'})        |13     |
|frozenset({'pastry', 'whole milk'})                 |97     |
|frozenset({'salty snack', 'pastry'})                |10     |
|frozenset({'salty snack', 'whole milk'})            |29     |
|frozenset({'sausage', 'hygiene articles'})          |13     |
|frozenset({'soda', 'pickled vegetables'})           |12     |
|frozenset({'whole milk', 'rolls/buns'})             |209    |
|frozenset({'sausage', 'rolls/buns'})                |80     |
|frozenset({'curd', 'frankfurter'})    