In [43]:
import random
import time
from pyspark.sql import SparkSession
from pyspark.sql.functions import broadcast

In [44]:
spark = (
    SparkSession.builder.appName("Broadcast Example")
    .getOrCreate()
)

# Surprass warnings
spark.sparkContext.setLogLevel("ERROR")

# Disable Spark auto broadcast
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)

# Disable Sort Merge Join
spark.conf.set("spark.sql.join.preferSortMergeJoin", "false")

# Disabling WholeStageCodegen to see the actual physical operators in the execution plan
spark.conf.set("spark.sql.codegen.wholeStage", "false")

In [45]:
size = 1_000_000
fact_table_data = [
    (i % 1000, round(random.uniform(10.0, 1000.0), 2)) for i in range(1, size)
]
fact_table_df = spark.createDataFrame(fact_table_data, ["id", "value"])

In [46]:
size = 1_000
dimension_table_data = [
    (i, random.choice(["A", "B", "C", "D", "E", "F"])) for i in range(1, size)
]
dimension_table_df = spark.createDataFrame(dimension_table_data, ["id", "category"])

In [47]:
joined_df_without_broadcast = fact_table_df.hint("SHUFFLE_HASH").join(
    dimension_table_df, on="id", how="inner"
)

In [48]:
joined_df_with_broadcast = fact_table_df.join(
    broadcast(dimension_table_df), on="id", how="inner"
)

In [49]:
start_time = time.time()
joined_df_without_broadcast.count()
execution_time = time.time() - start_time
print(f"\tExecution time: {execution_time:.2f} seconds")

[Stage 0:==>                (1 + 7) / 8][Stage 1:>                  (0 + 1) / 8]

	Execution time: 3.10 seconds


                                                                                

In [50]:
start_time = time.time()
joined_df_with_broadcast.count()
execution_time = time.time() - start_time
print(f"\tExecution time: {execution_time:.2f} seconds")

	Execution time: 0.59 seconds


In [51]:
joined_df_with_broadcast.explain()

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Project [id#137L, value#138, category#142]
   +- BroadcastHashJoin [id#137L], [id#141L], Inner, BuildRight, false
      :- Filter isnotnull(id#137L)
      :  +- Scan ExistingRDD[id#137L,value#138]
      +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, bigint, false]),false), [plan_id=1285]
         +- Filter isnotnull(id#141L)
            +- Scan ExistingRDD[id#141L,category#142]




In [52]:
joined_df_without_broadcast.explain()

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Project [id#137L, value#138, category#142]
   +- ShuffledHashJoin [id#137L], [id#141L], Inner, BuildLeft
      :- Exchange hashpartitioning(id#137L, 200), ENSURE_REQUIREMENTS, [plan_id=1313]
      :  +- Filter isnotnull(id#137L)
      :     +- Scan ExistingRDD[id#137L,value#138]
      +- Exchange hashpartitioning(id#141L, 200), ENSURE_REQUIREMENTS, [plan_id=1314]
         +- Filter isnotnull(id#141L)
            +- Scan ExistingRDD[id#141L,category#142]


