In [12]:
# Create Spark Session
from pyspark.sql import SparkSession,DataFrame
spark: SparkSession = SparkSession.builder.appName("yt_pyspark_revision").getOrCreate() # type: ignore

In [39]:
# Customer Data
customer_data = [
    (1, 101),
    (1, 102),
    (1, 104),
    (2, 101),
    (3, 102),
    (3, 101),
    (3, 103),
    (4, 101),
    (4, 102),
    (4, 104),
    (4, 103),
    (4, 103)
]
# Product Data
product_data = [
    (101,),
    (102,),
    (103,)
]

# Schema
customer_schema = "customer_id int, product_key int"
product_schema = "product_key int"

# Product df
product_df : DataFrame = spark.createDataFrame(data = product_data, schema= product_schema)
customer_df : DataFrame = spark.createDataFrame(data = customer_data,schema=customer_schema)

In [49]:
print("**********Product DataFrame**********")
product_df.show()
print("**********Customer DataFrame**********")
customer_df.show()

**********Product DataFrame**********
+-----------+
|product_key|
+-----------+
|        101|
|        102|
|        103|
+-----------+

**********Customer DataFrame**********
+-----------+-----------+
|customer_id|product_key|
+-----------+-----------+
|          1|        101|
|          1|        102|
|          1|        104|
|          2|        101|
|          3|        102|
|          3|        101|
|          3|        103|
|          4|        101|
|          4|        102|
|          4|        104|
|          4|        103|
|          4|        103|
+-----------+-----------+



In [40]:
product_count = product_df.count()

In [41]:
from pyspark.sql.functions import col, countDistinct, collect_set, concat_ws, count
joined_df = customer_df.join(product_df, on=customer_df.product_key == product_df.product_key, how='inner') \
    .select(customer_df.customer_id, customer_df.product_key)

result = joined_df.groupBy("customer_id").agg(count('product_key').alias("count"), concat_ws(", ", collect_set("product_key")).alias("products")) \
    .filter(f'count == {product_count}')



In [42]:
result.show()

+-----------+-----+-------------+
|customer_id|count|     products|
+-----------+-----+-------------+
|          3|    3|102, 103, 101|
+-----------+-----+-------------+



In [48]:
from pyspark.sql.functions import sort_array,collect_list,asc,countDistinct,array_distinct

s = joined_df.groupBy("customer_id") \
    .agg(countDistinct('product_key').alias("count"), sort_array(array_distinct(collect_list("product_key")), asc = False).alias("products")) \
    .filter(f'count == {product_count}')
s.show()

+-----------+-----+---------------+
|customer_id|count|       products|
+-----------+-----+---------------+
|          3|    3|[103, 102, 101]|
|          4|    3|[103, 102, 101]|
+-----------+-----+---------------+



In [51]:
spark.stop()