In [None]:
%pip install pystac rasterio stac-geoparquet shapely pyarrow==17.0.0

In [None]:
import pyarrow
print(pyarrow.__version__)


In [None]:
dbutils.library.restartPython()

In [None]:
import os
import json
import pystac
import rasterio
import re
import xml.etree.ElementTree as ET
import stac_geoparquet.arrow as sg_arrow

from pathlib import Path
from shapely.geometry import Polygon, mapping
from datetime import datetime, timezone

from pystac.extensions.eo import Band, EOExtension
from pystac.extensions.projection import ProjectionExtension


In [None]:
eods_dirs = [
    '/mnt/eods/sentinel2/ARCSI/2024/01/',
    '/mnt/eods/sentinel2/ARCSI/2024/02/',
    '/mnt/eods/sentinel2/ARCSI/2024/03/',
    '/mnt/eods/sentinel2/ARCSI/2024/04/',
    '/mnt/eods/sentinel2/ARCSI/2024/05/',
    '/mnt/eods/sentinel2/ARCSI/2024/06/'
    ]

def list_files_recursively(directories, extension):
    files = []
    for directory in directories:
        dirs_to_process = [directory]
        while dirs_to_process:
            current_dir = dirs_to_process.pop(0)
            for file_info in dbutils.fs.ls(current_dir):
                if file_info.isDir():
                    dirs_to_process.append(file_info.path)
                elif file_info.name.endswith(extension):
                    files.append(file_info.path)
    return files

# List all 10-band surface reflectance Geotiffs
image_files = list_files_recursively(eods_dirs, '_vmsk_sharp_rad_srefdem_stdsref.tif')
print(f"Found {len(image_files)} 10-band geotiff(s)")

# Convert the paths from 'dbfs:/' to '/dbfs/' to allow the `enable_file_access` 'path warming' function to work
image_paths = [file.replace('dbfs:/', '/dbfs/') for file in image_files]

for path in image_paths[:5]:
    print(path)

# List all metadata .xml files
metadata_files = list_files_recursively(eods_dirs, '_meta.xml')
print(f"Found {len(metadata_files)} metadata file(s)")

metadata_paths = [file.replace('dbfs:/', '/dbfs/') for file in metadata_files]

for path in metadata_paths[:5]:
    print(path)

In [None]:
sentinel2_eods_bands = [
    Band.create(name='B02', common_name='blue', description='Band 2 - Blue', center_wavelength='0.49', full_width_half_max='0.098'),
    Band.create(name='B03', common_name='green', description='Band 3 - Green', center_wavelength='0.56', full_width_half_max='0.045'),
    Band.create(name='B04', common_name='red', description='Band 4 - Red', center_wavelength='0.665', full_width_half_max='0.038'),
    Band.create(name='B05', common_name='rededge071', description='Band 5 - Vegetation red edge 1', center_wavelength='0.704', full_width_half_max='0.019'),
    Band.create(name='B06', common_name='rededge075', description='Band 6 - Vegetation red edge 2', center_wavelength='0.74', full_width_half_max='0.018'),
    Band.create(name='B07', common_name='rededge078', description='Band 7 - Vegetation red edge 3', center_wavelength='0.783', full_width_half_max='0.028'),
    Band.create(name='B08', common_name='nir', description='Band 8 - NIR', center_wavelength='0.842', full_width_half_max='0.145'),
    Band.create(name='B8A', common_name='nir08', description='Band 8A - Vegetation red edge 4', center_wavelength='0.865', full_width_half_max='0.033'),
    Band.create(name='B11', common_name='swir16', description='Band 11 - SWIR (1.6)', center_wavelength='1.61', full_width_half_max='0.143'),
    Band.create(name='B12', common_name='swir22', description='Band 12 - SWIR (2.2)', center_wavelength='2.19', full_width_half_max='0.242')
]


In [None]:
catalog = pystac.Catalog(id='Earth Observation Data Service', description='This catalog is a proof of concept using Sentinel-2 scenes from EODS.')


In [None]:
def enable_file_access(pth):
    """
    For a given file path, iterate over the parent directories and attempt to list them using `!ls {str(p)} > /dev/null`, which seems to cache the path within the Databricks environment, allowing IO operations on it.
    """
    item = Path(pth)

    # Get the parent directories of the path excluding the first 4 (e.g., '/dbfs/', 'mnt', 'eods', 'sentinel1' or 'sentinel2')
    parent_pths = list(reversed(item.parents))[4:]

    # Iterate through the parent directories and list their contents
    for p in parent_pths:
        #dbutils.fs.ls(str(p).replace('/dbfs/', ''))
        !ls {str(p)} > /dev/null

# Ensure the file paths are accessible
for img_path in image_paths:
    enable_file_access(img_path)

for metadata_path in metadata_paths:
    enable_file_access(metadata_path)    

In [None]:
"""
The following main workflow functions to create STAC items from .tif files in the EODS mount 
have been developed based on this tutorial: https://stacspec.org/en/tutorials/3-create-stac-item-with-extension/
"""

def get_bbox_and_footprint(raster):
    with rasterio.open(raster) as r:
        bounds = r.bounds
        bbox = [bounds.left, bounds.bottom, bounds.right, bounds.top]
        footprint = Polygon([
            [bounds.left, bounds.bottom],
            [bounds.left, bounds.top],
            [bounds.right, bounds.top],
            [bounds.right, bounds.bottom]
        ])
    return (bbox, mapping(footprint))

# Function to create a basic STAC item from a raster image
def create_basic_item(img_path):
    bbox, footprint = get_bbox_and_footprint(img_path)
    item_id = os.path.splitext(os.path.basename(img_path))[0]
    datetime_utc = datetime.fromtimestamp(os.path.getmtime(img_path), tz=timezone.utc)
    
    item = pystac.Item(
        id=item_id,
        geometry=footprint,
        bbox=bbox,
        datetime=datetime_utc,
        properties={}
    )
    
    asset = pystac.Asset(href=img_path, media_type=pystac.MediaType.GEOTIFF)
    item.add_asset("image", asset)
    
    return item

def apply_eo_extension(item, asset, metadata_path):
    eo = EOExtension.ext(item, add_if_missing=True)
    eo.apply(bands=sentinel2_eods_bands)
    
    eo_on_asset = EOExtension.ext(asset)
    eo_on_asset.apply(sentinel2_eods_bands)
    
    tree = ET.parse(metadata_path)
    root = tree.getroot()
    
    supplemental_info = root.find(".//gmd:supplementalInformation/gco:CharacterString", 
                                  namespaces={'gmd': 'http://www.isotc211.org/2005/gmd',
                                              'gco': 'http://www.isotc211.org/2005/gco'})
    
    if supplemental_info is not None and supplemental_info.text:
        match = re.search(r'ARCSI_CLOUD_COVER:\s*([\d.]+)', supplemental_info.text)
        if match:
            eo.cloud_cover = float(match.group(1))

def apply_projection_extension(item):
    proj_ext = ProjectionExtension.ext(item, add_if_missing=True)
    proj_ext.epsg = 27700

def add_common_metadata(item):
    item.common_metadata.platform = "Sentinel-2"
    item.common_metadata.instruments = "msi"
    item.common_metadata.gsd = 10

# Process an individual image and update catalog
def process_image(img_path, catalog):
    item = create_basic_item(img_path)
    asset = item.assets["image"]
    metadata_path = img_path.replace('.tif', '_meta.xml')
    
    apply_eo_extension(item, asset, metadata_path)
    apply_projection_extension(item)
    add_common_metadata(item)
    
    catalog.add_item(item)

# Main function to process all images
def process_all_images(image_paths, catalog):
    for img_path in image_paths:
        process_image(img_path, catalog)

In [None]:
process_all_images(image_paths, catalog)

In [None]:
# FileStore location for the test STAC
output_dir = '/FileStore/tomkdefra/test_stac'
output_dir_dbfs = '/dbfs/FileStore/tomkdefra/test_stac'

# Create the directory if it doesn't exist
dbutils.fs.mkdirs(output_dir)

# Normalise HREFs
catalog.normalize_hrefs('/dbfs' + (output_dir))

In [None]:
# Save the catalog
catalog.save(catalog_type=pystac.CatalogType.SELF_CONTAINED)


In [None]:
catalog_file = os.path.join(output_dir_dbfs, 'catalog.json')


In [None]:
read_catalog = pystac.Catalog.from_file(catalog_file)
catalog_json = read_catalog.to_dict()
print(json.dumps(catalog_json, indent=2))


In [None]:
# Access and print the JSON for first n items
items = list(read_catalog.get_items())
n_items_json = [item.to_dict() for item in items[:1]]
print("\nn items in the catalog:")
print(json.dumps(n_items_json, indent=2))


In [None]:
def convert_stac_to_geoparquet(catalog, output_file):
    """
    Converts a STAC catalog to a GeoParquet file.

    Args:
        catalog (pystac.Catalog): The STAC catalog containing items to convert.
        output_file (str): The path where the output GeoParquet file will be saved.

    The function retrieves all STAC items from the catalog, converts them to an
    Apache Arrow format using the stac-geoparquet library, and writes the result
    to a GeoParquet file. If an error occurs, it prints the error message and stack trace.

    Note:
        I think Stac-geoparquet requires a pyarrow version > 16.0.0. The version installed 
        on the clusters (at least cluster 1c) is 8.0.0 which results in an error. After 
        pip installing pyarrow==17.0.0 you will need to run dbutils.library.restartPython()
        to ensure the cluster is using the up-to-date version.
        
    Reference: 
    Based on the example from: https://stac-utils.github.io/stac-geoparquet/latest/examples/naip/#loading-to-arrow
    """
    try:
        items = list(catalog.get_all_items())
        record_batch_reader = sg_arrow.parse_stac_items_to_arrow(items)
        sg_arrow.to_parquet(record_batch_reader, output_file)
        print(f"Successfully created GeoParquet file: {output_file}")
    
    except Exception as e:
        print(f"Error: {e}")
        print(f"Error type: {type(e)}")
        import traceback
        traceback.print_exc()

In [None]:
output_file = "/dbfs/FileStore/tomkdefra/test_stac_geoparquet.parquet"
convert_stac_to_geoparquet(catalog, output_file)


In [None]:
dbutils.fs.mv("FileStore/tomkdefra/test_stac_geoparquet.parquet", "mnt/lab/unrestricted/eods_stac/eods_s2.parquet", True)