This skew optimization is performed on a different dataset

In [None]:
import warnings
warnings.filterwarnings("ignore")

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'

from pyspark.storagelevel import StorageLevel
from pyspark.sql.types import *
import pyspark.sql.functions as F
from pyspark.sql import SparkSession

import os
# Set environment variables (local paths)
os.environ["JAVA_HOME"] = "D:/Programs/Java"
os.environ["HADOOP_HOME"] = "D:/Programs/hadoop"
os.environ["SPARK_HOME"] = "D:/Programs/spark/spark-3.5.6-bin-hadoop3"  # Adjust if different

import findspark
findspark.init("D:/Programs/spark/spark-3.5.6-bin-hadoop3")

In [3]:
def prepare_trips_data(spark: SparkSession) -> DataFrame:
    pu_loc_to_change = [
        236, 132, 161, 186, 142, 141, 48, 239, 170, 162, 230, 163, 79, 234, 263, 140, 238, 107, 68, 138, 229, 249,
        237, 164, 90, 43, 100, 246, 231, 262, 113, 233, 143, 137, 114, 264, 148, 151
    ]

    res_df = spark.read\
        .parquet("data/trips/*.parquet")\
        .withColumn(
            "PULocationID",
            F.when(F.col("PULocationID").isin(pu_loc_to_change), F.lit(237))
            .otherwise(F.col("PULocationID"))
        )
    return res_df

In [None]:
def join_on_skewed_data(spark: SparkSession):
    trips_data = prepare_trips_data(spark=spark)
    location_details_data = spark.read.option("header", True).csv("data/taxi+_zone_lookup.csv")

    trips_with_pickup_location_details = trips_data\
        .join(location_details_data, F.col("PULocationID") == F.col("LocationID"), "inner")

    # .groupBy("Zone") \

    trips_with_pickup_location_details \
        .groupBy("Borough") \
        .agg(F.avg("trip_distance").alias("avg_trip_distance")) \
        .sort(F.col("avg_trip_distance").desc()) \
        .show(truncate=False, n=1000)

In [None]:
# aqe disabled
def create_spark_session_with_aqe_disabled() -> SparkSession:
    conf = SparkConf() \
        .set("spark.driver.memory", "4G") \
        .set("spark.sql.autoBroadcastJoinThreshold", "-1") \
        .set("spark.sql.shuffle.partitions", "201") \
        .set("spark.sql.adaptive.enabled", "false")

    spark_session = SparkSession\
        .builder\
        .master("local[8]")\
        .config(conf=conf)\
        .appName("Read from JDBC tutorial") \
        .getOrCreate()

    return spark_session

if __name__ == '__main__':
    start_time = time.time()
    spark = create_spark_session_with_aqe_disabled()

    join_on_skewed_data(spark=spark)

    print(f"Elapsed_time: {(time.time() - start_time)} seconds")
   # time.sleep(10000)

+-------------+------------------+
|Borough      |avg_trip_distance |
+-------------+------------------+
|Brooklyn     |54.084378566790875|
|Bronx        |50.5792505718606  |
|Queens       |14.573507646515498|
|Staten Island|11.273981415296637|
|Unknown      |7.245918024787724 |
|Manhattan    |5.424279480643658 |
|EWR          |0.8988827712778211|
+-------------+------------------+

Elapsed_time: 24.47364377975464 seconds


just 24secs with only borough


takes 35 secs with aqe disabled 

In [None]:
# aqe skew join enabled

def create_spark_session_with_aqe_skew_join_enabled() -> SparkSession:
    conf = SparkConf() \
        .set("spark.driver.memory", "4G") \
        .set("spark.sql.autoBroadcastJoinThreshold", "-1") \
        .set("spark.sql.shuffle.partitions", "201") \
        .set("spark.sql.adaptive.enabled", "true") \
        .set("spark.sql.adaptive.coalescePartitions.enabled", "false") \
        .set("spark.sql.adaptive.skewJoin.enabled", "true") \
        .set("spark.sql.adaptive.skewJoin.skewedPartitionFactor", "3") \
        .set("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "256K")

    spark_session = SparkSession\
        .builder\
        .master("local[8]")\
        .config(conf=conf)\
        .appName("Read from JDBC tutorial") \
        .getOrCreate()

    return spark_session

if __name__ == '__main__':
    start_time = time.time()
    spark = create_spark_session_with_aqe_skew_join_enabled()

    join_on_skewed_data(spark=spark)

    print(f"Elapsed_time: {(time.time() - start_time)} seconds")
   # time.sleep(10000)

+-------------+------------------+
|Borough      |avg_trip_distance |
+-------------+------------------+
|Brooklyn     |54.084378566790875|
|Bronx        |50.5792505718606  |
|Queens       |14.573507646515498|
|Staten Island|11.273981415296637|
|Unknown      |7.245918024787724 |
|Manhattan    |5.424279480644278 |
|EWR          |0.8988827712778211|
+-------------+------------------+

Elapsed_time: 15.890237808227539 seconds


15 secs with only borough

Takes 25 secs to perform join operation with aqe enabled 

In [None]:
def join_with_bucketing(spark: SparkSession):
    """Performs a join using bucketed tables and measures execution time."""
    spark.conf.set("spark.sql.adaptive.enabled", "false")
    num_buckets = 10

    # Prepare and write data (we won't time this part)
    trips_data = prepare_trips_data(spark=spark)
    trips_data.write.bucketBy(num_buckets, "PULocationID").sortBy("PULocationID").mode("overwrite").saveAsTable("bucketed_trips_table")

    location_details_data = spark.read.option("header", True).csv("data/taxi+_zone_lookup.csv")
    location_details_data.write.bucketBy(num_buckets, "LocationID").sortBy("LocationID").mode("overwrite").saveAsTable("bucketed_locations_table")

    # Read from the bucketed tables
    bucketed_trips = spark.table("bucketed_trips_table")
    bucketed_locations = spark.table("bucketed_locations_table")

    # --- Start Timer ---
    start_time = time.time()

    # Perform the join on the bucketed tables
    trips_with_pickup_location_details = bucketed_trips.join(
        bucketed_locations,
        bucketed_trips["PULocationID"] == bucketed_locations["LocationID"],
        "inner"
    )

    # Trigger action to evaluate the join and aggregations
    # Using .collect() or .show() will trigger the computation
    result_df = trips_with_pickup_location_details \
        .groupBy("Borough") \
        .agg(F.avg("trip_distance").alias("avg_trip_distance")) \
        .sort(F.col("avg_trip_distance").desc())

    result_df.show(truncate=False, n=1000) # This action triggers the job

    # --- Stop Timer ---
    end_time = time.time()
    elapsed_time = end_time - start_time

    print(f"✅ The join and aggregation took: {elapsed_time:.2f} seconds")

# (SparkSession initialization and function call would go here)
if __name__ == '__main__':
    spark = SparkSession.builder \
        .appName("BucketingTimingExample") \
        .config("spark.sql.warehouse.dir", "spark-warehouse") \
        .enableHiveSupport() \
        .getOrCreate()

    join_with_bucketing(spark=spark)

    #spark.stop()

    time.sleep(10000)

+-------------+------------------+
|Borough      |avg_trip_distance |
+-------------+------------------+
|Brooklyn     |54.08437856679085 |
|Bronx        |50.57925057186062 |
|Queens       |14.573507646515482|
|Staten Island|11.27398141529664 |
|Unknown      |7.245918024787716 |
|Manhattan    |5.424279480643807 |
|EWR          |0.8988827712778201|
+-------------+------------------+

✅ The join and aggregation took: 17.64 seconds


Took 17.64 settings for performing join on PULocationID (Bucketing )

In [None]:
# salting
def join_with_salting(spark: SparkSession):
    """Performs a join using the salting technique to handle skew and tracks execution time."""
    import time
    start_time = time.time()

    # --- 1. Define Skewed Key and Salt Range ---
    SKEWED_KEY = 237
    SALT_RANGE = 10

    # --- 2. Prepare the DataFrames ---
    trips_data = prepare_trips_data(spark=spark)
    location_details_data = spark.read.option("header", True).csv("data/taxi+_zone_lookup.csv")

    # --- 3. Salt the Large (skewed) DataFrame ---
    salted_trips_data = trips_data.withColumn(
        "salted_PULocationID",
        F.when(
            F.col("PULocationID") == SKEWED_KEY,
            F.concat(F.col("PULocationID"), F.lit("_"), (F.rand() * SALT_RANGE).cast("int"))
        ).otherwise(
            F.concat(F.col("PULocationID"), F.lit("_"), F.lit(0))
        )
    )

    # --- 4. Explode the Small DataFrame ---
    salt_array = F.array([F.lit(i) for i in range(SALT_RANGE)])

    exploded_locations_data = location_details_data.withColumn(
        "salt",
        F.when(F.col("LocationID") == SKEWED_KEY, salt_array).otherwise(F.array(F.lit(0)))
    ).withColumn("salt", F.explode("salt"))

    exploded_locations_data = exploded_locations_data.withColumn(
        "salted_LocationID",
        F.concat(F.col("LocationID"), F.lit("_"), F.col("salt"))
    )

    # --- 5. Perform the Join on the Salted Keys ---
    print("Performing join on salted keys...")
    trips_with_pickup_location_details = salted_trips_data.join(
        exploded_locations_data,
        salted_trips_data["salted_PULocationID"] == exploded_locations_data["salted_LocationID"],
        "inner"
    )

    # --- 6. Aggregations ---
    # trips_with_pickup_location_details \
    #     .groupBy("Zone") \
    #     .agg(F.avg("trip_distance").alias("avg_trip_distance")) \
    #     .sort(F.col("avg_trip_distance").desc()) \
    #     .show(truncate=False, n=1000)

    trips_with_pickup_location_details \
        .groupBy("Borough") \
        .agg(F.avg("trip_distance").alias("avg_trip_distance")) \
        .sort(F.col("avg_trip_distance").desc()) \
        .show(truncate=False, n=1000)

    # --- 7. End Time ---
    end_time = time.time()
    print(f"Total execution time: {end_time - start_time:.2f} seconds")


if __name__ == '__main__':
    spark = SparkSession.builder.appName("SaltingExample").getOrCreate()
    join_with_salting(spark=spark)
    spark.stop()

Performing join on salted keys...
+-------------+------------------+
|Borough      |avg_trip_distance |
+-------------+------------------+
|Brooklyn     |54.08437856679081 |
|Bronx        |50.57925057186058 |
|Queens       |14.573507646515468|
|Staten Island|11.273981415296642|
|Unknown      |7.245918024787745 |
|Manhattan    |5.424279480643289 |
|EWR          |0.8988827712778152|
+-------------+------------------+

Total execution time: 8.13 seconds


took 8.13 secs just for borough


took approx 35.96 secs