In [48]:
from pyspark.sql import SparkSession

spark = SparkSession. \
    builder. \
    config("spark.ui.port", "0"). \
    enableHiveSupport(). \
    appName("Spark DataSkew -Salting"). \
    master("yarn"). \
    config('spark.executor.instances','2'). \
    config('spark.executor.memory','512MB'). \
    config('spark.executor.cores','4'). \
    config('spark.dynamicAllocation.enabled','False'). \
    getOrCreate()

In [106]:
spark.stop()

In [49]:
spark.sparkContext.applicationId

'application_1745651200635_11844'

In [50]:
 #Disable AQE and Broadcast join

spark.conf.set("spark.sql.adaptive.enabled", False)
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", False)
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)


In [51]:
# Read Employee Skew data
schema = "first_name string, last_name string, job_title string, dob string, email string, phone string, salary double, department_id int"

emp = spark.read.format("csv").schema(schema).option("header", True).load("Datasets/employee_records_skew.csv")

In [28]:
emp.show()

+-----------+----------+--------------------+----------+--------------------+--------------------+--------+-------------+
| first_name| last_name|           job_title|       dob|               email|               phone|  salary|department_id|
+-----------+----------+--------------------+----------+--------------------+--------------------+--------+-------------+
|      Jacob|     Stark|         Fine artist|1976-04-25|jasonortiz@exampl...|  224-695-9516x02171|358889.0|            1|
|    Marissa|     Crane|Intelligence analyst|2000-06-24|johnsontroy@examp...|        277-928-0029|786608.0|            3|
|     Andrea|     Davis|     Physiotherapist|1999-06-17| ihowell@example.org|          9503082950|428991.0|            3|
|       John|     Tapia|Lecturer, further...|2001-09-23|russobarbara@exam...|    001-679-487-9525|241574.0|            9|
|      Colin|    Holmes|     Psychotherapist|1965-06-29|fsimmons@example.org|          3232202899|320260.0|            4|
|       Eric|      Beck|

In [52]:
# dept data
dept_schema = "department_id int, department_name string, description string, city string, state string, country string"

dept = spark.read.format("csv").schema(dept_schema).option("header", True).load("Datasets/department_data.csv")


In [53]:
#increaseing dataset size
emp_df = emp

In [31]:
emp_df.count()

1200365

In [54]:
# Join Datasets

df_joined = emp_df.join(dept, on=emp_df.department_id==dept.department_id, how="left_outer")


In [55]:
df_joined.write.format("noop").mode("overwrite").save()

In [56]:
df_joined.explain()


== Physical Plan ==
SortMergeJoin [department_id#839], [department_id#848], LeftOuter
:- *(1) Sort [department_id#839 ASC NULLS FIRST], false, 0
:  +- Exchange hashpartitioning(department_id#839, 200), ENSURE_REQUIREMENTS, [id=#836]
:     +- FileScan csv [first_name#832,last_name#833,job_title#834,dob#835,email#836,phone#837,salary#838,department_id#839] Batched: false, DataFilters: [], Format: CSV, Location: InMemoryFileIndex[hdfs://m01.itversity.com:9000/user/itv018960/Datasets/employee_records_skew.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<first_name:string,last_name:string,job_title:string,dob:string,email:string,phone:string,s...
+- *(3) Sort [department_id#848 ASC NULLS FIRST], false, 0
   +- Exchange hashpartitioning(department_id#848, 200), ENSURE_REQUIREMENTS, [id=#847]
      +- *(2) Filter isnotnull(department_id#848)
         +- FileScan csv [department_id#848,department_name#849,description#850,city#851,state#852,country#853] Batched: false, DataFilt

In [58]:
# Check the partition count
from pyspark.sql.functions import spark_partition_id, count, lit, desc

part_df = df_joined.withColumn("partition_num", spark_partition_id()).groupBy("partition_num").agg(count(lit(1)).alias("count"))

part_df.show()

+-------------+------+
|partition_num| count|
+-------------+------+
|          103|100417|
|          122| 99780|
|           43| 99451|
|          107| 99805|
|           49| 99706|
|           51|100248|
|          102|100214|
|           66|200420|
|          174|200310|
|           89|100014|
+-------------+------+



In [59]:
# Verify Employee data based on department_id
from pyspark.sql.functions import count, lit, desc, col

emp_df.groupBy("department_id").agg(count(lit(1))).show()


+-------------+--------+
|department_id|count(1)|
+-------------+--------+
|            1|   99451|
|            6|   99706|
|            3|  100248|
|            5|  200420|
|            9|  100014|
|            4|  100214|
|            8|  100417|
|            7|   99805|
|           10|   99780|
|            2|  200310|
+-------------+--------+



In [93]:
spark.conf.set("spark.sql.shuffle.partitions", 48)
spark.conf.get("spark.sql.shuffle.partitions")

'48'

In [94]:
#salting
import random
from pyspark.sql.functions import udf

# UDF to return a random number every time and add to Employee as salt
@udf
def salt_udf():
    return random.randint(0, 48)

# Salt Data Frame to add to department
salt_df = spark.range(0, 48)
salt_df.show()


+---+
| id|
+---+
|  0|
|  1|
|  2|
|  3|
|  4|
|  5|
|  6|
|  7|
|  8|
|  9|
| 10|
| 11|
| 12|
| 13|
| 14|
| 15|
| 16|
| 17|
| 18|
| 19|
+---+
only showing top 20 rows



In [95]:
# Apply Salted to EMP tables
from pyspark.sql.functions import lit, concat

salted_emp = emp.withColumn("salted_dept_id", concat("department_id", lit("_"), salt_udf()))

salted_emp.show()                                                     


+-----------+----------+--------------------+----------+--------------------+--------------------+--------+-------------+--------------+
| first_name| last_name|           job_title|       dob|               email|               phone|  salary|department_id|salted_dept_id|
+-----------+----------+--------------------+----------+--------------------+--------------------+--------+-------------+--------------+
|      Jacob|     Stark|         Fine artist|1976-04-25|jasonortiz@exampl...|  224-695-9516x02171|358889.0|            1|          1_30|
|    Marissa|     Crane|Intelligence analyst|2000-06-24|johnsontroy@examp...|        277-928-0029|786608.0|            3|          3_46|
|     Andrea|     Davis|     Physiotherapist|1999-06-17| ihowell@example.org|          9503082950|428991.0|            3|          3_31|
|       John|     Tapia|Lecturer, further...|2001-09-23|russobarbara@exam...|    001-679-487-9525|241574.0|            9|          9_15|
|      Colin|    Holmes|     Psychotherap

In [97]:
emp.count()

1200365

In [98]:
salted_emp.count()

1200365

In [99]:
# Salted Department

salted_dept = dept.join(salt_df, how="cross").withColumn("salted_dept_id", concat("department_id", lit("_"), "id"))

salted_dept.where("department_id = 5").show()

+-------------+---------------+--------------------+---------+-----+-------+---+--------------+
|department_id|department_name|         description|     city|state|country| id|salted_dept_id|
+-------------+---------------+--------------------+---------+-----+-------+---+--------------+
|            5|     Hardin Inc|Re-contextualized...|Hayestown|   WA|   Fiji|  0|           5_0|
|            5|     Hardin Inc|Re-contextualized...|Hayestown|   WA|   Fiji|  1|           5_1|
|            5|     Hardin Inc|Re-contextualized...|Hayestown|   WA|   Fiji|  2|           5_2|
|            5|     Hardin Inc|Re-contextualized...|Hayestown|   WA|   Fiji|  3|           5_3|
|            5|     Hardin Inc|Re-contextualized...|Hayestown|   WA|   Fiji|  4|           5_4|
|            5|     Hardin Inc|Re-contextualized...|Hayestown|   WA|   Fiji|  5|           5_5|
|            5|     Hardin Inc|Re-contextualized...|Hayestown|   WA|   Fiji|  6|           5_6|
|            5|     Hardin Inc|Re-contex

In [100]:
dept.count()

10

In [101]:
salted_dept.count()

480

In [102]:
# Lets make the salted join now

salted_joined_df = salted_emp.join(salted_dept, on=salted_emp.salted_dept_id==salted_dept.salted_dept_id, how="left_outer")

In [103]:
#write Action

salted_joined_df.write.format("noop").mode("overwrite").save()

In [105]:
# Check the each partition counts
from pyspark.sql.functions import spark_partition_id, count

part_df = salted_joined_df.withColumn("partition_num", spark_partition_id()).groupBy("partition_num").agg(count(lit(1)).alias("count"))

part_df.show()


+-------------+-----+
|partition_num|count|
+-------------+-----+
|           18|36727|
|           38|32467|
|           31|24775|
|            3|24379|
|           27|20414|
|           37|24421|
|            2|14304|
|           22|20493|
|            9|36516|
|           17|22497|
|           11|38636|
|           25|28739|
|           19|20210|
|           30|34866|
|           40|22485|
|           42|16314|
|           28|14317|
|           36|29006|
|            0|26354|
|            8|28690|
+-------------+-----+
only showing top 20 rows



Conclusion:
1. understand data skew
2. understand data spill to memory and Disk
3. how to prepare the salted key
4. how to reduce spills using salting approch
