### Salting: The basic principle is to bucket the data in a group using a column with a random number -- the salt
### This random key ensures that records with the same key are grouped together and spread evenly across partitions, preventing data skew and ensuring that each partition has a balanced workload.

In [1]:
from pyspark.sql import SparkSession

In [2]:
spark = SparkSession.builder \
        .appName("SaltingExample1") \
        .master("local[*]") \
        .config("spark.driver.bindAddress", "127.0.0.1") \
        .getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/05/01 20:29:04 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
data = [("key1", "value1"),
        ("key2", "value2"),
        ("key3", "value3"),
        ("key4", "value4"),
        ("key5", "value5"),
        ("key6", "value6"),
        ("key7", "value7"),
        ("key8", "value8"),
        ("key9", "value9"),
        ("key10", "value10")]
rdd = spark.sparkContext.parallelize(data)
num_partitions = 5

In [4]:
def salted_key(key):
    return key + "_" + str(hash(key)%num_partitions)

In [5]:
salted_rdd = rdd.map(lambda x: (salted_key(x[0]), x[1]))

In [6]:
salted_rdd.collect()

[('key1_1', 'value1'),
 ('key2_0', 'value2'),
 ('key3_4', 'value3'),
 ('key4_1', 'value4'),
 ('key5_0', 'value5'),
 ('key6_1', 'value6'),
 ('key7_0', 'value7'),
 ('key8_1', 'value8'),
 ('key9_4', 'value9'),
 ('key10_4', 'value10')]

In [7]:
def show_partitions(rdd, number_of_partitions):
    partitions = rdd.partitionBy(number_of_partitions)
    counter = 1
    for partition in partitions.glom().collect():
        print(f"Partition{counter}:")
        counter+=1
        for record in partition:
            print(record)

In [8]:
print("=======Before salting=======")
show_partitions(rdd,5)
print("=======After salting=======")
show_partitions(salted_rdd,5)

Partition1:
('key2', 'value2')
('key5', 'value5')
('key7', 'value7')
Partition2:
('key1', 'value1')
('key4', 'value4')
('key6', 'value6')
('key8', 'value8')
Partition3:
Partition4:
Partition5:
('key3', 'value3')
('key9', 'value9')
('key10', 'value10')
Partition1:
Partition2:
('key2_0', 'value2')
('key4_1', 'value4')
('key5_0', 'value5')
Partition3:
('key1_1', 'value1')
('key6_1', 'value6')
Partition4:
('key3_4', 'value3')
('key7_0', 'value7')
('key8_1', 'value8')
('key9_4', 'value9')
Partition5:
('key10_4', 'value10')


In [9]:
spark.stop()

In [10]:
import random
from pyspark.sql import SparkSession
from pyspark.sql.functions import spark_partition_id, count
from pyspark.sql.functions import udf, concat, lit, explode

In [11]:
# generate example healthcare insurace claims dataset of fact and dimension
def generate_fact_table(counter):
    claims_records = []
    drug_code = ["D100","D250", "D300", "D104", "D204","D304"]
    claims_ids = ["carrierXYZ" + str(i) for i in range(1100, 1110)]
    dayssupply_range = [i for i in range(5,100)] 
    for i in range(counter):
        record = [i, random.choice(claims_ids), random.choice(drug_code), random.choice(dayssupply_range)] 
        claims_records.append(record) 
    return claims_records

In [12]:
fact_records = generate_fact_table(200)
dimension_records = [["D100","BrandA"],
                     ["D250","BrandB"],
                     ["D300","BrandC"],
                     ["D104","GenericA"],
                     ["D204","GenericB"],
                     ["D304","GenericC"]] 
dimension_cols = ["drug_code","drug_name"]
fact_cols = ["id", "claim_id","drug_code", "dayssupply"]

In [13]:
spark = SparkSession.builder \
        .appName("saltingExample2") \
        .master("local[*]") \
        .config("spark.driver.bindAddress", "127.0.0.1") \
        .getOrCreate()

In [14]:
fact_df = spark.createDataFrame(data = fact_records, schema = fact_cols) 
dim_df = spark.createDataFrame(data = dimension_records, schema = dimension_cols) 

In [15]:
fact_df.show(10)

+---+--------------+---------+----------+
| id|      claim_id|drug_code|dayssupply|
+---+--------------+---------+----------+
|  0|carrierXYZ1102|     D100|        89|
|  1|carrierXYZ1107|     D300|        66|
|  2|carrierXYZ1101|     D304|         8|
|  3|carrierXYZ1101|     D104|        43|
|  4|carrierXYZ1103|     D300|        10|
|  5|carrierXYZ1102|     D100|        69|
|  6|carrierXYZ1103|     D250|        34|
|  7|carrierXYZ1101|     D100|        34|
|  8|carrierXYZ1108|     D250|        96|
|  9|carrierXYZ1100|     D204|        55|
+---+--------------+---------+----------+
only showing top 10 rows



In [16]:
spark.conf.set("spark.sql.adaptive.enabled", False)
spark.conf.set("spark.sql.shuffle.partitions", 5)
# Check the parameters
print(spark.conf.get("spark.sql.adaptive.enabled"))
print(spark.conf.get("spark.sql.shuffle.partitions"))

false
5


In [17]:
joined_df = fact_df.join(dim_df, on="drug_code", how="leftouter")

In [18]:
def rand():
    return random.randint(0,4) 
udfrand = udf(rand)

In [19]:
salted_fact_df = fact_df.withColumn("skew_key", lit(udfrand()))
salted_dim_df = dim_df.withColumn("skew_key", explode(lit([0,1,2,3,4])))
salted_join_df = salted_fact_df.join(salted_dim_df, on=["drug_code","skew_key"], how="leftouter")

In [20]:
def count_in_each_partition(dataFrame):
    partition_df = dataFrame.withColumn("partition_id", spark_partition_id()).groupBy("partition_id").agg(count("*")).orderBy("partition_id", ascending=True)
    partition_df.show()

In [21]:
count_in_each_partition(joined_df)
count_in_each_partition(salted_join_df)

+------------+--------+
|partition_id|count(1)|
+------------+--------+
|           0|      24|
|           1|      34|
|           2|      37|
|           3|      34|
|           4|      71|
+------------+--------+

+------------+--------+
|partition_id|count(1)|
+------------+--------+
|           0|      71|
|           1|      32|
|           2|      36|
|           3|      39|
|           4|      22|
+------------+--------+

