<a href="https://colab.research.google.com/github/ramayer/google-colab-examples/blob/main/Efficient_spark_range_joins_in_Databricks_vs_Apache_Spark.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# Efficient spark range joins in Databricks vs Apache Spark

## Databricks spark has efficient range joins through hints.

https://docs.databricks.com/en/optimizations/range-join.html


## Apache Spark seems not to, yet.

* https://github.com/apache/spark/pull/7379
* https://issues.apache.org/jira/browse/SPARK-8682

Zach Moshe describes a workaround.

* http://zachmoshe.com/2016/09/26/efficient-range-joins-with-spark.html

This notebook implements something simlar to each of the above.


In [1]:
try:
  import pyspark, findspark, delta
except:
   %pip install -q --upgrade pyspark==3.5

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m316.9/316.9 MB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for pyspark (setup.py) ... [?25l[?25hdone


In [2]:
import pyspark

MAX_MEMORY="8g"
maven_coords = [
    'io.delta:delta-spark_2.12:3.2.0',
]
spark = (pyspark.sql.SparkSession.builder.appName("MyApp")
    .config("spark.jars.packages", ",".join(maven_coords))
    .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension")
    .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog")
    .config("spark.executor.memory", MAX_MEMORY)
    .config("spark.driver.memory", MAX_MEMORY)
    .enableHiveSupport()
    .getOrCreate()
    )
spark

In [3]:
spark.range(1000 * 1000).createOrReplaceTempView("a_million_rows")

In [4]:
df1 = spark.sql("""
  select
    cast('2024-01-01' as timestamp) + interval '1 second' * id /4 as dttm,
    sin(id/60 / 4) as sin,
    cos(id/60 / 4) as cos
  from a_million_rows
  """)
df1.write.format("delta").mode('overwrite').saveAsTable("df1")
df1.sort('dttm').limit(3).pandas_api()



Unnamed: 0,dttm,sin,cos
0,2024-01-01 00:00:00.000,0.0,1.0
1,2024-01-01 00:00:00.250,0.004167,0.999991
2,2024-01-01 00:00:00.500,0.008333,0.999965


In [5]:
df2 = spark.sql("""
  select
    cast('2024-01-01' as timestamp) + interval '1 second' * id as dttm,
    cos(id/60) as val
  from a_million_rows
  """)
df2.write.format("delta").mode('overwrite').saveAsTable("df2")
df2.sort('dttm').limit(3).pandas_api()

Unnamed: 0,dttm,val
0,2024-01-01 00:00:00,1.0
1,2024-01-01 00:00:01,0.999861
2,2024-01-01 00:00:02,0.999444


In [6]:
df1 = spark.table("df1")
df2 = spark.table("df2")

In [7]:
from pyspark.sql.functions import current_timestamp, lag
from pyspark.sql.window import Window
df2_ranges = (spark.sql("select * from df2")
                .withColumn("prev_dttm", lag("dttm").over(Window.orderBy("dttm")))
                .withColumn("prev_val", lag("val").over(Window.orderBy("dttm")))
                .selectExpr("prev_dttm","prev_val", "dttm as next_dttm", "val as next_val")
                )
df2_ranges.createOrReplaceTempView("df2_ranges")
df2_ranges.sort('prev_dttm').limit(3).pandas_api()

Unnamed: 0,prev_dttm,prev_val,next_dttm,next_val
0,NaT,,2024-01-01 00:00:00,1.0
1,2024-01-01 00:00:00,1.0,2024-01-01 00:00:01,0.999861
2,2024-01-01 00:00:01,0.999861,2024-01-01 00:00:02,0.999444


In [8]:
# Don't run this on F/OSS Apache Spark, it takes forever
# CartesianProduct and/or BroadcastNestedLoopJoin
# (depending on spark configs) are both painful ways to do joins.

spark.sql("""
  SELECT df1.dttm as df1_dttm,
         df1.sin  as df1_sin,
         df1.cos  as df1_cos,
         df2_ranges.prev_dttm,
         df2_ranges.next_dttm,
         df2_ranges.prev_val,
         df2_ranges.next_val
  FROM df1
  JOIN df2_ranges ON (df1.dttm >= df2_ranges.prev_dttm
                  AND df1.dttm <= df2_ranges.next_dttm)
""").explain()



== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Project [dttm#1522 AS df1_dttm#1519, sin#1523 AS df1_sin#1520, cos#1524 AS df1_cos#1521, prev_dttm#604, next_dttm#613, prev_val#608, next_val#614]
   +- CartesianProduct ((dttm#1522 >= prev_dttm#604) AND (dttm#1522 <= next_dttm#613))
      :- Filter isnotnull(dttm#1522)
      :  +- FileScan parquet spark_catalog.default.df1[dttm#1522,sin#1523,cos#1524] Batched: true, DataFilters: [isnotnull(dttm#1522)], Format: Parquet, Location: PreparedDeltaFileIndex(1 paths)[file:/content/spark-warehouse/df1], PartitionFilters: [], PushedFilters: [IsNotNull(dttm)], ReadSchema: struct<dttm:timestamp,sin:double,cos:double>
      +- Project [prev_dttm#604, prev_val#608, dttm#599 AS next_dttm#613, val#600 AS next_val#614]
         +- Filter (isnotnull(prev_dttm#604) AND isnotnull(dttm#599))
            +- Window [lag(dttm#599, -1, null) windowspecdefinition(dttm#599 ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, -1)) AS prev_dttm#604, lag(v

In [9]:
# python API does no better, even with a hint
df1.hint("range_join",6).join(df2_ranges,
                               on=[
                                   df1.dttm >= df2_ranges.prev_dttm,
                                   df1.dttm <= df2_ranges.next_dttm
                               ]).explain()

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- CartesianProduct ((dttm#589 >= prev_dttm#604) AND (dttm#589 <= next_dttm#613))
   :- Filter isnotnull(dttm#589)
   :  +- FileScan parquet spark_catalog.default.df1[dttm#589,sin#590,cos#591] Batched: true, DataFilters: [isnotnull(dttm#589)], Format: Parquet, Location: PreparedDeltaFileIndex(1 paths)[file:/content/spark-warehouse/df1], PartitionFilters: [], PushedFilters: [IsNotNull(dttm)], ReadSchema: struct<dttm:timestamp,sin:double,cos:double>
   +- Project [prev_dttm#604, prev_val#608, dttm#599 AS next_dttm#613, val#600 AS next_val#614]
      +- Filter (isnotnull(prev_dttm#604) AND isnotnull(dttm#599))
         +- Window [lag(dttm#599, -1, null) windowspecdefinition(dttm#599 ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, -1)) AS prev_dttm#604, lag(val#600, -1, null) windowspecdefinition(dttm#599 ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, -1)) AS prev_val#608], [dttm#599 ASC NULLS FIRST]
            +- Sort [dttm

In [10]:
# This should be fast on Databricks
# https://docs.databricks.com/en/optimizations/range-join.html
# but is annoyingly not on Apache Spark
spark.sql("""
  SELECT  /*+ RANGE_JOIN(dttm, 10) */
         df1.dttm as df1_dttm,
         df1.sin  as df1_sin,
         df1.cos  as df1_cos,
         df2_ranges.prev_dttm,
         df2_ranges.next_dttm,
         df2_ranges.prev_val,
         df2_ranges.next_val
  FROM df1
  JOIN df2_ranges ON (df1.dttm >= df2_ranges.prev_dttm
                  AND df1.dttm <= df2_ranges.next_dttm)
""").explain()



== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Project [dttm#2205 AS df1_dttm#2202, sin#2206 AS df1_sin#2203, cos#2207 AS df1_cos#2204, prev_dttm#604, next_dttm#613, prev_val#608, next_val#614]
   +- CartesianProduct ((dttm#2205 >= prev_dttm#604) AND (dttm#2205 <= next_dttm#613))
      :- Filter isnotnull(dttm#2205)
      :  +- FileScan parquet spark_catalog.default.df1[dttm#2205,sin#2206,cos#2207] Batched: true, DataFilters: [isnotnull(dttm#2205)], Format: Parquet, Location: PreparedDeltaFileIndex(1 paths)[file:/content/spark-warehouse/df1], PartitionFilters: [], PushedFilters: [IsNotNull(dttm)], ReadSchema: struct<dttm:timestamp,sin:double,cos:double>
      +- Project [prev_dttm#604, prev_val#608, dttm#599 AS next_dttm#613, val#600 AS next_val#614]
         +- Filter (isnotnull(prev_dttm#604) AND isnotnull(dttm#599))
            +- Window [lag(dttm#599, -1, null) windowspecdefinition(dttm#599 ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, -1)) AS prev_dttm#604, lag(v

## Manually emulate the Databricks optimization

In [11]:
# Choose a minute for the bin size.
# Smaller bins are faster, but less forgiving of missing data.
spark.sql("""
  create or replace temp view df1b as
  select *,
    floor(unix_timestamp(dttm)/60) as bin
  from df1
""")
spark.sql("""
  create or replace temp view df2b as
  select *,
    floor(unix_timestamp(prev_dttm)/60) as prev_bin,
    floor(unix_timestamp(next_dttm)/60) as next_bin
  from df2_ranges
""")


DataFrame[]

In [12]:
# Much better plan.
#
# SortMergeJoin & hashpartitioning should be reasonable.

manually_binned_join = spark.sql("""
WITH a as (
   SELECT df1b.dttm as df1_dttm,
         df1b.sin  as df1_sin,
         df1b.cos  as df1_cos,
         df2b.prev_dttm,
         df2b.next_dttm,
         df2b.prev_val,
         df2b.next_val
   FROM df1b
   JOIN df2b ON (df1b.bin = df2b.next_bin)
  ),
  b as (
  SELECT df1b.dttm as df1_dttm,
         df1b.sin  as df1_sin,
         df1b.cos  as df1_cos,
         df2b.prev_dttm,
         df2b.next_dttm,
         df2b.prev_val,
         df2b.next_val
  FROM df1b
  JOIN df2b ON (df1b.bin <> df2b.next_bin and df1b.bin = df2b.prev_bin)
  ),
  c as (
    SELECT * FROM a
    UNION ALL
    SELECT * FROM b
  )
  SELECT * FROM c
  where (df1_dttm >= prev_dttm and df1_dttm < next_dttm)
""")
manually_binned_join.explain()


== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Union
   :- Project [dttm#2476 AS df1_dttm#2449, sin#2477 AS df1_sin#2450, cos#2478 AS df1_cos#2451, prev_dttm#604 AS prev_dttm#2452, next_dttm#2425 AS next_dttm#2453, prev_val#608 AS prev_val#2454, next_val#2426 AS next_val#2455]
   :  +- SortMergeJoin [bin#2416L], [next_bin#2428L], Inner, ((dttm#2476 >= prev_dttm#604) AND (dttm#2476 < next_dttm#2425))
   :     :- Sort [bin#2416L ASC NULLS FIRST], false, 0
   :     :  +- Exchange hashpartitioning(bin#2416L, 200), ENSURE_REQUIREMENTS, [plan_id=865]
   :     :     +- Project [dttm#2476, sin#2477, cos#2478, FLOOR((cast(unix_timestamp(dttm#2476, yyyy-MM-dd HH:mm:ss, Some(Etc/UTC), false) as double) / 60.0)) AS bin#2416L]
   :     :        +- Filter (isnotnull(dttm#2476) AND isnotnull(FLOOR((cast(unix_timestamp(dttm#2476, yyyy-MM-dd HH:mm:ss, Some(Etc/UTC), false) as double) / 60.0))))
   :     :           +- FileScan parquet spark_catalog.default.df1[dttm#2476,sin#2477,cos#2478] B

In [13]:
import time
import pandas
t0 = time.time()
manually_binned_join.createOrReplaceTempView("manually_binned_join")
result = spark.sql("select * from manually_binned_join order by df1_dttm").limit(10).toPandas()
t1 = time.time()
print(f"took {t1-t0:.2f} seconds")
result

took 12.32 seconds


Unnamed: 0,df1_dttm,df1_sin,df1_cos,prev_dttm,next_dttm,prev_val,next_val
0,2024-01-01 00:00:00.000,0.0,1.0,2024-01-01 00:00:00,2024-01-01 00:00:01,1.0,0.999861
1,2024-01-01 00:00:00.250,0.004167,0.999991,2024-01-01 00:00:00,2024-01-01 00:00:01,1.0,0.999861
2,2024-01-01 00:00:00.500,0.008333,0.999965,2024-01-01 00:00:00,2024-01-01 00:00:01,1.0,0.999861
3,2024-01-01 00:00:00.750,0.0125,0.999922,2024-01-01 00:00:00,2024-01-01 00:00:01,1.0,0.999861
4,2024-01-01 00:00:01.000,0.016666,0.999861,2024-01-01 00:00:01,2024-01-01 00:00:02,0.999861,0.999444
5,2024-01-01 00:00:01.250,0.020832,0.999783,2024-01-01 00:00:01,2024-01-01 00:00:02,0.999861,0.999444
6,2024-01-01 00:00:01.500,0.024997,0.999688,2024-01-01 00:00:01,2024-01-01 00:00:02,0.999861,0.999444
7,2024-01-01 00:00:01.750,0.029163,0.999575,2024-01-01 00:00:01,2024-01-01 00:00:02,0.999861,0.999444
8,2024-01-01 00:00:02.000,0.033327,0.999444,2024-01-01 00:00:02,2024-01-01 00:00:03,0.999444,0.99875
9,2024-01-01 00:00:02.250,0.037491,0.999297,2024-01-01 00:00:02,2024-01-01 00:00:03,0.999444,0.99875
