In [None]:
import time

data = sc.textFile("hdfs://vm1:9000/user/azureuser/data/mushroom.dat", minPartitions=16)

SUPPORT = int(0.2*data.count())
print(SUPPORT)

# compute frequent-1 itemset

# Create the rdd which stores all the transactions
transactions = data.map(lambda line: line.strip().split())

# Merge all the transactions together
items=transactions.flatMap(lambda x:x)

# count the frequency of all the items in the merged rdd list
item_counts=items.map(lambda x:(x, 1)).reduceByKey(lambda x,y:x+y)

# get the frequent-1 itemset
L1=item_counts.filter(lambda x:x[1]>=SUPPORT).map(lambda x:x[0])
print(L1.count())

transactions.flatMap(lambda x:x).count()
freq=sc.broadcast(set(L1.collect()))

# eliminate all the infrequent items
def purge_itemset(itemsets):
    # Only keep frequent items in 
    return [item for item in itemsets if item in freq.value]
transactions=transactions.map(purge_itemset)
transactions.flatMap(lambda x:x).count()

executor_memory_status = sc._jsc.sc().getExecutorMemoryStatus().size()
worker_count = executor_memory_status - 1
print(f"Current number of workers: {worker_count}")

# Small data algorithm, without pruning
def join_and_pruning(L_k):
    # Determine k from the first itemset if L_k is not empty
    if not L_k:
        return []

    k = len(L_k[0])

    # Generate candidate (k+1)-itemsets by joining k-itemsets
    candidate_set = set()
    n = len(L_k)
    for i in range(n):
        for j in range(i + 1, n):
            set_i = set(L_k[i])
            set_j = set(L_k[j])
            union_set = set_i | set_j
            if len(union_set) == k + 1:
                # Store as a frozen set to allow hashing and later convert to list
                candidate_set.add(frozenset(union_set))

    # Convert frozen sets back to lists
    candidates = [list(item) for item in candidate_set]

    return candidates

itemset_broadcast=sc.broadcast(transactions.map(lambda x: set(x)).collect())
def support_count(itemset):
    # Count how many transactions contain the itemset
    count = sum(1 for transaction in itemset_broadcast.value if set(itemset).issubset(transaction))
    return (itemset, count)
def get_frequent_set(C_k):
    C_k_rdd=sc.parallelize(C_k)
    # Map and filter step
    L_k = C_k_rdd \
        .map(support_count) \
        .filter(lambda x: x[1] >= SUPPORT) \
        .map(lambda x: x[0]) \
        .collect()
    return L_k

L_k = [[item] for item in L1.collect()]
k = 2
    
overall_start_time = time.time()  # Start timing for the whole loop

while L_k:
    print(f'k={k}')
    C_k = join_and_pruning(L_k)
    L_k = get_frequent_set(C_k)
    k += 1

overall_end_time = time.time()  # End timing for the whole loop
print(f'Total time for the while loop: {overall_end_time - overall_start_time:.2f} seconds')

spark.stop()