In [1]:
from pyspark.sql import DataFrame

def join_products_categories(
    df_products: DataFrame,
    df_categories: DataFrame,
    df_product_category: DataFrame,
    product_id_col: str = "product_id",
    product_name_col: str = "product_name",
    category_id_col: str = "category_id",
    category_name_col: str = "category_name",
) -> DataFrame:
    """Join products and categories"""

    p = df_products.alias("p")
    c = df_categories.alias("c")
    pc = df_product_category.dropDuplicates(["product_id","category_id"]).alias("pc")

    joined = (
        p.join(pc, col(f"p.{product_id_col}") == col(f"pc.{product_id_col}"), "left")
         .join(c, col(f"pc.{category_id_col}") == col(f"c.{category_id_col}"), "left")
         .select(
             col(f"p.{product_name_col}").alias("product_name"),
             col(f"c.{category_name_col}").alias("category_name"),
         )
    )
    return joined

In [2]:
# Example
from pyspark.sql import SparkSession
from pyspark.sql.functions import col

spark = SparkSession.builder.getOrCreate()

products_data = [
    (1, "iPhone"),
    (2, "MacBook"),
    (3, "AirPods"),
    (4, "Sticker"),  # no category
    (5, "iPad"),
    (6, "Trash Can"), # no category
]

products_cols = ["product_id", "product_name"]

df_products = spark.createDataFrame(products_data, products_cols)

categories_data = [
    (1, "Electronics"),
    (2, "Computers"),
    (3, "Accessories"),
    (4, "Cars"), # no products
]
categories_cols = ["category_id", "category_name"]

df_categories = spark.createDataFrame(categories_data, categories_cols)

product_category_data = [
    (1, 1),
    (1, 3),
    (2, 2),
    (3, 3),
    (5, 1),
    (1, 1),
]

product_category_cols = ["product_id", "category_id"]

df_product_category = spark.createDataFrame(product_category_data, product_category_cols)

df_result = join_products_categories(df_products, df_categories, df_product_category)

df_result.orderBy(col("category_name").asc_nulls_last(), col("product_name").asc()).show()


Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
25/09/08 03:10:56 WARN Utils: Your hostname, DESKTOP-EL02S3D, resolves to a loopback address: 127.0.1.1; using 10.255.255.254 instead (on interface lo)
25/09/08 03:10:56 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/09/08 03:10:56 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
                                                                                

+------------+-------------+
|product_name|category_name|
+------------+-------------+
|     AirPods|  Accessories|
|      iPhone|  Accessories|
|     MacBook|    Computers|
|        iPad|  Electronics|
|      iPhone|  Electronics|
|     Sticker|         NULL|
|   Trash Can|         NULL|
+------------+-------------+



In [3]:
# Tests initialize
import ipytest, pytest
__file__ = "main.ipynb"
ipytest.autoconfig()

In [4]:
# Fixtures
from pyspark.sql import SparkSession

@pytest.fixture(scope="session")
def spark():
    spark = (
        SparkSession.builder.master("local[*]")
        .appName("pytest-pyspark")
        .getOrCreate()
    )
    yield spark
    spark.stop()

In [5]:
# Tests
def test_pairs_and_nulls(spark):
    """Tests join_products_categories method given eq values"""

    df_products = spark.createDataFrame(
        [(1, "iPhone"), (2, "MacBook"), (3, "AirPods"), (4, "Sticker"), (5, "iPad"), (6, "Trash Can")],
        ["product_id", "product_name"],
    )
    df_categories = spark.createDataFrame(
        [(1, "Electronics"), (2, "Computers"), (3, "Accessories"), (4, "Cars")],
        ["category_id", "category_name"],
    )
    df_pc = spark.createDataFrame(
        [(1, 1), (1, 3), (2, 2), (3, 3), (5, 1)],
        ["product_id", "category_id"],
    )

    out = join_products_categories(df_products, df_categories, df_pc)

    expected = spark.createDataFrame(
        [
            ("AirPods", "Accessories"),
            ("iPhone", "Accessories"),
            ("MacBook", "Computers"),
            ("iPad", "Electronics"),
            ("iPhone", "Electronics"),
            ("Sticker", None),
            ("Trash Can", None),
        ],
        ["product_name", "category_name"],
    )

    assert set(map(tuple, out.collect())) == set(map(tuple, expected.collect()))

def test_duplicate_relations_drop(spark):
    """Tests that dup relations are dropped"""

    df_products = spark.createDataFrame([(1, "USB-C Cable")], ["product_id", "product_name"])
    df_categories = spark.createDataFrame([(3, "Accessories")], ["category_id", "category_name"])
    
    df_pc = spark.createDataFrame([(1, 3), (1, 3)], ["product_id", "category_id"])

    out = join_products_categories(df_products, df_categories, df_pc)

    rows = out.collect()
    assert len(rows) == 1

ipytest.run("-v")


platform linux -- Python 3.11.2, pytest-8.4.2, pluggy-1.6.0
rootdir: /home/light/testovie/pyspark-products
configfile: pyproject.toml
collected 2 items

t_70e2b9a3b4194b95ad4e0a5339f0e726.py 

25/09/08 03:11:08 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.
                                                                                

[32m.[0m[32m.[0m[32m                                                     [100%][0m



<ExitCode.OK: 0>