A `skewed dataset` is defined by a dataset that has a class imbalance, this leads to poor or failing spark jobs that often get a `OOM` (out of memory) error.

When performing a `join` onto a `skewed dataset` it's usually the case where there is an imbalance on the `key`(s) on which the join is performed on. This results in a majority of the data falls onto a single partition, which will take longer to complete than the other partitions.

Some hints to detect skewness is:
1. The `key`(s) consist mainly of `null` values which fall onto a single partition.
2. There is a subset of values for the `key`(s) that makeup the high percentage of the total keys which fall onto a single partition.

We go through both these cases and see how we can combat it.

### Library Imports

In [1]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F

### Template

In [2]:
spark = (
    SparkSession.builder
    .master("local")
    .appName("Exploring Joins")
    .config("spark.some.config.option", "some-value")
    .getOrCreate()
)

sc = spark.sparkContext

### Situation 1: Null Keys

Inital Datasets

In [3]:
customers = spark.createDataFrame([
    (1, None), 
    (2, None), 
    (3, 1),
], ["id", "card_id"])

customers.show()

+---+-------+
| id|card_id|
+---+-------+
|  1|   null|
|  2|   null|
|  3|      1|
+---+-------+



In [4]:
cards = spark.createDataFrame([
    (1, "john", "doe", 21), 
    (2, "rick", "roll", 10), 
    (3, "bob", "brown", 2)
], ["card_id", "first_name", "last_name", "age"])

cards.show()

+-------+----------+---------+---+
|card_id|first_name|last_name|age|
+-------+----------+---------+---+
|      1|      john|      doe| 21|
|      2|      rick|     roll| 10|
|      3|       bob|    brown|  2|
+-------+----------+---------+---+



### Option #1: Join Regularly

In [5]:
df = customers.join(cards, "card_id", "left")
df.show()

+-------+---+----------+---------+----+
|card_id| id|first_name|last_name| age|
+-------+---+----------+---------+----+
|   null|  1|      null|     null|null|
|   null|  2|      null|     null|null|
|      1|  3|      john|      doe|  21|
+-------+---+----------+---------+----+



In [6]:
df.explain()

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Project [card_id#1L, id#0L, first_name#14, last_name#15, age#16L]
   +- SortMergeJoin [card_id#1L], [card_id#13L], LeftOuter
      :- Sort [card_id#1L ASC NULLS FIRST], false, 0
      :  +- Exchange hashpartitioning(card_id#1L, 200), ENSURE_REQUIREMENTS, [plan_id=164]
      :     +- Scan ExistingRDD[id#0L,card_id#1L]
      +- Sort [card_id#13L ASC NULLS FIRST], false, 0
         +- Exchange hashpartitioning(card_id#13L, 200), ENSURE_REQUIREMENTS, [plan_id=165]
            +- Filter isnotnull(card_id#13L)
               +- Scan ExistingRDD[card_id#13L,first_name#14,last_name#15,age#16L]




In [7]:
df = customers.join(cards, "card_id")
df.show()

+-------+---+----------+---------+---+
|card_id| id|first_name|last_name|age|
+-------+---+----------+---------+---+
|      1|  3|      john|      doe| 21|
+-------+---+----------+---------+---+



In [8]:
df.explain()

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Project [card_id#1L, id#0L, first_name#14, last_name#15, age#16L]
   +- SortMergeJoin [card_id#1L], [card_id#13L], Inner
      :- Sort [card_id#1L ASC NULLS FIRST], false, 0
      :  +- Exchange hashpartitioning(card_id#1L, 200), ENSURE_REQUIREMENTS, [plan_id=305]
      :     +- Filter isnotnull(card_id#1L)
      :        +- Scan ExistingRDD[id#0L,card_id#1L]
      +- Sort [card_id#13L ASC NULLS FIRST], false, 0
         +- Exchange hashpartitioning(card_id#13L, 200), ENSURE_REQUIREMENTS, [plan_id=306]
            +- Filter isnotnull(card_id#13L)
               +- Scan ExistingRDD[card_id#13L,first_name#14,last_name#15,age#16L]




**What Happened**:
* Rows that didn't join up were brought to the join.

* For a `left join`, they will get `Null` values for the right side columns, what's the point of being them in?
* For a `inner join`, they rows will get dropped, so again what's the point of being them in?

**Results**:
* We brought more rows to the join than we had to. These rows get normally get put onto a single partition. 
* If the data is large enough and the percentage of keys that are null is high. The program could OOM out.

### Option #2: Filter Null Keys First, then Join, then Union

In [9]:
def null_skew_helper(left, right, key):
    """
    Steps:
        1. Filter out the null rows.
        2. Create the columns you would get from the join.
        3. Join the tables.
        4. Union the null rows to joined table.
    """
    lr_null_rows = left.where(F.col(key).isNull())
    for f in right.schema.fields:
            lr_null_rows = lr_null_rows.withColumn(f.name, F.lit(None).cast(f.dataType))
    left_non_null_rows = left.where(F.col(key).isNotNull())
    lr_non_null_rows = left_non_null_rows.join(right, key, "left")
    return lr_null_rows.union(lr_non_null_rows.select(lr_null_rows.columns))
    

df = null_skew_helper(customers, cards, "card_id")
df.show()

+---+-------+----------+---------+----+
| id|card_id|first_name|last_name| age|
+---+-------+----------+---------+----+
|  1|   null|      null|     null|null|
|  2|   null|      null|     null|null|
|  3|      1|      john|      doe|  21|
+---+-------+----------+---------+----+



In [10]:
df.explain()

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Union
   :- Project [id#0L, null AS card_id#90L, null AS first_name#93, null AS last_name#97, null AS age#102L]
   :  +- Filter isnull(card_id#1L)
   :     +- Scan ExistingRDD[id#0L,card_id#1L]
   +- Project [id#118L, card_id#119L, first_name#14, last_name#15, age#16L]
      +- SortMergeJoin [card_id#119L], [card_id#13L], LeftOuter
         :- Sort [card_id#119L ASC NULLS FIRST], false, 0
         :  +- Exchange hashpartitioning(card_id#119L, 200), ENSURE_REQUIREMENTS, [plan_id=581]
         :     +- Filter isnotnull(card_id#119L)
         :        +- Scan ExistingRDD[id#118L,card_id#119L]
         +- Sort [card_id#13L ASC NULLS FIRST], false, 0
            +- Exchange hashpartitioning(card_id#13L, 200), ENSURE_REQUIREMENTS, [plan_id=582]
               +- Filter isnotnull(card_id#13L)
                  +- Scan ExistingRDD[card_id#13L,first_name#14,last_name#15,age#16L]




**What Happened**:
* We seperated the data into 2 sets:
  * one where the `key`s are not `null`.
  * one where the `key`s are `null`.
* We perform the join on the set where the keys are not null, then union it back with the set where the keys are null. (This step is not necessary when doing an inner join).

**Results**:
* We brought less data to the join.
* We read the data twice; more time was spent on reading data from disk.

### Option #3: Cache the Table, Filter Null Keys First, then Join, then Union

**Helper Function**

In [11]:
def null_skew_helper(left, right, key):
    """
    Steps:
        1. Cache table.
        2. Filter out the null rows.
        3. Create the columns you would get from the join.
        4. Join the tables.
        5. Union the null rows to joined table.
    """
    left = left.cache()
    df1 = left.where(F.col(key).isNull())
    for f in right.schema.fields:
            df1 = df1.withColumn(f.name, F.lit(None).cast(f.dataType))
    df2 = left.where(F.col(key).isNotNull())
    df2 = df2.join(right, key, "left")
    return df1.union(df2.select(df1.columns))

In [12]:
df = null_skew_helper(customers, cards, "card_id")
df.show()

+---+-------+----------+---------+----+
| id|card_id|first_name|last_name| age|
+---+-------+----------+---------+----+
|  1|   null|      null|     null|null|
|  2|   null|      null|     null|null|
|  3|      1|      john|      doe|  21|
+---+-------+----------+---------+----+



In [13]:
df.explain()

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Union
   :- Project [id#0L, null AS card_id#161L, null AS first_name#164, null AS last_name#168, null AS age#173L]
   :  +- Filter isnull(card_id#1L)
   :     +- InMemoryTableScan [card_id#1L, id#0L], [isnull(card_id#1L)]
   :           +- InMemoryRelation [id#0L, card_id#1L], StorageLevel(disk, memory, deserialized, 1 replicas)
   :                 +- *(1) Scan ExistingRDD[id#0L,card_id#1L]
   +- Project [id#189L, card_id#190L, first_name#14, last_name#15, age#16L]
      +- SortMergeJoin [card_id#190L], [card_id#13L], LeftOuter
         :- Sort [card_id#190L ASC NULLS FIRST], false, 0
         :  +- Exchange hashpartitioning(card_id#190L, 200), ENSURE_REQUIREMENTS, [plan_id=849]
         :     +- Filter isnotnull(card_id#190L)
         :        +- InMemoryTableScan [id#189L, card_id#190L], [isnotnull(card_id#190L)]
         :              +- InMemoryRelation [id#189L, card_id#190L], StorageLevel(disk, memory, deserialized, 1 r

**What Happened**:
* Similar to option #2, but we did a `InMemoryTableScan` instead of two reads of the data.

**Results**:
* We brought less data to the join.
* We did 1 less read, but we used more memory.

### Summary

All to say:
* It's definitely better to bring less data to a join, so performing a filter for `null keys` before the join is definitely suggested.
* For `left join`s:
    * By doing a union, this will result in an extra read of data or memory usage.
    * Decide what you can afford; the extra read vs memory usage and `cache` the table before the `filter`.

Always check the spread the values for the `join key`, to detect if there's any skew and pre filters that can be performed.