In [1]:
from pyspark.sql import SparkSession

In [2]:
spark = (SparkSession.builder 
    .appName("Spark broadcast join") 
    .master("local[*]")
    # .config("spark.executor.cores", "5")
    # .config("spark.driver.memory" , "10g")
    .config("spark.sql.catalogImplementation" , "hive")
    .getOrCreate())

# spark.conf.set("spark.sql.shuffle.partitions",200)

In [3]:
spark

In [4]:
# Disabling AQE (Adaptive query execution) to test the broadcast joins

spark.conf.set("spark.sql.adaptive.enabled", False)
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", False)
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)

In [5]:
trip_data = spark.read.csv("data/mibici_2014-2024/mibici_2014-2024.csv", header=True, inferSchema=True)

In [6]:
location_data = spark.read.csv('data/nomenclature_2024.csv', header=True, inferSchema=True)

In [7]:
trip_data.printSchema()

root
 |-- _c0: integer (nullable = true)
 |-- Trip_Id: integer (nullable = true)
 |-- User_Id: integer (nullable = true)
 |-- Sex: string (nullable = true)
 |-- Birth_year: integer (nullable = true)
 |-- Trip_start: timestamp (nullable = true)
 |-- Trip_end: timestamp (nullable = true)
 |-- Origin_Id: integer (nullable = true)
 |-- Destination_Id: integer (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Duration: string (nullable = true)



In [8]:
location_data.printSchema()

root
 |-- id: integer (nullable = true)
 |-- name: string (nullable = true)
 |-- obcn: string (nullable = true)
 |-- location: string (nullable = true)
 |-- latitude: double (nullable = true)
 |-- longitude: double (nullable = true)
 |-- status: string (nullable = true)



In [17]:
# code without broadcast join

from pyspark.sql.functions import col      # importing col from the pyspark sql funtions 

start_location = location_data.alias("location_data")
end_location = location_data.alias("location_data")

enriched_data = (
    trip_data
    .join(start_location.alias('s'), trip_data["Origin_Id"] == col('s.id'))
    .join(end_location.alias('e'), trip_data["Destination_Id"] == col('e.id'))
    .select(trip_data["trip_id"], col('s.name').alias("start_location"), col('e.name').alias("drop_location"))
)

enriched_data.explain()

enriched_data.write.format('noop').mode("overwrite").save()

== Physical Plan ==
*(9) Project [trip_id#18, name#57 AS start_location#427, name#370 AS drop_location#428]
+- *(9) SortMergeJoin [Destination_Id#25], [id#369], Inner
   :- *(6) Sort [Destination_Id#25 ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(Destination_Id#25, 200), ENSURE_REQUIREMENTS, [plan_id=1063]
   :     +- *(5) Project [Trip_Id#18, Destination_Id#25, name#57]
   :        +- *(5) SortMergeJoin [Origin_Id#24], [id#56], Inner
   :           :- *(2) Sort [Origin_Id#24 ASC NULLS FIRST], false, 0
   :           :  +- Exchange hashpartitioning(Origin_Id#24, 200), ENSURE_REQUIREMENTS, [plan_id=1047]
   :           :     +- Coalesce 10
   :           :        +- *(1) Filter (isnotnull(Origin_Id#24) AND isnotnull(Destination_Id#25))
   :           :           +- FileScan csv [Trip_Id#18,Origin_Id#24,Destination_Id#25] Batched: false, DataFilters: [isnotnull(Origin_Id#24), isnotnull(Destination_Id#25)], Format: CSV, Location: InMemoryFileIndex(1 paths)[file:/C:/Users/

In [10]:
# code with the broadcast join

from pyspark.sql.functions import col, broadcast # need to import the col and broadcast functiions from pyspark sql functions

start_location = location_data.alias("location_data")
end_location = location_data.alias("location_data")

enriched_data = (
    trip_data
    .join(broadcast(start_location.alias('s')), trip_data["Origin_Id"] == col('s.id'))
    .join(broadcast(end_location.alias('e')), trip_data["Destination_Id"] == col('e.id'))
    .select(trip_data["trip_id"], col('s.name').alias("start_location"), col('e.name').alias("drop_location"))
)



enriched_data.explain()


# Writing the data as "noop" to benchmark the performance without writting the data anywhere
enriched_data.write.format('noop').mode("overwrite").save()  

== Physical Plan ==
*(3) Project [trip_id#18, name#57 AS start_location#273, name#216 AS drop_location#274]
+- *(3) BroadcastHashJoin [Destination_Id#25], [id#215], Inner, BuildRight, false
   :- *(3) Project [Trip_Id#18, Destination_Id#25, name#57]
   :  +- *(3) BroadcastHashJoin [Origin_Id#24], [id#56], Inner, BuildRight, false
   :     :- *(3) Filter (isnotnull(Origin_Id#24) AND isnotnull(Destination_Id#25))
   :     :  +- FileScan csv [Trip_Id#18,Origin_Id#24,Destination_Id#25] Batched: false, DataFilters: [isnotnull(Origin_Id#24), isnotnull(Destination_Id#25)], Format: CSV, Location: InMemoryFileIndex(1 paths)[file:/C:/Users/prana/Downloads/spark books/spark-learning/data/mibici_..., PartitionFilters: [], PushedFilters: [IsNotNull(Origin_Id), IsNotNull(Destination_Id)], ReadSchema: struct<Trip_Id:int,Origin_Id:int,Destination_Id:int>
   :     +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)),false), [plan_id=379]
   :        +- *(1) Filter

In [11]:
# creating the temp views to test the same using Spark SQL

location_data.createOrReplaceTempView('location_view') # view for Locations
trip_data.createOrReplaceTempView('trip_view') # view for Trips

In [12]:
# SQL query for joining trips and location table to get the start and drop locations

enriched_data = spark.sql(
    """
    select t.trip_id, s.name as start_location, e.name as drop_location
    from trip_view t
    join location_view s on t.Origin_Id = s.id
    join location_view e on t.Destination_Id = e.id
    """
)

enriched_data.explain()

enriched_data.write.format('noop').mode("overwrite").save()

== Physical Plan ==
*(9) Project [trip_id#18, name#57 AS start_location#286, name#289 AS drop_location#287]
+- *(9) SortMergeJoin [Destination_Id#25], [id#288], Inner
   :- *(6) Sort [Destination_Id#25 ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(Destination_Id#25, 200), ENSURE_REQUIREMENTS, [plan_id=583]
   :     +- *(5) Project [Trip_Id#18, Destination_Id#25, name#57]
   :        +- *(5) SortMergeJoin [Origin_Id#24], [id#56], Inner
   :           :- *(2) Sort [Origin_Id#24 ASC NULLS FIRST], false, 0
   :           :  +- Exchange hashpartitioning(Origin_Id#24, 200), ENSURE_REQUIREMENTS, [plan_id=567]
   :           :     +- *(1) Filter (isnotnull(Origin_Id#24) AND isnotnull(Destination_Id#25))
   :           :        +- FileScan csv [Trip_Id#18,Origin_Id#24,Destination_Id#25] Batched: false, DataFilters: [isnotnull(Origin_Id#24), isnotnull(Destination_Id#25)], Format: CSV, Location: InMemoryFileIndex(1 paths)[file:/C:/Users/prana/Downloads/spark books/spark-learning/d

In [13]:
# enriched_data.filter(enriched_data['trip_id'] == 32244893).show(10)

In [14]:
# trip_data.count()

In [15]:
# trip_data.show(10)

In [16]:
from pyspark.sql.functions import spark_partition_id
(
    enriched_data
    .withColumn("partition", spark_partition_id())
    .groupBy("partition")
    .count()
    .orderBy("partition")
    .show()
)

+---------+------+
|partition| count|
+---------+------+
|        0| 52582|
|        1| 27733|
|        2|121565|
|        3|231475|
|        4|241822|
|        5|323180|
|        6|126923|
|        7|170736|
|        8| 16886|
|        9|   580|
|       10|118644|
|       11|450350|
|       12| 66117|
|       13|116180|
|       14|325244|
|       18| 91449|
|       19|336680|
|       20|   564|
|       21|465929|
|       22| 84681|
+---------+------+
only showing top 20 rows

