In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from pyspark.sql import functions as F


Spark = SparkSession.builder.appName("DataSkew").getOrCreate()
sc=Spark.sparkContext

##Check the Spark Version

In [0]:
spark.version

Out[2]: '3.3.0'

## Since the Spark version is 3.3.0, We will disable AQE and AutoBroadcastJoin to test the Salting Technique

In [0]:
spark.conf.set("spark.sql.autoBroadcastJoinThreshold","-1")
spark.conf.set("spark.sql.adaptive.enabled","false")
spark.conf.set("spark.sql.shuffle.partitions","3") #Setting the sql shuffle partiton from default value 200 to 3

In [0]:
display(spark.conf.get("spark.sql.autoBroadcastJoinThreshold"))
display(spark.conf.get("spark.sql.adaptive.enabled"))
display(spark.conf.get("spark.sql.shuffle.partitions"))
display(spark.conf.get("spark.executor.memory")) #Spark Executor Memory is 4GB

'-1''false''3''8278m'

## Create highly skewed DataFrame with 100M rows, where 99% of the rows have the same key value

In [0]:
from pyspark.sql.functions import rand,when
df0=spark.createDataFrame([0] * 99999998,IntegerType()).withColumn("key",when(rand()<0.01,0).otherwise(1)).repartition(1)
df1 = spark.createDataFrame([1], IntegerType()).withColumn("key",when(rand()<0.01,0).otherwise(1)).repartition(1)
df2 = spark.createDataFrame([2], IntegerType()).withColumn("key",when(rand()<0.01,0).otherwise(1)).repartition(1)
df_skew = df0.union(df1).union(df2)
display(df_skew)

value,key
0,1
0,1
0,1
0,1
0,1
0,1
0,1
0,1
0,1
0,1


In [0]:
df_skew.groupBy("key").count().orderBy("key").show()

+---+--------+
|key|   count|
+---+--------+
|  0| 1001899|
|  1|98998101|
+---+--------+



In [0]:
#Check the partition
df_skew.rdd.getNumPartitions()

Out[7]: 3

In [0]:
#Check the data skew using spark_partition_id
df_skewed = df_skew.withColumn("PartitonID",F.spark_partition_id())
df_skewed.groupby("PartitonID").count().show()

+----------+--------+
|PartitonID|   count|
+----------+--------+
|         0|99999998|
|         2|       1|
|         1|       1|
+----------+--------+



In [0]:
df_skewed.printSchema()

root
 |-- value: integer (nullable = true)
 |-- key: integer (nullable = false)
 |-- PartitonID: integer (nullable = false)



In [0]:
#Create a second dataframe - small size

df_small = spark.createDataFrame([i for i in range(10000)],IntegerType())\
            .withColumn("size",when(rand() < 0.1, "small").otherwise("Large")).repartition(3)

display(df_small)

value,size
137,small
809,Large
473,Large
889,Large
303,Large
1020,Large
529,Large
9,Large
468,Large
1018,Large


In [0]:
df_small.groupBy("size").count().orderBy("size").show()

+-----+-----+
| size|count|
+-----+-----+
|Large| 8988|
|small| 1012|
+-----+-----+



In [0]:
df_small.rdd.getNumPartitions()

Out[14]: 3

In [0]:
#Data Evenly Partitioned
df_even = df_small.withColumn("PartitionID",F.spark_partition_id())
df_even.groupby("PartitionID").count().show()

+-----------+-----+
|PartitionID|count|
+-----------+-----+
|          0| 3334|
|          2| 3333|
|          1| 3333|
+-----------+-----+



## Join both the DataFrame

In [0]:
df_join = df_skew.join(df_small,"value","inner").select(df_skew.value,df_skew.key,df_small.size)
display(df_join)



value,key,size
0,1,Large
0,1,Large
0,1,Large
0,1,Large
0,1,Large
0,1,Large
0,1,Large
0,1,Large
0,1,Large
0,1,Large


In [0]:
#Data Unevenly Distributed 
df_join.withColumn("PartitionID",F.spark_partition_id()).groupby("PartitionId").count().show()

+-----------+--------+
|PartitionId|   count|
+-----------+--------+
|          0|99999999|
|          1|       1|
+-----------+--------+



In [0]:
df_left = df_skew.withColumn("salt", (F.rand() * spark.conf.get("spark.sql.shuffle.partitions")).cast("int"))

In [0]:
df_right = df_small.withColumn("salt_temp", F.array([F.lit(i) for i in range(int(spark.conf.get("spark.sql.shuffle.partitions")))]))\
                  .withColumn("salt",F.explode(F.col("salt_temp")))

In [0]:
df_joined=df_left.join(df_right,["value","salt"],"inner").select(df_left.value,df_left.key,df_right.size)
display(df_joined)

value,key,size
0,1,Large
0,1,Large
0,1,Large
0,1,Large
0,1,Large
0,1,Large
0,1,Large
0,1,Large
0,1,Large
0,1,Large


In [0]:
#Data Evenly distributed after SALTING
df_joined.withColumn("PartitionID",F.spark_partition_id()).groupby("PartitionId").count().show()

+-----------+--------+
|PartitionId|   count|
+-----------+--------+
|          0|33332520|
|          2|33334326|
|          1|33333154|
+-----------+--------+

