In [0]:
from pyspark.sql import functions as f
from pyspark.sql.functions import expr, explode, col
from sedona.spark import *
from pyspark.sql.window import Window
from pyspark.sql import functions as F
import math
from pyspark.sql import SparkSession
import os
import tempfile
import numpy as np
from osgeo import gdal, gdalconst
import boto3
from botocore.exceptions import ClientError
import logging
import os


In [0]:
dataset_bucket_name = "revodata-databricks-geospatial"
catalog_name = "geospatial"

In [0]:
os.environ['AWS_ACCESS_KEY_ID'] = dbutils.secrets.get(scope="aws_geospatial_s3", key="access_key")
os.environ['AWS_SECRET_ACCESS_KEY'] = dbutils.secrets.get(scope="aws_geospatial_s3", key="secret_key")
os.environ['AWS_DEFAULT_REGION'] = 'eu-west-2'      # Match your bucket region

In [0]:
config = SedonaContext.builder(). \
    config("spark.hadoop.fs.s3a.bucket.wherobots-examples.aws.credentials.provider","org.apache.hadoop.fs.s3a.AnonymousAWSCredentialsProvider"). \
    getOrCreate()

sedona = SedonaContext.create(config)

In [0]:
# Configure GDAL
gdal.UseExceptions()
gdal.SetConfigOption('GDAL_DISABLE_READDIR_ON_OPEN', 'YES')

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def get_s3_client():
    """Initialize and return boto3 S3 client with credentials"""
    return boto3.client('s3',
                       aws_access_key_id=os.getenv('AWS_ACCESS_KEY_ID'),
                       aws_secret_access_key=os.getenv('AWS_SECRET_ACCESS_KEY'),
                       region_name=os.getenv('AWS_DEFAULT_REGION', 'eu-west-2'))

def upload_to_s3(local_path, s3_path):
    """Upload file to S3 using boto3"""
    s3 = get_s3_client()
    try:
        bucket, key = s3_path.replace("s3://", "").split("/", 1)
        s3.upload_file(local_path, bucket, key)
        logger.info(f"Successfully uploaded to {s3_path}")
        return True
    except ClientError as e:
        logger.error(f"Failed to upload to S3: {e}")
        return False

def merge_tiffs(binary_df, output_path):
    """
    Merges multiple TIFF binary data from a Spark DataFrame into a single TIFF using GDAL
    
    Args:
        binary_df: PySpark DataFrame containing binary TIFF data
        output_path: Path to save the merged TIFF file (can be local or S3 path)
    """
    # Collect all TIFF binary data to driver
    tiff_data = binary_df.select(col("raster_binary")).collect()
    
    with tempfile.TemporaryDirectory() as temp_dir:
        temp_files = []
        
        # Write each binary TIFF to a temporary file
        for i, row in enumerate(tiff_data):
            temp_file = os.path.join(temp_dir, f"temp_{i}.tif")
            with open(temp_file, 'wb') as f:
                f.write(row['raster_binary'])
            temp_files.append(temp_file)
        
        if len(temp_files) > 0:
            try:
                # Create VRT mosaic
                vrt_file = os.path.join(temp_dir, "mosaic.vrt")
                gdal.BuildVRT(vrt_file, temp_files)
                
                # Create local output file
                local_output = os.path.join(temp_dir, "merged_output.tif")
                translate_options = gdal.TranslateOptions(
                    creationOptions=['COMPRESS=DEFLATE', 'TILED=YES', 'BIGTIFF=IF_SAFER']
                )
                gdal.Translate(local_output, vrt_file, options=translate_options)
                
                # Handle output destination
                if output_path.startswith('s3://'):
                    if upload_to_s3(local_output, output_path):
                        logger.info(f"Merged {len(temp_files)} TIFFs to {output_path}")
                    else:
                        raise Exception("Failed to upload to S3")
                else:
                    os.rename(local_output, output_path)
                    logger.info(f"Merged {len(temp_files)} TIFFs to {output_path}")
                
            except Exception as e:
                logger.error(f"Error during merge: {e}")
                raise
        else:
            logger.warning("No TIFF files found to merge")

In [0]:
classification_df = spark.table("geospatial.soma.classification").withColumn("tile", expr("RS_FromGeoTiff(raster_binary)")).repartitionByRange(20,"rn")
classification_df = classification_df.withColumn("maxValue", expr("""RS_SummaryStats(tile, "max", 1, false)"""))

In [0]:
no_interpolation = classification_df.filter(classification_df["maxValue"] != 999).select("tile_x", "tile_y","rn", "year", "index", "raster_binary")

In [0]:
interpolated_df = classification_df.filter(classification_df["maxValue"] == 999).select("tile_x", "tile_y","rn", "year", "index", expr("RS_Interpolate(tile, 2.0, 'variable', 48.0, 6.0)").alias("tile")).withColumn("raster_binary", expr("RS_AsGeoTiff(tile)")).select("tile_x", "tile_y","rn", "year", "index", "raster_binary")


In [0]:
union_df = interpolated_df.unionByName(no_interpolation, allowMissingColumns=False)
union_df.write.mode("overwrite").saveAsTable("geospatial.soma.interpolation")
union_df = spark.table("geospatial.soma.interpolation").withColumn("tile", expr("RS_FromGeoTiff(raster_binary)")).repartitionByRange(10,"rn")

In [0]:
output_tiff_2025 = f"s3://{dataset_bucket_name}/geospatial-dataset/raster/orthophoto/soma/interpolation/interpolation_2025.tif"
output_tiff_2022 = f"s3://{dataset_bucket_name}/geospatial-dataset/raster/orthophoto/soma/interpolation/interpolation_2022.tif"

union_df_2025 = union_df.filter(union_df["year"] == 2025)
union_df_2022 = union_df.filter(union_df["year"] == 2022)
merge_tiffs(union_df_2025, output_tiff_2025)
merge_tiffs(union_df_2022, output_tiff_2022)

In [0]:
union_df.createOrReplaceTempView("union_df_vw")

merged_raster = sedona.sql("""
    SELECT rn, RS_Union_Aggr(tile, index) AS raster
    FROM union_df_vw
    GROUP BY rn
""").repartitionByRange(10,"rn")

merged_raster.createOrReplaceTempView("merged_raster_vw")

In [0]:
diff_raster = merged_raster.withColumn("diff_band", expr( 
        "RS_LogicalDifference("
        "RS_BandAsArray(raster, 1), RS_BandAsArray(raster, 2)"
        ")")).repartitionByRange(10,"rn")

result_df = diff_raster.select("rn", expr("RS_AddBandFromArray(raster, diff_band) AS raster").alias("raster")).withColumn("raster_binary", expr("RS_AsGeoTiff(raster)")).repartitionByRange(10,"rn")

In [0]:
result_df.select("rn", "raster_binary").write.mode("overwrite").saveAsTable("geospatial.soma.change_detection")

In [0]:
# Merge TIFFs
result_df = spark.table("geospatial.soma.change_detection").withColumn("tile", expr("RS_FromGeoTiff(raster_binary)")).repartitionByRange(10,"rn")
output_tiff = f"s3://{dataset_bucket_name}/geospatial-dataset/raster/orthophoto/soma/difference/difference_output.tif"
merge_tiffs(result_df, output_tiff)