In [None]:
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql.types import IntegerType

In [None]:
spark = SparkSession.builder.appName("salting").getOrCreate()
spark.conf.set("spark.sql.shuffle.partitions", "3")
spark.conf.set("spark.sql.adaptive.enabled", "false")

# Uniform dataset

In [None]:
df_uniform = spark.createDataFrame([i for i in range(1000000)], IntegerType())
df_uniform.show(3, truncate=False)

In [None]:
(
    df_uniform
    .withColumn("partition", F.spark_partition_id())
    .groupBy("partition")
    .count()
    .orderBy("partition")
    .show()
)

# Skewed dataset  

In [None]:
df0 = spark.createDataFrame([0]*999990, IntegerType()).repartition(1)
df1 = spark.createDataFrame([1]*15, IntegerType()).repartition(1)
df2 = spark.createDataFrame([2]*10, IntegerType()).repartition(1)
df3 = spark.createDataFrame([3]*5, IntegerType()).repartition(1)
df_skew = df0.union(df1).union(df2).union(df3)
df_skew.show(3, truncate=False)

In [None]:
(
    df_skew
    .withColumn("partition", F.spark_partition_id())
    .groupBy("partition")
    .count()
    .orderBy("partition")
    .show()
)

# Skewed join dataset

In [None]:
df_joined_c1 = df_skew.join(df_uniform,'value', 'inner')

In [None]:
(
    df_joined_c1
    .withColumn("partition", F.spark_partition_id())
    .groupBy("partition")
    .count()
    .show()
)

# Simulating uniform distribution through salting

In [None]:
SALT_NUMBER = int(spark.conf.get("spark.sql.shuffle.partitions"))
SALT_NUMBER

In [None]:
df_skew = df_skew.withColumn("salt", (F.rand()*SALT_NUMBER).cast("int"))
df_skew.show()

In [None]:
df_uniform = (
    df_uniform
    .withColumn("salt_values", F.array([F.lit(i) for i in range(SALT_NUMBER)]))
    .withColumn("salt", F.explode(F.col("salt_values")))
)
df_uniform.show()

In [None]:
df_joined = df_skew.join(df_uniform,['value', 'salt'], 'inner')

In [None]:
(
    df_joined
    .withColumn("partition", F.spark_partition_id())
    .groupBy("value","partition")
    .count()
    .orderBy("value","partition")
    .show()
)

# Salting in aggregations

In [None]:
(
    df_skew
    .groupBy("value", "salt")
    .agg(F.count("value").alias("count"))
    .groupBy("value")
    .agg(F.sum("count").alias("count"))
    .show()
)

In [None]:
spark.stop()