In [1]:
# enable pyspark
import findspark
findspark.init()

In [2]:
'''
Scripts instantiates a SparkSession locally with 8 worker threads.
'''
appName = "Join Strategies - broadcast, Shuffle hash and sort-merge join, "
master = "local[8]"
from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession
# Create Spark session
conf = SparkConf().setMaster(master).setAppName(appName)
spark = SparkSession.builder.config(conf=conf) \
    .enableHiveSupport() \
    .getOrCreate()
# INFO/WARN/DEBUG
# https://kontext.tech/column/spark/457/tutorial-turn-off-info-logs-in-spark
spark.sparkContext.setLogLevel("INFO")

### Broadcast Variables
In PySpark Broadcast variables are read-only shared variables that are cached and available on all nodes in
a cluster to be used by the tasks. **Instead of sending this data along with every task, pySpark caches the broadcast
variable - this lookup info, on each node/machine**. The tasks use this cached info while executing the transformations. Each
node/executor depending on number of cores could be running lots of tasks(spark recommends 2-3 tasks per CPU core).


**PySpark RDD Broadcast variable example**
Below is a very simple example of how to use broadcast variables on RDD. This example defines commonly used data (states) in a Map variable and distributes the variable using SparkContext.broadcast() and then use these variables on RDD map() transformation.

ref: https://sparkbyexamples.com/pyspark/pyspark-broadcast-variables/

In [4]:
states = {"NY":"New York", "CA":"California", "FL":"Florida"}
broadcastStates = spark.sparkContext.broadcast(states)
data = [("James","Smith","USA","CA"),
    ("Michael","Rose","USA","NY"),
    ("Robert","Williams","USA","CA"),
    ("Maria","Jones","USA","FL")
  ]

rdd = spark.sparkContext.parallelize(data)

def state_convert(code):
    return broadcastStates.value[code]

result = rdd.map(lambda x: (x[0],x[1],x[2],state_convert(x[3]))).collect()
print(result)


[('James', 'Smith', 'USA', 'California'), ('Michael', 'Rose', 'USA', 'New York'), ('Robert', 'Williams', 'USA', 'California'), ('Maria', 'Jones', 'USA', 'Florida')]


In [5]:
# Print the tables in the catalog
print(spark.catalog.listTables())

[]


In [5]:
print(f'default value for spark.sql.autoBroadcastJoinThreshold: {spark.conf.get("spark.sql.autoBroadcastJoinThreshold")}')

default value for spark.sql.autoBroadcastJoinThreshold: 10485760b


In [10]:
df1 = spark.range(100)
df2 = spark.range(100)

First of all spark.sql.autoBroadcastJoinThreshold and broadcast hint are separate mechanisms. Even if autoBroadcastJoinThreshold is disabled setting broadcast hint will take precedence. With default settings:

In [12]:
df1.printSchema()
df1.show()
df1.count()

root
 |-- id: long (nullable = false)

+---+
| id|
+---+
|  0|
|  1|
|  2|
|  3|
|  4|
|  5|
|  6|
|  7|
|  8|
|  9|
| 10|
| 11|
| 12|
| 13|
| 14|
| 15|
| 16|
| 17|
| 18|
| 19|
+---+
only showing top 20 rows



100

Spark will use autoBroadcastJoinThreshold and automatically broadcast data:

In [16]:
df1.join(df2, df1.id==df2.id, "inner").explain()

== Physical Plan ==
*(2) BroadcastHashJoin [id#4L], [id#6L], Inner, BuildRight, false
:- *(2) Range (0, 100, step=1, splits=8)
+- BroadcastExchange HashedRelationBroadcastMode(List(input[0, bigint, false]),false), [id=#77]
   +- *(1) Range (0, 100, step=1, splits=8)




In [21]:
 # Register the DataFrame as a SQL temporary view
df1.createOrReplaceTempView("df1")
sqlDF1 = spark.sql("SELECT * FROM df1")
# sqlDF1.show()

df2.createOrReplaceTempView("df2")
sqlDF2 = spark.sql("SELECT * FROM df2")
# sqlDF2.show()

spark.sql(
 "SELECT  /*+ MAPJOIN(df2) */ * FROM df1 JOIN df2 ON df1.id = df2.id"
).explain()


== Physical Plan ==
*(2) BroadcastHashJoin [id#4L], [id#6L], Inner, BuildRight, false
:- *(2) Range (0, 100, step=1, splits=8)
+- BroadcastExchange HashedRelationBroadcastMode(List(input[0, bigint, false]),false), [id=#161]
   +- *(1) Range (0, 100, step=1, splits=8)




In [34]:
# another way to hint broadcast with sql
spark.sql(
 "SELECT  /*+ BROADCAST(df2) */ * FROM df1 JOIN df2 ON df1.id = df2.id"
).explain()

== Physical Plan ==
*(2) BroadcastHashJoin [id#4L], [id#6L], Inner, BuildRight, false
:- *(2) Range (0, 100, step=1, splits=8)
+- BroadcastExchange HashedRelationBroadcastMode(List(input[0, bigint, false]),false), [id=#399]
   +- *(1) Range (0, 100, step=1, splits=8)




When we disable auto broadcast Spark will use standard SortMergeJoin: 


In [29]:
# When we disable auto broadcast Spark will use standard SortMergeJoin:
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)
df1.join(df2, df1.id==df2.id, "inner").explain()

== Physical Plan ==
*(5) SortMergeJoin [id#4L], [id#6L], Inner
:- *(2) Sort [id#4L ASC NULLS FIRST], false, 0
:  +- Exchange hashpartitioning(id#4L, 200), ENSURE_REQUIREMENTS, [id=#293]
:     +- *(1) Range (0, 100, step=1, splits=8)
+- *(4) Sort [id#6L ASC NULLS FIRST], false, 0
   +- ReusedExchange [id#6L], Exchange hashpartitioning(id#4L, 200), ENSURE_REQUIREMENTS, [id=#293]




In [30]:
spark.sql(
 "SELECT   * FROM df1 JOIN df2 ON df1.id = df2.id"
).explain()


== Physical Plan ==
*(5) SortMergeJoin [id#4L], [id#6L], Inner
:- *(2) Sort [id#4L ASC NULLS FIRST], false, 0
:  +- Exchange hashpartitioning(id#4L, 200), ENSURE_REQUIREMENTS, [id=#339]
:     +- *(1) Range (0, 100, step=1, splits=8)
+- *(4) Sort [id#6L ASC NULLS FIRST], false, 0
   +- ReusedExchange [id#6L], Exchange hashpartitioning(id#4L, 200), ENSURE_REQUIREMENTS, [id=#339]




When we disable auto broadcast Spark will use standard SortMergeJoin, **but can forced to use BroadcastHashJoin with broadcast hint**

In [33]:
from pyspark.sql.functions import broadcast
df1.join(broadcast(df2), df1.id==df2.id, "inner").explain()

== Physical Plan ==
*(2) BroadcastHashJoin [id#4L], [id#6L], Inner, BuildRight, false
:- *(2) Range (0, 100, step=1, splits=8)
+- BroadcastExchange HashedRelationBroadcastMode(List(input[0, bigint, false]),false), [id=#381]
   +- *(1) Range (0, 100, step=1, splits=8)


