In [1]:
# Spark Session
from pyspark.sql import SparkSession

spark = (
    SparkSession
    .builder
    .appName("Understand Plans and DAG")
    .master("local[*]")
    .getOrCreate()
)

spark

In [2]:
# Disable AQE and Broadcast join

spark.conf.set("spark.sql.adaptive.enabled", False)
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", False)
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)

In [3]:
# Check default Parallism

spark.sparkContext.defaultParallelism

8

In [4]:
# Create dataframes

df_1 = spark.range(4, 200, 2)
df_2 = spark.range(2, 200, 4)

In [6]:
df_2.rdd.getNumPartitions()

8

In [7]:
# Re-partition data

df_3 = df_1.repartition(5)
df_4 = df_2.repartition(7)

In [9]:
df_4.rdd.getNumPartitions()

7

In [10]:
# Join the dataframes

df_joined = df_3.join(df_4, on="id")

In [11]:
# Get the sum of ids

df_sum = df_joined.selectExpr("sum(id) as total_sum")

In [12]:
# View data
df_sum.show()

+---------+
|total_sum|
+---------+
|     4998|
+---------+



In [13]:
# Explain plan

df_sum.explain()

== Physical Plan ==
*(6) HashAggregate(keys=[], functions=[sum(id#0L)])
+- Exchange SinglePartition, ENSURE_REQUIREMENTS, [id=#182]
   +- *(5) HashAggregate(keys=[], functions=[partial_sum(id#0L)])
      +- *(5) Project [id#0L]
         +- *(5) SortMergeJoin [id#0L], [id#2L], Inner
            :- *(2) Sort [id#0L ASC NULLS FIRST], false, 0
            :  +- Exchange hashpartitioning(id#0L, 200), ENSURE_REQUIREMENTS, [id=#166]
            :     +- Exchange RoundRobinPartitioning(5), REPARTITION_BY_NUM, [id=#165]
            :        +- *(1) Range (4, 200, step=2, splits=8)
            +- *(4) Sort [id#2L ASC NULLS FIRST], false, 0
               +- Exchange hashpartitioning(id#2L, 200), ENSURE_REQUIREMENTS, [id=#173]
                  +- Exchange RoundRobinPartitioning(7), REPARTITION_BY_NUM, [id=#172]
                     +- *(3) Range (2, 200, step=4, splits=8)




In [14]:
# Union the data again to see the skipped stages

df_union = df_sum.union(df_4)

In [15]:
df_union.show()

+---------+
|total_sum|
+---------+
|     4998|
|       14|
|       38|
|       50|
|       74|
|      110|
|      130|
|      154|
|      186|
|       10|
|       30|
|       54|
|       98|
|      118|
|      138|
|      158|
|      178|
|        6|
|       42|
|       70|
+---------+
only showing top 20 rows



In [16]:
# Explain plan

df_union.explain()

== Physical Plan ==
Union
:- *(6) HashAggregate(keys=[], functions=[sum(id#0L)])
:  +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [id=#420]
:     +- *(5) HashAggregate(keys=[], functions=[partial_sum(id#0L)])
:        +- *(5) Project [id#0L]
:           +- *(5) SortMergeJoin [id#0L], [id#2L], Inner
:              :- *(2) Sort [id#0L ASC NULLS FIRST], false, 0
:              :  +- Exchange hashpartitioning(id#0L, 200), ENSURE_REQUIREMENTS, [id=#404]
:              :     +- Exchange RoundRobinPartitioning(5), REPARTITION_BY_NUM, [id=#403]
:              :        +- *(1) Range (4, 200, step=2, splits=8)
:              +- *(4) Sort [id#2L ASC NULLS FIRST], false, 0
:                 +- Exchange hashpartitioning(id#2L, 200), ENSURE_REQUIREMENTS, [id=#411]
:                    +- Exchange RoundRobinPartitioning(7), REPARTITION_BY_NUM, [id=#410]
:                       +- *(3) Range (2, 200, step=4, splits=8)
+- ReusedExchange [id#20L], Exchange RoundRobinPartitioning(7), REPARTITION_BY_N

In [17]:
# DataFrame to RDD

df_1.rdd

MapPartitionsRDD[5] at javaToPython at NativeMethodAccessorImpl.java:0