
# Goal

This notebook demonstrates how to load, process, and analyze geospatial raster imagery using Databricks and Apache Sedona. It reads orthophoto TIFF files from Unity Catalog volumes, extracts raster metadata and statistics, and prepares the data for further geospatial analysis namely classification of the pixels using NDVI and NDWI.

In [0]:
%run ../get_user


In [0]:
%run ./merge_images

In [0]:
user_email = spark.sql("SELECT current_user()").collect()[0][0]
username = get_username_from_email(user_email)
print(username)

In [0]:
from pyspark.sql import functions as f
from pyspark.sql.functions import expr, explode, col
from sedona.spark import *
from pyspark.sql.window import Window
from pyspark.sql import functions as F
import math
from pyspark.sql import SparkSession
import os

In [0]:
catalog_name = "geospatial"
schema_name = "inputs"
volume_name="geospatial_dataset"

In [0]:
dataset_bucket_name = "revodata-databricks-geospatial"
file_urls = {"2022": f"/Volumes/{catalog_name}/{schema_name}/{volume_name}/raster/orthophoto/soma/2022/2022_4BandImagery_SanFranciscoCA_J1191044.tif", "2025": f"/Volumes/{catalog_name}/{schema_name}/{volume_name}/raster/orthophoto/soma/2025/2025_4BandImagery_SanFranciscoCA_J1191043.tif"}


In [0]:

config = SedonaContext.builder() .\
    config('spark.jars.packages',
           'org.apache.sedona:sedona-spark-shaded-3.3_2.12:1.8.0,'
           'org.datasyslab:geotools-wrapper:1.8.0-33.1'). \
    getOrCreate()

sedona = SedonaContext.create(config)


In [0]:
# Load orthophoto GeoTIFF files for 2025 and 2022, convert to raster, and create temp views for further analysis

df_image_2025 = sedona.read.format("binaryFile").load(file_urls["2025"])
df_image_2025 = df_image_2025.withColumn("raster", f.expr("RS_FromGeoTiff(content)"))
df_image_2025.createOrReplaceTempView("image_new_vw")
display(df_image_2025)

df_image_2022 = sedona.read.format("binaryFile").load(file_urls["2022"])
df_image_2022 = df_image_2022.withColumn("raster", f.expr("RS_FromGeoTiff(content)"))
df_image_2022.createOrReplaceTempView("image_old_vw")
display(df_image_2022)

In [0]:
# Explore ;)
display(spark.sql("""SELECT RS_MetaData(raster) AS metadata, 
                  RS_NumBands(raster) AS num_bands,
                  RS_SummaryStatsAll(raster) AS summary_stat,
                  RS_BandPixelType(raster) AS band_pixel_type,
                  RS_Count(raster) AS count 
                  FROM image_new_vw"""))

In [0]:
# Explore ;)
htmlDf = sedona.sql("SELECT RS_AsImage(raster, 500) FROM image_new_vw")
SedonaUtils.display_image(htmlDf)

In [0]:
def compute_tile_grid(w, h, x, y):
    """
    Compute the optimal number of columns and rows to divide an image of size (w x h)
    into tiles such that each tile is at least size x and at most size y in both width and height.
    
    Returns:
        cols (int): number of tiles along width
        rows (int): number of tiles along height
        tile_width (int): actual width of each tile
        tile_height (int): actual height of each tile
    """
    
    def get_tile_count(dim, min_tile_size, max_tile_size):
        # Get range of possible tile counts satisfying both size constraints
        min_count = math.ceil(dim / max_tile_size)
        max_count = math.floor(dim / min_tile_size)
        if min_count > max_count:
            raise ValueError(f"Cannot tile dimension {dim} within bounds {min_tile_size}-{max_tile_size}")
        
        # Choose the tile count that makes tiles closest to the center of the allowed range
        best_count = None
        best_tile_size = None
        target_size = (min_tile_size + max_tile_size) / 2
        
        for count in range(min_count, max_count + 1):
            tile_size = dim / count
            if min_tile_size <= tile_size <= max_tile_size:
                if best_tile_size is None or abs(tile_size - target_size) < abs(best_tile_size - target_size):
                    best_tile_size = tile_size
                    best_count = count
                    
        return best_count, math.ceil(dim / best_count)

    cols, tile_width = get_tile_count(w, x, y)
    rows, tile_height = get_tile_count(h, x, y)

    return cols, rows, tile_width, tile_height

In [0]:
# Compute image width and height, then determine optimal tile grid based on min/max tile size constraints

w = spark.sql("select RS_Width(raster) from image_new_vw").collect()[0][0]
h = spark.sql("select RS_Height(raster) from image_new_vw").collect()[0][0]

print(w)
print(h)

min_tile_size = 128  # minimum tile size
max_tile_size = 256  # maximum tile size

cols, rows, tile_w, tile_h = compute_tile_grid(w, h, min_tile_size, max_tile_size)
print(f"Image will be divided into {cols} columns and {rows} rows")
print(f"Each tile will be approximately {tile_w} x {tile_h} pixels")

In [0]:
# Explode raster images into tiles for 2025 and 2022, add tile metadata, assign row numbers and year, and repartition by tile_x

tiled_df_2025 = (
    df_image_2025
    .selectExpr(f"RS_TileExplode(raster, {tile_w}, {tile_h})")
    .withColumnRenamed("x", "tile_x")
    .withColumnRenamed("y", "tile_y")
    .withColumn("width", expr("RS_Width(tile)"))
    .withColumn("height", expr("RS_height(tile)"))
    .repartitionByRange("tile_x")
)

window_spec = Window.orderBy("tile_x", "tile_y")
tiled_df_2025 = (
    tiled_df_2025
    .withColumn("rn", F.row_number().over(window_spec))
    .withColumn("year", lit(2025))
)

tiled_df_2022 = (
    df_image_2022
    .selectExpr(f"RS_TileExplode(raster, {tile_w}, {tile_h})")
    .withColumnRenamed("x", "tile_x")
    .withColumnRenamed("y", "tile_y")
    .withColumn("width", expr("RS_Width(tile)"))
    .withColumn("height", expr("RS_height(tile)"))
    .repartitionByRange("tile_x")
)

window_spec = Window.orderBy("tile_x", "tile_y")
tiled_df_2022 = (
    tiled_df_2022
    .withColumn("rn", F.row_number().over(window_spec))
    .withColumn("year", lit(2022))
)

In [0]:
# Explore

first_tile = tiled_df_2025.limit(1)
first_tile.createOrReplaceTempView("first_tile_vw")
htmlDf = sedona.sql("SELECT RS_AsImage(tile) FROM first_tile_vw")
SedonaUtils.display_image(htmlDf)

In [0]:
# Union 2022 and 2025 tiled rasters, assign index, and calculate NDVI/NDWI bands for each tile
union_raster = tiled_df_2022.unionByName(tiled_df_2025, allowMissingColumns=False)
window_spec = Window.partitionBy("rn").orderBy(F.desc("year"))
union_raster = union_raster.withColumn("index", F.row_number().over(window_spec))

# Calculating NDVI using Red and NIR bands as NDVI = (NIR - Red) / (NIR + Red)
# Calculating NDWI using Green and NIR bands as NDWI = (Green - NIR) / (Green + NIR)
union_raster = (
    union_raster
    .withColumn(
        "ndvi",
        expr(
            "RS_Divide("
            "  RS_Subtract(RS_BandAsArray(tile, 1), RS_BandAsArray(tile, 4)), "
            "  RS_Add(RS_BandAsArray(tile, 1), RS_BandAsArray(tile, 4))"
            ")"
        )
    )
    .withColumn(
        "ndwi",
        expr(
            "RS_Divide("
            "  RS_Subtract(RS_BandAsArray(tile, 4), RS_BandAsArray(tile, 2)), "
            "  RS_Add(RS_BandAsArray(tile, 4), RS_BandAsArray(tile, 2))"
            ")"
        )
    )
)

# First and and second bands as arrays in new columns
union_raster = union_raster.withColumn(
    "red",
    expr(
        "RS_BandAsArray(tile, 1)"
    )
).withColumn(
    "green",
    expr(
        "RS_BandAsArray(tile, 2)"
    )
)

display(union_raster.limit(1))

In [0]:
# Classification tree for raster tiles based on Red, Green bands, NDVI, and NDWI.
# Applies different classification rules for years 2022 and 2025 using arrays_zip and transform.
# Assigns a class value per pixel: 1, 2, 3, 4, or 999 (no data).

union_raster = union_raster.withColumn(
    "classification",
    F.expr("""
        transform(
            arrays_zip(ndvi, ndwi, red, green),
            x -> 
                CASE 
                    WHEN year = 2022 THEN
                        CASE 
                            WHEN x.red < 15 AND x.green < 15 THEN 4
                            WHEN x.ndvi > 0.35 AND x.ndwi < -0.35 THEN 2
                            WHEN (x.ndvi < -0.2 AND x.ndwi > 0.35) OR (x.red < 15 AND x.ndwi > 0.35) OR (x.ndwi > 0.45) THEN 3
                            WHEN x.ndvi >= -0.3 AND x.ndvi <= 0.3 AND x.ndwi >= -0.3 AND x.ndwi <= 0.3 THEN 1
                            ELSE 999
                        END
                    WHEN year = 2025 THEN
                        CASE 
                            WHEN x.red < 15 AND x.green < 15 THEN 4
                            WHEN x.ndvi > 0.3 AND x.ndwi < -0.15 THEN 2
                            WHEN x.ndvi < -0.35 AND x.ndwi > 0.55 THEN 3
                            WHEN (x.ndvi >= -0.5 AND x.ndvi <= 0.5 AND x.ndwi >= -0.5 AND x.ndwi <= 0.5) OR (x.ndvi > 0.8 AND x.ndwi > 0.3) THEN 1                           
                            ELSE 999
                        END
                    ELSE 999
                END
        )
    """)
)
display(union_raster.limit(2))

In [0]:
# This block creates a classification raster for each tile, setting the no data value to 999.
# Steps:
# 1. Select relevant columns and create a new raster band from the classification array.
# 2. Apply the no data value to the band (twice, for inspection and finalization).
# 3. Display the result for inspection.
# 4. Write the classified raster tiles as GeoTIFF to a Unity Catalog table.

classification_df = (
    union_raster
    .select(
        "tile_x", "tile_y", "rn", "year", "index",
        expr("RS_MakeRaster(tile, 'I', classification) AS tile").alias("tile")
    )
    .select(
        "tile_x", "tile_y", "rn", "year", "index",
        expr("RS_SetBandNoDataValue(tile, 1, 999, false)").alias("tile")
    )
    .select(
        "tile_x", "tile_y", "rn", "year", "index",
        expr("RS_SetBandNoDataValue(tile, 1, 999, true)").alias("tile")
    )
)

display(classification_df.limit(10))

classification_df.withColumn("raster_binary", expr("RS_AsGeoTiff(tile)")) \
    .select("tile_x", "tile_y", "rn", "year", "index", "raster_binary") \
    .write.mode("overwrite") \
    .saveAsTable(f"geospatial.soma.classification_{username}")

In [0]:
# Load classified raster tiles from Unity Catalog, repartition by 'rn', display sample,
# filter for 2025 and 2022 years, and merge tiles into output GeoTIFFs for each year.

geotiff_df = spark.table(f"geospatial.soma.classification_{username}") \
    .repartitionByRange(10, "rn")
display(geotiff_df.limit(10))

output_tiff_2025 = f"s3://{dataset_bucket_name}/outputs/geotiff/{username}/classification_2025.tif"
output_tiff_2022 = f"s3://{dataset_bucket_name}/outputs/geotiff/{username}/classification_2022geotiff_df.tif"

class_2025 = geotiff_df.filter(geotiff_df["year"] == 2025)
class_2022 = geotiff_df.filter(geotiff_df["year"] == 2022)

merge_tiffs(class_2025, output_tiff_2025)
merge_tiffs(class_2022, output_tiff_2022)