In [0]:
from pyspark.sql import functions as f
from pyspark.sql.functions import expr, explode, col
from sedona.spark import *
from pyspark.sql.window import Window
from pyspark.sql import functions as F
import math
from pyspark.sql import SparkSession
import os
import tempfile
import numpy as np
from osgeo import gdal, gdalconst
import boto3
from botocore.exceptions import ClientError
import logging
import os


In [0]:
dataset_bucket_name = "revodata-databricks-geospatial"
file_urls = {"2022": f"s3://{dataset_bucket_name}/geospatial-dataset/raster/orthophoto/soma/2022/2022_4BandImagery_SanFranciscoCA_J1191044.tif", "2025": f"s3://{dataset_bucket_name}/geospatial-dataset/raster/orthophoto/soma/2025/2025_4BandImagery_SanFranciscoCA_J1191043.tif"}
catalog_name = "geospatial"


In [0]:
os.environ['AWS_ACCESS_KEY_ID'] = dbutils.secrets.get(scope="aws_geospatial_s3", key="access_key")
os.environ['AWS_SECRET_ACCESS_KEY'] = dbutils.secrets.get(scope="aws_geospatial_s3", key="secret_key")
os.environ['AWS_DEFAULT_REGION'] = 'eu-west-2'      # Match your bucket region

In [0]:
config = SedonaContext.builder(). \
    config("spark.hadoop.fs.s3a.bucket.wherobots-examples.aws.credentials.provider","org.apache.hadoop.fs.s3a.AnonymousAWSCredentialsProvider"). \
    getOrCreate()

sedona = SedonaContext.create(config)

In [0]:
# Configure GDAL
gdal.UseExceptions()
gdal.SetConfigOption('GDAL_DISABLE_READDIR_ON_OPEN', 'YES')

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def get_s3_client():
    """Initialize and return boto3 S3 client with credentials"""
    return boto3.client('s3',
                       aws_access_key_id=os.getenv('AWS_ACCESS_KEY_ID'),
                       aws_secret_access_key=os.getenv('AWS_SECRET_ACCESS_KEY'),
                       region_name=os.getenv('AWS_DEFAULT_REGION', 'eu-west-2'))

def upload_to_s3(local_path, s3_path):
    """Upload file to S3 using boto3"""
    s3 = get_s3_client()
    try:
        bucket, key = s3_path.replace("s3://", "").split("/", 1)
        s3.upload_file(local_path, bucket, key)
        logger.info(f"Successfully uploaded to {s3_path}")
        return True
    except ClientError as e:
        logger.error(f"Failed to upload to S3: {e}")
        return False

def merge_tiffs(binary_df, output_path):
    """
    Merges multiple TIFF binary data from a Spark DataFrame into a single TIFF using GDAL
    
    Args:
        binary_df: PySpark DataFrame containing binary TIFF data
        output_path: Path to save the merged TIFF file (can be local or S3 path)
    """
    # Collect all TIFF binary data to driver
    tiff_data = binary_df.select(col("raster_binary")).collect()
    
    with tempfile.TemporaryDirectory() as temp_dir:
        temp_files = []
        
        # Write each binary TIFF to a temporary file
        for i, row in enumerate(tiff_data):
            temp_file = os.path.join(temp_dir, f"temp_{i}.tif")
            with open(temp_file, 'wb') as f:
                f.write(row['raster_binary'])
            temp_files.append(temp_file)
        
        if len(temp_files) > 0:
            try:
                # Create VRT mosaic
                vrt_file = os.path.join(temp_dir, "mosaic.vrt")
                gdal.BuildVRT(vrt_file, temp_files)
                
                # Create local output file
                local_output = os.path.join(temp_dir, "merged_output.tif")
                translate_options = gdal.TranslateOptions(
                    creationOptions=['COMPRESS=DEFLATE', 'TILED=YES', 'BIGTIFF=IF_SAFER']
                )
                gdal.Translate(local_output, vrt_file, options=translate_options)
                
                # Handle output destination
                if output_path.startswith('s3://'):
                    if upload_to_s3(local_output, output_path):
                        logger.info(f"Merged {len(temp_files)} TIFFs to {output_path}")
                    else:
                        raise Exception("Failed to upload to S3")
                else:
                    os.rename(local_output, output_path)
                    logger.info(f"Merged {len(temp_files)} TIFFs to {output_path}")
                
            except Exception as e:
                logger.error(f"Error during merge: {e}")
                raise
        else:
            logger.warning("No TIFF files found to merge")

In [0]:
df_image_2025 = sedona.read.format("binaryFile").load(file_urls["2025"])
df_image_2025 = df_image_2025.withColumn("raster", f.expr("RS_FromGeoTiff(content)"))
df_image_2025.createOrReplaceTempView("image_new_vw")
display(df_image_2025)

In [0]:
display(spark.sql("""SELECT RS_MetaData(raster) AS metadata, 
                  RS_NumBands(raster) AS num_bands,
                  RS_SummaryStatsAll(raster) AS summary_stat,
                  RS_BandPixelType(raster) AS band_pixel_type,
                  RS_Count(raster) AS count 
                  FROM image_new_vw"""))



In [0]:
htmlDf = sedona.sql("SELECT RS_AsImage(raster, 500) FROM image_new_vw")
SedonaUtils.display_image(htmlDf)

In [0]:
df_image_2022 = sedona.read.format("binaryFile").load(file_urls["2022"])
df_image_2022 = df_image_2022.withColumn("raster", f.expr("RS_FromGeoTiff(content)"))
df_image_2022.createOrReplaceTempView("image_old_vw")
display(df_image_2022)

In [0]:
def compute_tile_grid(w, h, x, y):
    """
    Compute the optimal number of columns and rows to divide an image of size (w x h)
    into tiles such that each tile is at least size x and at most size y in both width and height.
    
    Returns:
        cols (int): number of tiles along width
        rows (int): number of tiles along height
        tile_width (int): actual width of each tile
        tile_height (int): actual height of each tile
    """
    
    def get_tile_count(dim, min_tile_size, max_tile_size):
        # Get range of possible tile counts satisfying both size constraints
        min_count = math.ceil(dim / max_tile_size)
        max_count = math.floor(dim / min_tile_size)
        if min_count > max_count:
            raise ValueError(f"Cannot tile dimension {dim} within bounds {min_tile_size}-{max_tile_size}")
        
        # Choose the tile count that makes tiles closest to the center of the allowed range
        best_count = None
        best_tile_size = None
        target_size = (min_tile_size + max_tile_size) / 2
        
        for count in range(min_count, max_count + 1):
            tile_size = dim / count
            if min_tile_size <= tile_size <= max_tile_size:
                if best_tile_size is None or abs(tile_size - target_size) < abs(best_tile_size - target_size):
                    best_tile_size = tile_size
                    best_count = count
                    
        return best_count, math.ceil(dim / best_count)

    cols, tile_width = get_tile_count(w, x, y)
    rows, tile_height = get_tile_count(h, x, y)

    return cols, rows, tile_width, tile_height

In [0]:
w = spark.sql("select RS_Width(raster) from image_new_vw").collect()[0][0]
h = spark.sql("select RS_Height(raster) from image_new_vw").collect()[0][0]

print(w)
print(h)

min_tile_size = 128  # minimum tile size
max_tile_size = 256  # maximum tile size
cols, rows, tile_w, tile_h = compute_tile_grid(w, h, min_tile_size, max_tile_size)
print(f"Image will be divided into {cols} columns and {rows} rows")
print(f"Each tile will be approximately {tile_w} x {tile_h} pixels")

In [0]:
tiled_df_2025 = df_image_2025.selectExpr(
  f"RS_TileExplode(raster, {tile_w}, {tile_h})"
).withColumnRenamed("x", "tile_x").withColumnRenamed("y", "tile_y").withColumn("width", expr("RS_Width(tile)")).withColumn("height", expr("RS_height(tile)"))
window_spec = Window.orderBy("tile_x", "tile_y")
tiled_df_2025 = tiled_df_2025.withColumn("rn", F.row_number().over(window_spec)).withColumn("year", lit(2025))


In [0]:
tiled_df_2022 = df_image_2022.selectExpr(
  f"RS_TileExplode(raster, {tile_w}, {tile_h})"
).withColumnRenamed("x", "tile_x").withColumnRenamed("y", "tile_y").withColumn("width", expr("RS_Width(tile)")).withColumn("height", expr("RS_height(tile)"))
window_spec = Window.orderBy("tile_x", "tile_y")
tiled_df_2022 = tiled_df_2022.withColumn("rn", F.row_number().over(window_spec)).withColumn("year", lit(2022))


In [0]:
first_tile = tiled_df_2025.limit(1)
first_tile.createOrReplaceTempView("first_tile_vw")
htmlDf = sedona.sql("SELECT RS_AsImage(tile) FROM first_tile_vw")
SedonaUtils.display_image(htmlDf)

In [0]:
union_raster = tiled_df_2022.unionByName(tiled_df_2025, allowMissingColumns=False)
window_spec = Window.partitionBy("rn").orderBy(F.desc("year"))
union_raster = union_raster.withColumn("index", F.row_number().over(window_spec))

In [0]:
# Calculating NDVI using Red and NIR bands as NDVI = (NIR - Red) / (NIR + Red)
union_raster = union_raster.withColumn(
    "ndvi",
    expr(
        "RS_Divide("
        "  RS_Subtract(RS_BandAsArray(tile, 1), RS_BandAsArray(tile, 4)), "
        "  RS_Add(RS_BandAsArray(tile, 1), RS_BandAsArray(tile, 4))"
        ")"
    )
)

In [0]:
# Calculating NDWI using Green and NIR bands as NDWI = (Green - NIR) / (Green + NIR)
union_raster = union_raster.withColumn(
    "ndwi",
    expr(
        "RS_Divide("
        "  RS_Subtract(RS_BandAsArray(tile, 4), RS_BandAsArray(tile, 2)), "
        "  RS_Add(RS_BandAsArray(tile, 4), RS_BandAsArray(tile, 2))"
        ")"
    )
)

display(union_raster.limit(1))

In [0]:
# Red and Green bands as arrays in new columns
union_raster = union_raster.withColumn(
    "red",
    expr(
        "RS_BandAsArray(tile, 1)"
    )
).withColumn(
    "green",
    expr(
        "RS_BandAsArray(tile, 2)"
    )
)

display(union_raster.limit(1))

In [0]:
# Classification tree based on Red and Green bands and NDVI, NDWI
union_raster = union_raster.withColumn(
    "classification",
    F.expr("""
        transform(
            arrays_zip(ndvi, ndwi, red, green),
            x -> 
                CASE 
                    WHEN year = 2022 THEN
                        CASE 
                            WHEN x.red < 15 AND x.green < 15 THEN 4
                            WHEN x.ndvi > 0.35 AND x.ndwi < -0.35 THEN 2
                            WHEN (x.ndvi < -0.2 AND x.ndwi > 0.35) OR (x.red < 15 AND x.ndwi > 0.35) OR (x.ndwi > 0.45) THEN 3
                            WHEN x.ndvi >= -0.3 AND x.ndvi <= 0.3 AND x.ndwi >= -0.3 AND x.ndwi <= 0.3 THEN 1
                            ELSE 999
                        END
                    WHEN year = 2025 THEN
                        CASE 
                            WHEN x.red < 15 AND x.green < 15 THEN 4
                            WHEN x.ndvi > 0.3 AND x.ndwi < -0.15 THEN 2
                            WHEN x.ndvi < -0.35 AND x.ndwi > 0.55 THEN 3
                            WHEN (x.ndvi >= -0.5 AND x.ndvi <= 0.5 AND x.ndwi >= -0.5 AND x.ndwi <= 0.5) OR (x.ndvi > 0.8 AND x.ndwi > 0.3) THEN 1                           
                            ELSE 999
                        END
                    ELSE 999
                END
        )
    """)
)


In [0]:
display(union_raster.limit(2))

In [0]:
# Classification array as a new band in the raster and defining no data value as 999 
classification_df = (
    union_raster
    .select("tile_x", "tile_y", "rn", "year", "index", expr("RS_MakeRaster(tile, 'I', classification) AS tile").alias("tile"))
    .select("tile_x", "tile_y", "rn", "year", "index", expr("RS_SetBandNoDataValue(tile,1, 999, false)").alias("tile"))
    .select("tile_x", "tile_y", "rn", "year", "index", expr("RS_SetBandNoDataValue(tile,1, 999, true)").alias("tile"))
)
display(classification_df.limit(10))

In [0]:
classification_df.withColumn("raster_binary", expr("RS_AsGeoTiff(tile)")).select("tile_x", "tile_y","rn", "year", "index", "raster_binary").write.mode("overwrite").saveAsTable("geospatial.soma.classification")

In [0]:
display(classification_df.limit(10))

In [0]:
geotiff_df = spark.table("geospatial.soma.classification").repartitionByRange(10,"rn")
display(geotiff_df.limit(10))

In [0]:
# Merge TIFFs
output_tiff_2025 = f"s3://{dataset_bucket_name}/geospatial-dataset/raster/orthophoto/soma/classification/classification_2025.tif"
output_tiff_2022 = f"s3://{dataset_bucket_name}/geospatial-dataset/raster/orthophoto/soma/classification/classification_2022.tif"

class_2025 = geotiff_df.filter(geotiff_df["year"] == 2025)
class_2022 = geotiff_df.filter(geotiff_df["year"] == 2022)


In [0]:
merge_tiffs(class_2025, output_tiff_2025)
merge_tiffs(class_2022, output_tiff_2022)