In [1]:
from pyspark.sql import SparkSession

In [2]:
spark = (SparkSession.builder 
    .appName("Spark broadcast join") 
    .master("local[*]")
    .config("spark.sql.adaptive.enabled", "false")
    .getOrCreate())

In [3]:
spark

In [4]:
# Disabling AQE (Adaptive query execution) to test the broadcast joins

spark.conf.set("spark.sql.adaptive.enabled", False)
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", False)
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)

In [5]:
trip_data = spark.read.csv("data/mibici_2014-2024/mibici_2014-2024.csv", header=True, inferSchema=True)

In [6]:
location_data = spark.read.csv('data/nomenclature_2024.csv', header=True, inferSchema=True)

In [7]:
trip_data.printSchema()

root
 |-- _c0: integer (nullable = true)
 |-- Trip_Id: integer (nullable = true)
 |-- User_Id: integer (nullable = true)
 |-- Sex: string (nullable = true)
 |-- Birth_year: integer (nullable = true)
 |-- Trip_start: timestamp (nullable = true)
 |-- Trip_end: timestamp (nullable = true)
 |-- Origin_Id: integer (nullable = true)
 |-- Destination_Id: integer (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Duration: string (nullable = true)



In [8]:
location_data.printSchema()

root
 |-- id: integer (nullable = true)
 |-- name: string (nullable = true)
 |-- obcn: string (nullable = true)
 |-- location: string (nullable = true)
 |-- latitude: double (nullable = true)
 |-- longitude: double (nullable = true)
 |-- status: string (nullable = true)



In [9]:
# code without broadcast join

from pyspark.sql.functions import col      # importing col from the pyspark sql funtions 

start_location = location_data.alias("location_data")
end_location = location_data.alias("location_data")

enriched_data = (
    trip_data
    .join(start_location.alias('s'), trip_data["Origin_Id"] == col('s.id'))
    .join(end_location.alias('e'), trip_data["Destination_Id"] == col('e.id'))
    .select(trip_data["trip_id"], col('s.name').alias("start_location"), col('e.name').alias("drop_location"))
)

# enriched_data.show(10)

enriched_data.write.format('noop').mode("overwrite").save()

In [10]:
# code with the broadcast join

from pyspark.sql.functions import col, broadcast # need to import the col and broadcast functiions from pyspark sql functions

start_location = location_data.alias("location_data")
end_location = location_data.alias("location_data")

enriched_data = (
    trip_data
    .join(broadcast(start_location.alias('s')), trip_data["Origin_Id"] == col('s.id'))
    .join(broadcast(end_location.alias('e')), trip_data["Destination_Id"] == col('e.id'))
    .select(trip_data["trip_id"], col('s.name').alias("start_location"), col('e.name').alias("drop_location"))
)

# enriched_data.show(10)


# Writing the data as "noop" to benchmark the performance without writting the data anywhere
enriched_data.write.format('noop').mode("overwrite").save()  

In [11]:
# creating the temp views to test the same using Spark SQL

location_data.createOrReplaceTempView('location_view') # view for Locations
trip_data.createOrReplaceTempView('trip_view') # view for Trips

In [12]:
# SQL query for joining trips and location table to get the start and drop locations

enriched_data = spark.sql(
    """
    select t.trip_id, s.name as start_location, e.name as drop_location
    from trip_view t
    join location_view s on t.Origin_Id = s.id
    join location_view e on t.Destination_Id = e.id
    """
)
enriched_data.write.format('noop').mode("overwrite").save()

In [13]:
# enriched_data.filter(enriched_data['trip_id'] == 32244893).show(10)

In [14]:
# trip_data.count()

In [15]:
# trip_data.show(10)