In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import lit, col

spark = SparkSession.builder.getOrCreate()

# Left table: 90% rows with key 'A' (skewed)
left_data = [(f"id_{i}", "A") if i < 90 else (f"id_{i}", f"K_{i}") for i in range(10)]
left_df = spark.createDataFrame(left_data, ["id", "key"])
left_df.display()
# Right table: matching keys
right_data = [("A", "hot_value")]
right_df = spark.createDataFrame(right_data, ["key", "value"])
right_df.display()

In [0]:
from pyspark.sql.functions import when, concat, floor, rand, col, lit, explode, array

# Add salt only to skewed keys
left_salted = left_df.withColumn(
    "salted_key",
    when(col("key") == "A", concat(col("key"), lit("_"), floor(rand() * 5)))  # 5-way salt
    .otherwise(col("key"))
)
display(left_salted)

# Explode right side to match salted keys
right_exploded = right_df.filter(col("key") == "A").withColumn(
    "salted_key",
    explode(array([lit(f"A_{i}") for i in range(5)]))
)
right_non_exploded = right_df.filter(col("key") != "A").withColumn(
    "salted_key",
    col("key")
)

right_final = right_exploded.union(right_non_exploded)
display(right_final)

In [0]:
joined_df = left_salted.join(right_final, on="salted_key", how="inner")
joined_df.display()