# Spark Salting technique

In [1]:
# Create Spark Session

from pyspark.sql import SparkSession

spark = SparkSession \
    .builder \
    .appName("Salting technique") \
    .master("local[*]") \
    .getOrCreate()

spark

In [2]:
# Lets create the example dataset of fact and dimesion we would use for demonstration
# Python program to generate random Fact table data
# [1, ,"ORD1001", "D102", 56]
import random


def generate_fact_data(counter=100):
    fact_records = []
    dim_keys = ["D100", "D101", "D102", "D103", "D104"]
    order_ids = ["ORD" + str(i) for i in range(1001, 1010)]
    qty_range = [i for i in range(10, 120)]
    for i in range(counter):
        _record = [i, random.choice(order_ids), random.choice(dim_keys), random.choice(qty_range)]
        fact_records.append(_record)
    return fact_records

# We will generate 200 records with random data in Fact to create skewness
fact_records = generate_fact_data(200)

dim_records = [
    ["D100", "Product A"],
    ["D101", "Product B"],
    ["D102", "Product C"],
    ["D103", "Product D"],
    ["D104", "Product E"]
]

_fact_cols = ["id", "order_id", "prod_id", "qty"]s
_dim_cols = ["prod_id", "prod_name"]

In [6]:
# Generate Fact Data Frame
fact_df = spark.createDataFrame(data = fact_records, schema=_fact_cols)

fact_df.printSchema()
fact_df.show(10, truncate = False)

root
 |-- id: long (nullable = true)
 |-- order_id: string (nullable = true)
 |-- prod_id: string (nullable = true)
 |-- qty: long (nullable = true)

+---+--------+-------+---+
|id |order_id|prod_id|qty|
+---+--------+-------+---+
|0  |ORD1009 |D101   |50 |
|1  |ORD1007 |D102   |55 |
|2  |ORD1007 |D103   |87 |
|3  |ORD1007 |D103   |23 |
|4  |ORD1007 |D103   |15 |
|5  |ORD1004 |D100   |20 |
|6  |ORD1009 |D103   |57 |
|7  |ORD1002 |D100   |51 |
|8  |ORD1008 |D101   |61 |
|9  |ORD1004 |D101   |74 |
+---+--------+-------+---+
only showing top 10 rows



In [7]:
# Generate Prod Dim Data Frame
dim_df = spark.createDataFrame(data = dim_records, schema=_dim_cols)

dim_df.printSchema()
dim_df.show(10, False)

root
 |-- prod_id: string (nullable = true)
 |-- prod_name: string (nullable = true)

+-------+---------+
|prod_id|prod_name|
+-------+---------+
|D100   |Product A|
|D101   |Product B|
|D102   |Product C|
|D103   |Product D|
|D104   |Product E|
+-------+---------+



In [42]:
# Set Spark parameters - We have to turn off AQL to demonstrate Salting
spark.conf.set("spark.sql.adaptive.enabled", False)
spark.conf.set("spark.sql.shuffle.partitions", 5)

# Check the parameters
print(spark.conf.get("spark.sql.adaptive.enabled"))
print(spark.conf.get("spark.sql.shuffle.partitions"))

false
5


In [45]:
# Lets join the fact and dim without salting
joined_df = fact_df.join(dim_df, on="prod_id", how="leftouter")
joined_df.show(10, False)

+-------+---+--------+---+---------+
|prod_id|id |order_id|qty|prod_name|
+-------+---+--------+---+---------+
|D103   |2  |ORD1007 |87 |Product D|
|D103   |3  |ORD1007 |23 |Product D|
|D103   |4  |ORD1007 |15 |Product D|
|D103   |6  |ORD1009 |57 |Product D|
|D103   |10 |ORD1009 |99 |Product D|
|D103   |26 |ORD1008 |66 |Product D|
|D103   |34 |ORD1004 |63 |Product D|
|D103   |53 |ORD1006 |102|Product D|
|D103   |55 |ORD1002 |90 |Product D|
|D103   |59 |ORD1009 |19 |Product D|
+-------+---+--------+---+---------+
only showing top 10 rows



In [46]:
# Check the partition details to understand distribution
from pyspark.sql.functions import spark_partition_id, count

partition_df = joined_df.withColumn("partition_num", spark_partition_id()).groupBy("partition_num").agg(count("id"))
partition_df.show()

+-------------+---------+
|partition_num|count(id)|
+-------------+---------+
|            4|      119|
|            2|       81|
+-------------+---------+



In [75]:
# Let prepare the salt
import random
from pyspark.sql.functions import udf

# UDF to return a random number every time
def rand(): return random.randint(0, 4) #Since we are distributing the data in 5 partitions
rand_udf = udf(rand)

# Salt Data Frame to add to dimension
salt_df = spark.range(0, 5)
salt_df.show()

+---+
| id|
+---+
|  0|
|  1|
|  2|
|  3|
|  4|
+---+



In [91]:
# Salted Fact
from pyspark.sql.functions import lit, expr, concat

salted_fact_df = fact_df.withColumn("salted_prod_id", concat("prod_id",lit("_"), lit(rand_udf())))
salted_fact_df.show(10, False)

+---+--------+-------+---+--------------+
|id |order_id|prod_id|qty|salted_prod_id|
+---+--------+-------+---+--------------+
|0  |ORD1009 |D101   |50 |D101_4        |
|1  |ORD1007 |D102   |55 |D102_4        |
|2  |ORD1007 |D103   |87 |D103_0        |
|3  |ORD1007 |D103   |23 |D103_2        |
|4  |ORD1007 |D103   |15 |D103_4        |
|5  |ORD1004 |D100   |20 |D100_0        |
|6  |ORD1009 |D103   |57 |D103_2        |
|7  |ORD1002 |D100   |51 |D100_3        |
|8  |ORD1008 |D101   |61 |D101_2        |
|9  |ORD1004 |D101   |74 |D101_1        |
+---+--------+-------+---+--------------+
only showing top 10 rows



In [90]:
# Salted DIM
salted_dim_df = dim_df.join(salt_df, how="cross").withColumn("salted_prod_id", concat("prod_id", lit("_"), "id")).drop("id")

salted_dim_df.show()

+-------+---------+--------------+
|prod_id|prod_name|salted_prod_id|
+-------+---------+--------------+
|   D100|Product A|        D100_0|
|   D100|Product A|        D100_1|
|   D100|Product A|        D100_2|
|   D100|Product A|        D100_3|
|   D100|Product A|        D100_4|
|   D101|Product B|        D101_0|
|   D101|Product B|        D101_1|
|   D101|Product B|        D101_2|
|   D101|Product B|        D101_3|
|   D101|Product B|        D101_4|
|   D102|Product C|        D102_0|
|   D102|Product C|        D102_1|
|   D102|Product C|        D102_2|
|   D102|Product C|        D102_3|
|   D102|Product C|        D102_4|
|   D103|Product D|        D103_0|
|   D103|Product D|        D103_1|
|   D103|Product D|        D103_2|
|   D103|Product D|        D103_3|
|   D103|Product D|        D103_4|
+-------+---------+--------------+
only showing top 20 rows



In [78]:
# Lets make the salted join now
salted_joined_df = salted_fact_df.join(salted_dim_df, on="salted_prod_id", how="leftouter")
salted_joined_df.show(10, False)

+--------------+---+--------+-------+---+-------+---------+
|salted_prod_id|id |order_id|prod_id|qty|prod_id|prod_name|
+--------------+---+--------+-------+---+-------+---------+
|D100_0        |126|ORD1006 |D100   |44 |D100   |Product A|
|D100_0        |185|ORD1004 |D100   |20 |D100   |Product A|
|D100_1        |106|ORD1008 |D100   |106|D100   |Product A|
|D100_1        |175|ORD1005 |D100   |54 |D100   |Product A|
|D104_1        |50 |ORD1002 |D104   |101|D104   |Product E|
|D104_1        |81 |ORD1006 |D104   |32 |D104   |Product E|
|D104_1        |177|ORD1002 |D104   |23 |D104   |Product E|
|D104_1        |178|ORD1007 |D104   |11 |D104   |Product E|
|D101_1        |103|ORD1007 |D101   |115|D101   |Product B|
|D103_1        |2  |ORD1007 |D103   |87 |D103   |Product D|
+--------------+---+--------+-------+---+-------+---------+
only showing top 10 rows



In [96]:
# Check the partition details to understand distribution
from pyspark.sql.functions import spark_partition_id, count

partition_df = salted_joined_df.withColumn("partition_num", spark_partition_id()).groupBy("partition_num") \
    .agg(count(lit(1)).alias("count")).orderBy("partition_num")

partition_df.show()

+-------------+-----+
|partition_num|count|
+-------------+-----+
|            0|   18|
|            1|   64|
|            2|   24|
|            3|   61|
|            4|   33|
+-------------+-----+

