# Performing Spatial Joins in Wherobots

This notebook will guide you through performing spatial joins in Wherobots using Python and the DataFrame API — giving you a hands-on understanding of how to combine datasets based on their spatial relationships.

## What you will learn

This notebook will teach you to:

* Perform **standard spatial joins** — identifying features within other geometries
* Execute **nearest neighbor joins** — finding the closest feature between datasets
* Calculate **zonal statistics** — summarizing values within geographic zones
* Apply optimization techniques like spatial partitioning with GeoHashes
* Visualize join results using interactive tools

> Spatial joins are a core operation in geospatial analysis, allowing you to merge datasets based on how their features relate in space.

This notebook focuses on practical workflows and scalable processing with Wherobots and Apache Sedona.


In [None]:
from sedona.spark import SedonaContext
from pyspark.sql.functions import expr
from pyspark.sql.functions import col

config = SedonaContext.builder().getOrCreate()
sedona = SedonaContext.create(config)

## Loading datasets for a spatial join

To perform a spatial join, we need two datasets: a **polygon layer** and a **point layer**.

In this example, we use publicly available datasets from the Wherobots Open Data Catalog:

* **Polygon data** — Administrative boundaries from the Overture Maps Foundation
* **Point data** — Places of interest from the Foursquare dataset

> The polygon query selects US localities (like cities or towns) from the `divisions_division_area` table.
> The points DataFrame loads all place records from the Foursquare dataset.

These two DataFrames will serve as the inputs for the spatial join operations in this notebook.

In [None]:
query = '''
SELECT 
    * 
FROM
    wherobots_open_data.overture_maps_foundation.divisions_division_area
WHERE
    subtype = 'locality'
    AND country = 'US'
'''

polygons_df = sedona.sql(query)
points_df = sedona.table("wherobots_open_data.foursquare.places")

In [None]:
print("Sample of the Polygon Dataset (Administrative Boundaries):")
polygons_df.show(5)

In [None]:
print("Sample of the Points Dataset (Facilities):")
points_df.show(5, truncate=False)

## Performing a standard spatial join

With both datasets loaded, we can now join them based on their spatial relationship.
In this case, we want to find which facilities (points) fall within each administrative boundary (polygons).

We use the `ST_Intersects` function to check if a facility's geometry intersects a boundary's geometry:

> We use DataFrame aliases for clarity when joining.
> The spatial join keeps only the pairs of points and polygons where their geometries intersect.

The resulting points DataFrame will include columns for each administrative boundary that it intersects.

In [None]:
facilities = points_df.alias("f")
admin_boundaries = polygons_df.alias("poly")

In [None]:
spatial_join_df = facilities.join(
    admin_boundaries,
    expr("ST_Intersects(poly.geometry, f.geom)")
)

In [None]:
%%time
spatial_join_df.count()

In [None]:
print("Standard Spatial Join Results (Facilities within Administrative Boundaries):")
spatial_join_df.show(1)

## Spatial aggregate within polygons using a spatial join

After performing a spatial join, a common analysis is to aggregate how many points fall within each polygon — for example, summarizing the number of facilities within each administrative boundary.

We can perform this in a single operation by combining the spatial join with a `groupBy` and aggregation:

> This query joins the polygons and points, groups the results by the polygon ID, and counts the matching points.
> The result is a DataFrame showing each polygon and the number of points (facilities) within it.


In [None]:
points_count_efficient_df = polygons_df.alias("poly") \
    .join(points_df.alias("f"), expr("ST_Intersects(poly.geometry, f.geom)")) \
    .groupBy("poly.id") \
    .agg(expr("COUNT(*) as point_count"))

In [None]:
print("🔹 Efficient Count of Points in Each Polygon:")
points_count_efficient_df.show(10)

## Performing a nearest neighbor spatial join

In some cases, you may want to find the closest feature from another dataset — such as identifying the nearest administrative centroid for each facility.

This is called a **nearest neighbor join**, and Wherobots supports it using the `ST_AKNN` function.

We first calculate the centroids of the administrative boundaries.

> The `ST_Centroid` function returns the geometric center of each polygon.

In [None]:
centroids_df = polygons_df.selectExpr("id", "ST_Centroid(geometry) as centroid")

## Running the approximate k-nearest neighbor (AKNN) join

We join the points to the centroids using `ST_AKNN`, which performs an approximate k-nearest neighbor search.

In this example, we retrieve the **4 nearest centroids** for each point:

```python
aknn_df = points_df.alias("q").join(
    centroids_df.alias("o"),
    expr("ST_AKNN(q.geom, o.centroid, 4, false)")
)
```

> The `ST_AKNN` function takes the query geometry, the object geometry, the number of neighbors (4), and a boolean for including ties.

This join pairs each facility with its closest administrative centroids — useful for proximity analysis and clustering.


In [None]:
aknn_df = points_df.alias("q").join(
    centroids_df.alias("o"),
    expr("ST_AKNN(q.geom, o.centroid, 4, false)")
)

In [None]:
aknn_result_df = aknn_df.select(
    expr("q.fsq_place_id as query_id"),
    expr("q.geom as query_geom"),
    expr("o.centroid as object_geom")
)

In [None]:
print("🔹 Nearest Neighbor Join using ST_AKNN:")
aknn_result_df.show(10, truncate=False)

## Optimizing spatial joins with geohash partitioning

When working with large datasets, spatial joins can be computationally expensive. By sorting on a spatial index — like **GeoHash** — you can improve join performance by reducing unnecessary comparisons.

We start by creating a dedicated database schema to store the partitioned tables:

We compute a **GeoHash** for each geometry, then sort the DataFrames by this key. This clusters spatially nearby features together in storage, improving partition alignment during joins.

```python
points_df = points_df.withColumn("geohash", expr("ST_GeoHash(geom, 6)"))
polygons_df = polygons_df.withColumn("geohash", expr("ST_GeoHash(geometry, 6)"))

from pyspark.sql.functions import col

sorted_points = points_df.sort(col("geohash")).drop("geohash")
sorted_polys = polygons_df.sort(col("geohash")).drop("geohash")
```

> GeoHash precision of **6** balances granularity and performance.

Sorting without partitioning still improves join locality in distributed processing. Partitioning by spatial keys can significantly reduce join execution time on large datasets.


In [None]:
database_name = 'joins'
sedona.sql(f"CREATE DATABASE IF NOT EXISTS wherobots.{database_name}")

In [None]:
points_df = points_df.withColumn("geohash", expr("ST_GeoHash(geom, 6)"))
polygons_df = polygons_df.withColumn("geohash", expr("ST_GeoHash(geometry, 6)"))

In [None]:
sorted_points = points_df.sort(col("geohash"))\
    .drop("geohash")

sorted_polys = polygons_df.sort(col("geohash"))\
    .drop("geohash")

sorted_points.writeTo(f"wherobots.{database_name}.points").createOrReplace()
sorted_polys.writeTo(f"wherobots.{database_name}.polygons").createOrReplace()

print("DataFrames sorted by geohash for improved spatial join performance")

In [None]:
# Alias the DataFrames for clarity
facilities = sedona.table(f"wherobots.{database_name}.points").alias("f")
admin_boundaries = sedona.table(f"wherobots.{database_name}.polygons").alias("poly")

In [None]:
spatial_join_df_partition = facilities.join(
    admin_boundaries,
    expr("ST_Intersects(poly.geometry, f.geom)")
)

In [None]:
%%time
spatial_join_df_partition.count()

# Visualizing spatial join results

Wherobots includes tools for interactive spatial visualization.
We’ll use **SedonaKepler** to explore the results of our spatial join directly in the notebook.

> SedonaKepler creates an interactive map that lets you explore your joined data visually.


In [None]:
# Define the WKT polygon as a string
wkt_polygon = "POLYGON((-84.656729 33.983118, -84.109483 33.983118, -84.109483 33.562116, -84.656729 33.562116, -84.656729 33.983118))"

In [None]:
detailed_facilities_df = spatial_join_df.select(
    "f.fsq_place_id",    # Unique facility identifier
    "f.name",            # Facility name
    "f.address",         # Facility address
    "f.locality",        # Locality information
    "f.region",          # Region name
    "f.postcode",        # Postal code
    "f.admin_region",    # Administrative region
    "f.post_town",       # Post town
    "f.country",         # Country name
    "f.geom",            # Facility geometry
    "poly.names"         # Additional name information
).filter(
    expr(f"ST_Intersects(geometry, ST_GeomFromText('{wkt_polygon}'))")
).selectExpr("*", "names.primary") \
.drop("names")

# Display the first few rows of the resulting DataFrame
print("Detailed Facility Information from spatial_join_df:")
detailed_facilities_df.count()

In [None]:
from sedona.maps.SedonaKepler import SedonaKepler

# Create an interactive map from the spatial join DataFrame.
# The map will show facilities along with the administrative boundaries they fall within.
kepler_map = SedonaKepler.create_map(df=detailed_facilities_df, name="Facilities_Within_Zones")

In [None]:
kepler_map

## Visualizing zonal statistics with a choropleth map

We can also visualize summary statistics using a **choropleth map**, which colors each zone based on a value — like an average measurement.

> SedonaPyDeck creates a choropleth map that highlights differences between zones based on your data.


In [None]:
points_count_efficient_df = polygons_df.alias("poly") \
    .filter(
        expr(f"ST_Intersects(geometry, ST_GeomFromText('{wkt_polygon}'))")
    ) \
    .join(points_df.alias("f"), expr("ST_Intersects(poly.geometry, f.geom)")) \
    .groupBy("poly.id", "poly.geometry") \
    .agg(expr("COUNT(*) as point_count"))

In [None]:
from sedona.maps.SedonaPyDeck import SedonaPyDeck

# Create a choropleth map using the zonal statistics DataFrame.
# The zones are colored based on the 'avg_measurement' column, highlighting variations across regions.

choropleth_map = SedonaPyDeck.create_choropleth_map(
    df=points_count_efficient_df,
    plot_col="point_count"  # This column drives the color intensity
)

# Display the choropleth map in your Jupyter Notebook
choropleth_map.show()