In [14]:
from pyspark.sql import SparkSession

spark = (
    SparkSession.builder
    .appName("mindbox_product_category_pairs")
    .master("local[*]")
    .config("spark.sql.shuffle.partitions", "4")
    .getOrCreate()
)

spark


In [15]:
from pyspark.sql.types import (
    StructType, StructField, LongType, StringType
)

product_schema = StructType([
    StructField("id", LongType(), nullable=False),
    StructField("name", StringType(), nullable=False),
])

category_schema = StructType([
    StructField("id", LongType(), nullable=False),
    StructField("name", StringType(), nullable=False),
])

product_category_schema = StructType([
    StructField("product_id", LongType(), nullable=False),
    StructField("category_id", LongType(), nullable=False),
])


In [16]:
products_data = [
    {"id": 1, "name": "Phone"},
    {"id": 2, "name": "Laptop"},
    {"id": 3, "name": "Book"},
    {"id": 4, "name": "Table"},
]

categories_data = [
    {"id": 10, "name": "Electronics"},
    {"id": 20, "name": "Furniture"},
]

product_categories_data = [
    {"product_id": 1, "category_id": 10},
    {"product_id": 2, "category_id": 10},
    {"product_id": 4, "category_id": 20},
    {"product_id": 1, "category_id": 10},
    {"product_id": 3, "category_id": 777},
    {"product_id": 999, "category_id": 10},
]


In [17]:
products = spark.createDataFrame(products_data, schema=product_schema)
categories = spark.createDataFrame(categories_data, schema=category_schema)
product_categories = spark.createDataFrame(product_categories_data, schema=product_category_schema)


In [None]:
print("products")
products.show(truncate=False)

print("categories")
categories.show(truncate=False)

print("product_categories")
product_categories.show(truncate=False)


products
+---+------+
|id |name  |
+---+------+
|1  |Phone |
|2  |Laptop|
|3  |Book  |
|4  |Table |
+---+------+

categories
+---+-----------+
|id |name       |
+---+-----------+
|10 |Electronics|
|20 |Furniture  |
+---+-----------+

product_categories
+----------+-----------+
|product_id|category_id|
+----------+-----------+
|1         |10         |
|2         |10         |
|4         |20         |
|1         |10         |
|3         |777        |
|999       |10         |
+----------+-----------+



In [19]:
from pyspark.sql import DataFrame
from pyspark.sql.functions import col

def product_category_pairs(
    products_df: DataFrame,
    categories_df: DataFrame,
    product_categories_df: DataFrame
) -> DataFrame:

    pc_dedup = (
        product_categories_df
        .select("product_id", "category_id")
        .dropDuplicates(["product_id", "category_id"])
    )

    joined = (
        products_df.alias("p")
        .join(pc_dedup.alias("pc"), col("p.id") == col("pc.product_id"), how="left")
    )

    with_cats = (
        joined
        .join(categories_df.alias("c"), col("pc.category_id") == col("c.id"), how="left")
    )

    result = with_cats.select(
        col("p.id").alias("product_id"),
        col("p.name").alias("product_name"),
        col("c.id").alias("category_id"),
        col("c.name").alias("category_name"),
    )

    return result


In [20]:
res = product_category_pairs(products, categories, product_categories)

res_ordered = res.orderBy("product_id", "category_id")
res_ordered.show(truncate=False)


+----------+------------+-----------+-------------+
|product_id|product_name|category_id|category_name|
+----------+------------+-----------+-------------+
|1         |Phone       |10         |Electronics  |
|2         |Laptop      |10         |Electronics  |
|3         |Book        |NULL       |NULL         |
|4         |Table       |20         |Furniture    |
+----------+------------+-----------+-------------+



In [None]:

assert res.count() == products.count() == 4

book_rows = res.filter((col("product_id") == 3) & col("category_id").isNull() & col("category_name").isNull())
assert book_rows.count() == 1

assert res.select("product_id", "category_id").dropDuplicates().count() == res.count()

expected = {
    (1, "Phone", 10, "Electronics"),
    (2, "Laptop", 10, "Electronics"),
    (3, "Book",  None, None),
    (4, "Table", 20, "Furniture"),
}
actual = set(
    (r["product_id"], r["product_name"], r["category_id"], r["category_name"])
    for r in res.collect()
)
assert actual == expected

print("Все проверки пройдены")


✓ Все проверки пройдены


In [22]:
products.createOrReplaceTempView("products")
categories.createOrReplaceTempView("categories")
product_categories.createOrReplaceTempView("product_categories")


In [None]:
from pyspark.sql.functions import col

sql_res = spark.sql("""
WITH pc_dedup AS (
    SELECT DISTINCT product_id, category_id
    FROM product_categories
)
SELECT
    p.id  AS product_id,
    p.name AS product_name,
    c.id  AS category_id,
    c.name AS category_name
FROM products p
LEFT JOIN pc_dedup pc
    ON p.id = pc.product_id
LEFT JOIN categories c
    ON pc.category_id = c.id
ORDER BY product_id, category_id
""")

sql_res.show(truncate=False)

api_res = product_category_pairs(products, categories, product_categories).orderBy("product_id","category_id")
assert api_res.collect() == sql_res.collect()
print("SQL вывод совпадает с DataFrame API")


+----------+------------+-----------+-------------+
|product_id|product_name|category_id|category_name|
+----------+------------+-----------+-------------+
|1         |Phone       |10         |Electronics  |
|2         |Laptop      |10         |Electronics  |
|3         |Book        |NULL       |NULL         |
|4         |Table       |20         |Furniture    |
+----------+------------+-----------+-------------+

✓ SQL вывод совпадает с DataFrame API
