In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import spark_partition_id, count

spark = SparkSession.builder.getOrCreate()

# Skew: 80k rows in Electronics
skewed = [(i, f"Item_{i}", "Electronics") for i in range(1, 80001)]

# Other categories
others = [(i, f"Item_{i}", ["Toys", "Sports", "Clothing"][i % 3])
          for i in range(80001, 100001)]

data = skewed + others

df = spark.createDataFrame(data, ["id", "name", "category"])

print("Total rows =", df.count())
df.show(5)


Total rows = 100000
+---+------+-----------+
| id|  name|   category|
+---+------+-----------+
|  1|Item_1|Electronics|
|  2|Item_2|Electronics|
|  3|Item_3|Electronics|
|  4|Item_4|Electronics|
|  5|Item_5|Electronics|
+---+------+-----------+
only showing top 5 rows


In [0]:
dist_before = (
    df.select(spark_partition_id().alias("pid"))
      .groupBy("pid")
      .agg(count("*").alias("rows"))
)

print("Partition distribution BEFORE salting:")
dist_before.show()


Partition distribution BEFORE salting:
+---+------+
|pid|  rows|
+---+------+
|  0|100000|
+---+------+



In [0]:
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType

def classify(cat):
    if cat == "Electronics":
        return "High_Load"
    else:
        return "Normal_Load"

classify_udf = udf(classify, StringType())

df_udf = df.withColumn("cat_class", classify_udf("category"))

df_udf.show(5)


+---+------+-----------+---------+
| id|  name|   category|cat_class|
+---+------+-----------+---------+
|  1|Item_1|Electronics|High_Load|
|  2|Item_2|Electronics|High_Load|
|  3|Item_3|Electronics|High_Load|
|  4|Item_4|Electronics|High_Load|
|  5|Item_5|Electronics|High_Load|
+---+------+-----------+---------+
only showing top 5 rows


In [0]:
from pyspark.sql.functions import col, concat_ws

SALTS = 8  # perfect for single-node serverless

df_salted = df_udf.withColumn(
    "category_salted",
    concat_ws("_",
              col("category"),
              (col("id") % SALTS))
)


In [0]:
df_fixed = df_salted.repartition(SALTS, col("category_salted"))


In [0]:
dist_after = (
    df_fixed.select(spark_partition_id().alias("pid"))
            .groupBy("pid")
            .agg(count("*").alias("rows"))
)

print("Partition distribution AFTER salting:")
dist_after.show()


Partition distribution AFTER salting:
+---+-----+
|pid| rows|
+---+-----+
|  0|15832|
|  1|12500|
|  2|14169|
|  3|10833|
|  4|  833|
|  5|11666|
|  6|32500|
|  7| 1667|
+---+-----+

