## Python notebook for generating the STAC catalog from Google Earth Engine (GEE) assets with local style files

### Tools:
1. Pystac 
2. Google Earth Engine (EE)
3. Rasterio (for bounds)
4. Geopandas (for bounds)
5. Matplotlib (for thumbnails)

This notebook generates a STAC catalog from GEE assets, while using local style files to create thumbnails.

### 1. Importing the required modules

In [2]:
import os
import json
import xml.etree.ElementTree as ET
from datetime import datetime

import rasterio
from rasterio.warp import transform_bounds
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, Normalize
import pystac
import sys
import ee

parent_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))
if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)
import constants
import numpy as np
from shapely.geometry import mapping, box
from pystac.extensions.table import TableExtension
from pystac import Asset, MediaType
from pystac.extensions.classification import ClassificationExtension, Classification
from pystac.extensions.raster import RasterExtension,RasterBand
from pystac.extensions.projection import ProjectionExtension

# Authenticate and Initialize GEE
try:
    ee.Initialize()
    print("Google Earth Engine initialized successfully!")
except Exception as e:
    print(f"Initialization failed: {e}")
    # Authenticate if not already done. Follow the instructions.
    # ee.Authenticate()
    print("You may need to run ee.Authenticate() and ee.Initialize()")

### 2. Defining Constants and GEE Asset IDs with Local Style Files

The `blocks_info` dictionary now contains both GEE asset IDs and local style file paths.

In [3]:
base_dir="../data/"

blocks_info = [
    {
        "block": "gobindpur",
        "location": "jharkhand",
        "raster_asset_id": "your_ee_raster_asset_id_here",
        "vector_asset_id": "your_ee_vector_asset_id_here",
        "raster_style_file": "style_file.qml",
        "vector_style_file": "swb_style.qml",
        "raster_title": "LULC_map",
        "vector_title": "Water_bodies"
    },
    {
        "block": "mirzapur",
        "location": "uttar_pradesh",
        "raster_asset_id": "your_ee_raster_asset_id_here",
        "vector_asset_id": "your_ee_vector_asset_id_here",
        "raster_style_file": "style_file.qml",
        "vector_style_file": "swb_style.qml",
        "raster_title": "LULC_map",
        "vector_title": "Water_bodies"
    }
]

### 3. Helper Functions (Unchanged from original notebook)

In [4]:
corestack_dir = os.path.join(base_dir, 'CorestackCatalogs')

def extract_raster_dates_from_filename(raster_filename):
    try:
        print(raster_filename)
        parts = raster_filename.split('_')
        start_date = datetime.strptime(parts[2], "%Y-%m-%d")
        end_date = datetime.strptime(parts[3], "%Y-%m-%d")
        print(start_date)
        print(end_date)
    except Exception as e:
        raise ValueError(f"Failed to extract raster dates from filename '{raster_filename}': {e}")
        
    return start_date, end_date    

def parse_qml_classes(qml_path):
    tree = ET.parse(qml_path)
    root = tree.getroot()
    classes = []

    for entry in root.findall(".//paletteEntry"):
        class_info = {}
        for attr_key, attr_value in entry.attrib.items():
            if attr_key == "value":
                try:
                    class_info[attr_key] = int(attr_value)
                except ValueError:
                    class_info[attr_key] = attr_value
            else:
                class_info[attr_key] = attr_value
        classes.append(class_info)

    if not classes:
        for entry in root.findall(".//item"):
            class_info = {}
            for attr_key, attr_value in entry.attrib.items():
                if attr_key == "value":
                    try:
                        class_info[attr_key] = int(attr_value)
                    except ValueError:
                        class_info[attr_key] = attr_value
                else:
                    class_info[attr_key] = attr_value
            classes.append(class_info)

    return classes

def _rgb_to_hex(rgb_string):
   
    if not isinstance(rgb_string, str):
        return None
    try:
        parts = [int(p) for p in rgb_string.split(',')[:3]]
        return f'#{parts[0]:02x}{parts[1]:02x}{parts[2]:02x}'
    except:
        return None

def parse_qml_style(qml_path):
    
    try:
        tree = ET.parse(qml_path)
        root = tree.getroot()

        symbol = None
        renderer = root.find(".//renderer-v2[@type='categorizedSymbol']")
        if renderer is not None:
            first_category = renderer.find(".//category")
            if first_category is not None:
                symbol_id = first_category.get("symbol")
                if symbol_id:
                    symbol = root.find(f".//symbols/symbol[@name='{symbol_id}']")
        
        
        if symbol is None:
            renderer = root.find(".//renderer-v2[@type='singleSymbol']")
            if renderer is not None:
                symbol = renderer.find(".//symbol")

        if symbol is not None:
            
            fill_color = symbol.get("color")
            outline_color = symbol.get("outlineColor")

        
            if fill_color is None or outline_color is None:
                for child in symbol.iter():
                    k_val = child.get('k')
                    name_val = child.get('name')
                    
                    if k_val == 'color' and fill_color is None:
                        fill_color = _rgb_to_hex(child.get('v'))
                    elif k_val == 'outline_color' and outline_color is None:
                        outline_color = _rgb_to_hex(child.get('v'))
                    elif name_val == 'color' and fill_color is None:
                        fill_color = _rgb_to_hex(child.get('value'))
                    elif name_val == 'outline_color' and outline_color is None:
                        outline_color = _rgb_to_hex(child.get('value'))

            
            if fill_color and outline_color:
                if not fill_color.startswith('#'):
                    fill_color = f"#{fill_color}"
                if not outline_color.startswith('#'):
                    outline_color = f"#{outline_color}"
                return fill_color, outline_color
            
        print("Could not find a recognized renderer or color properties in QML. Defaulting to standard colors.")
        return None, None

    except Exception as e:
        print(f"Error parsing QML file: {e}")
        return None, None

def generate_vector_thumbnail(vector_path, out_path, qml_path):
   
    try:
        gdf = gpd.read_file(vector_path)
    except Exception as e:
        print(f"Error reading vector file: {e}")
        return
    
    if gdf.crs is None or gdf.crs.to_epsg() != 4326:
        gdf = gdf.to_crs(epsg=4326)

    
    fill_color, edge_color = parse_qml_style(qml_path)
    
    if fill_color is None:
        fill_color = "lightblue"
    if edge_color is None:
        edge_color = "blue"

    print(f"Parsed QML fill color: {fill_color}")
    print(f"Parsed QML edge color: {edge_color}")
    
    fig, ax = plt.subplots(figsize=(3, 3))
    fig.patch.set_facecolor("white")
    ax.set_facecolor("white")

    gdf.plot(ax=ax, color=fill_color, edgecolor=edge_color, linewidth=0.5)

    ax.axis('off')

    plt.savefig(out_path, dpi=150, bbox_inches='tight', pad_inches=0, facecolor=fig.get_facecolor())
    plt.close()

def generate_raster_thumbnail(tif_path, out_path, qml_path):
   
    with rasterio.open(tif_path) as src:
        arr = src.read(1) 
        nodata = src.nodata
        if nodata is not None:
            arr = np.ma.masked_equal(arr, nodata)
    
    unique_raster_values = np.unique(arr.compressed() if isinstance(arr, np.ma.MaskedArray) else arr)
    print(f"Unique values in raster data: {unique_raster_values}")
    
    
    style_info = parse_qml_classes(qml_path)

    filtered_style_info = [cls for cls in style_info if cls.get('value') in unique_raster_values]
    
    values = [cls['value'] for cls in filtered_style_info if 'value' in cls]
    colors = [cls['color'] for cls in filtered_style_info if 'color' in cls]
    
    print(f"Parsed QML values: {values}")
    print(f"Parsed QML colors: {colors}")
    
    
    try:
        if not values or not colors or len(values) != len(colors):
            raise ValueError("Invalid or insufficient palette information in QML file.")
    
        sorted_indices = np.argsort(values)
        sorted_values = np.array(values)[sorted_indices]
        sorted_colors = np.array(colors)[sorted_indices]

        cmap = ListedColormap(sorted_colors)
        bounds = np.array(sorted_values) - 0.5
        bounds = np.append(bounds, sorted_values[-1] + 0.5)
        norm = Normalize(vmin=bounds.min(), vmax=bounds.max())

    except ValueError as e:
        print(f"Skipping palette generation due to error: {e}. Using a default colormap.")
        cmap = 'gray'
        norm = None

    plt.figure(figsize=(3, 3), dpi=100)
    
    plt.imshow(arr, cmap=cmap, norm=norm, interpolation='none')
    plt.axis('off')

    plt.savefig(out_path, bbox_inches='tight', pad_inches=0)
    plt.close()

def parse_vector_descriptions(qml_path):
    tree = ET.parse(qml_path)
    root = tree.getroot()
    columns = []
    colnames = []
    coldesc = []
    for entry in root.findall(".//alias"):
        for attr_key, attr_value in entry.attrib.items():
            if (attr_key == 'field'):
                colnames.append(attr_value)
            if (attr_key == 'name'):
                coldesc.append(attr_value)
    vector_desc_df = pd.DataFrame([colnames,coldesc]).T
    vector_desc_df.columns = ['column_name','column_description']
    return vector_desc_df

def create_raster_item(location, block, raster_asset_id, raster_thumbnail, raster_style_file, base_dir, data_url, title, stac_output_dir):
    try:
        ee_image = ee.Image(raster_asset_id)
        image_info = ee_image.getInfo()
        properties = image_info.get('properties', {})
        
        # Get metadata from GEE asset
        bounds_wgs84 = ee_image.geometry().bounds().getInfo()['coordinates']
        bbox = bounds_wgs84[0][0] + bounds_wgs84[0][2]
        geom = ee_image.geometry().bounds().getInfo()
        
        start_date = datetime.fromtimestamp(properties.get('system:time_start', 0) / 1000)
        end_date = datetime.fromtimestamp(properties.get('system:time_end', 0) / 1000)
        
        # The asset is not a local file, so we need a dummy path to generate the thumbnail
        # Note: This requires a sample .tif file with similar classes for matplotlib to work
        # For this example, we assume you have a file at this path.
        dummy_raster_path = os.path.join(base_dir, 'dummy_raster_for_thumbnail.tif')
        
        generate_raster_thumbnail(dummy_raster_path, raster_thumbnail, raster_style_file)
        style_info = parse_qml_classes(raster_style_file)

        style_json_path = os.path.join(stac_output_dir, os.path.basename(raster_style_file).replace('.qml', '.json'))
        with open(style_json_path, "w") as f:
            json.dump(style_info, f, indent=2)

        item_id = f"{title}_{os.path.basename(raster_asset_id).replace(':', '_').replace('/', '_')}"
        
        item = pystac.Item(
            id=item_id,
            bbox=bbox,
            geometry=geom,
            datetime=start_date,
            properties={
                "title": title,
                "description": f"Raster data for {title} in {block} of {location}, sourced from GEE asset '{raster_asset_id}'",
                "start_datetime": start_date.isoformat() + 'Z',
                "end_datetime": end_date.isoformat() + 'Z',
                "gee:asset_id": raster_asset_id
            }
        )

        item.add_asset("data", Asset(
            href=f"https://earthengine.googleapis.com/v1alpha/{raster_asset_id}:getPixels",
            media_type=MediaType.GEOTIFF,
            roles=["data"],
            title="GEE Raster Layer"
        ))

        item.add_asset("thumbnail", Asset(
            href=os.path.join(data_url, os.path.relpath(raster_thumbnail, start=base_dir)),
            media_type=MediaType.PNG,
            roles=["thumbnail"],
            title="Raster Thumbnail"
        ))

        item.add_asset("legend", Asset(
            href=os.path.join(data_url, os.path.relpath(style_json_path, start=base_dir)),
            media_type=MediaType.JSON,
            roles=["metadata"],
            title="Legend JSON"
        ))

        item.add_asset("style", Asset(
            href=os.path.join(data_url, os.path.relpath(raster_style_file, start=base_dir)),
            media_type=MediaType.XML,
            roles=["metadata"],
            title="Raster Style (QML)"
        ))
        
        return item

    except Exception as e:
        print(f"Error creating raster item from {raster_asset_id}: {e}")
        return None

def create_vector_item(location, block, vector_asset_id, vector_desc_df, vector_thumbnail, vector_style_file, base_dir, data_url, title, stac_output_dir, qml_path):
    try:
        ee_feature_collection = ee.FeatureCollection(vector_asset_id)
        collection_info = ee_feature_collection.getInfo()
        
        # Get metadata from GEE asset
        bounds_wgs84 = ee_feature_collection.geometry().bounds().getInfo()['coordinates']
        bbox = bounds_wgs84[0][0] + bounds_wgs84[0][2]
        geom = ee_feature_collection.geometry().bounds().getInfo()

        start_date = datetime.now() # Default, as date info might not be available
        end_date = datetime.now()
        
        # Dummy path for thumbnail generation
        dummy_vector_path = os.path.join(base_dir, 'dummy_vector_for_thumbnail.geojson')

        generate_vector_thumbnail(dummy_vector_path, vector_thumbnail, qml_path)

        item_id = f"{title}_{os.path.basename(vector_asset_id).replace(':', '_').replace('/', '_')}"

        item = pystac.Item(
            id=item_id,
            geometry=geom,
            bbox=bbox,
            datetime=start_date,
            properties={
                "title": title,
                "description": f"Vector data for {title} in {block} of {location}, sourced from GEE asset '{vector_asset_id}'",
                "start_datetime": start_date.isoformat() + 'Z',
                "end_datetime": end_date.isoformat() + 'Z',
                "gee:asset_id": vector_asset_id
            }
        )

        # This part requires an export to get table info, which is not feasible in-notebook.
        # The column definitions are based on the style file for this hybrid approach.
        table_ext = TableExtension.ext(item, add_if_missing=True)
        # This part will be empty or incomplete without reading a full vector file.
        # For a full implementation, you would need to export the GEE asset to get schema.
        table_ext.columns = [
            {
                "name": row['column_name'],
                "type": "string",
                "description" : row['column_description']
            }
            for ind,row in vector_desc_df.iterrows()
        ]
        
        item.add_asset("data", Asset(
            href=f"https://earthengine.googleapis.com/v1alpha/{vector_asset_id}:getPixels",
            media_type=MediaType.GEOJSON,
            roles=["data"],
            title="GEE Vector Layer"
        ))

        item.add_asset("thumbnail", Asset(
            href=os.path.join(data_url, os.path.relpath(vector_thumbnail, start=base_dir)),
            media_type=MediaType.PNG,
            roles=["thumbnail"],
            title="Vector Thumbnail"
        ))

        item.add_asset("style", Asset(
            href=os.path.join(data_url, os.path.relpath(vector_style_file, start=base_dir)),
            media_type=MediaType.XML,
            roles=["metadata"],
            title="Vector Style"
        ))
    
        return item
    
    except Exception as e:
        print(f"Error creating vector item from {vector_asset_id}: {e}")
        return None

def generate_stac_for_block(info):
    base_dir = '../data/'
    corestack_dir = os.path.join(base_dir, 'CorestackCatalogs')

    location = info['location']
    block = info['block']
    
    location_dir = os.path.join(corestack_dir, location)
    block_dir=os.path.join(location_dir, block)

    os.makedirs(block_dir, exist_ok=True)
    
    block_catalog = pystac.Catalog(
        id=block,
        title=f"STAC for {block}",
        description=f"STAC catalog for {block} block data in {location}"
    )

    stac_output_dir = os.path.join(base_dir, 'STAC_output')
    os.makedirs(stac_output_dir, exist_ok=True)

    # Process Raster Layer
    if 'raster_asset_id' in info:
        raster_thumbnail_filename = f'{block}_{os.path.splitext(os.path.basename(info["raster_asset_id"]))[0]}_thumbnail.png'
        raster_thumbnail_path = os.path.join(stac_output_dir, raster_thumbnail_filename)

        item = create_raster_item(location, block, info['raster_asset_id'], raster_thumbnail_path, os.path.join(base_dir, info['raster_style_file']), base_dir, constants.data_url, info['raster_title'], stac_output_dir)
        if item:
            item.set_self_href(os.path.join(block_dir, f"{item.id}.json"))
            item.save_object()
            block_catalog.add_item(item)

    # Process Vector Layer
    if 'vector_asset_id' in info:
        vector_thumbnail_filename = f'{block}_{os.path.splitext(os.path.basename(info["vector_asset_id"]))[0]}_thumbnail.png'
        vector_thumbnail_path = os.path.join(stac_output_dir, vector_thumbnail_filename)
        
        vector_desc_df = parse_vector_descriptions(os.path.join(base_dir, info['vector_style_file']))
        item = create_vector_item(location, block, info['vector_asset_id'], vector_desc_df, vector_thumbnail_path, os.path.join(base_dir, info['vector_style_file']), base_dir, constants.data_url, info['vector_title'], stac_output_dir, os.path.join(base_dir, info['vector_style_file']))
        if item:
            item.set_self_href(os.path.join(block_dir, f"{item.id}.json"))
            item.save_object()
            block_catalog.add_item(item)
        
    block_catalog.set_self_href(os.path.join(block_dir, 'catalog.json'))
    block_catalog.save_object()
    print(f" STAC catalog created for block: {block} in {location}")

    location_catalog_path = os.path.join(location_dir, 'catalog.json')
    location_catalog_modified = False 

    if os.path.exists(location_catalog_path):
        location_catalog = pystac.read_file(location_catalog_path)
        print(f"Loaded existing location catalog: {location}")
    else:
        os.makedirs(location_dir, exist_ok=True)
        location_catalog = pystac.Catalog(
            id=location,
            title=f"STAC for {location}",
            description=f"STAC catalog for data in {location}"
        )
        location_catalog.set_self_href(location_catalog_path)
        print(f"Created new location catalog: {location}")
        location_catalog_modified = True

    child_id_to_add = block_catalog.id
    existing_child_ids = {child.id for child in location_catalog.get_children()} 
    
    if child_id_to_add not in existing_child_ids:
        child_to_add = pystac.read_file(os.path.join(block_dir, 'catalog.json'))
        location_catalog.add_child(child_to_add)
        location_catalog_modified = True 
        print(f"Added block '{block}' to location catalog '{location}'.")
    else:
        print(f"Block '{block}' already exists in location catalog '{location}'") 
    
    if location_catalog_modified:
        location_catalog.normalize_and_save(location_dir, catalog_type=pystac.CatalogType.SELF_CONTAINED)
        print(f"Updated location catalog for: {location}")

def generate_root_catalog(blocks_info, base_dir, corestack_dir):
    root_catalog_path = os.path.join(corestack_dir, "catalog.json")

    if os.path.exists(root_catalog_path):
        root_catalog = pystac.read_file(root_catalog_path)
        print("Loaded existing root catalog.")
    else:
        root_catalog = pystac.Catalog(
            id="corestack",
            title="CorestackCatalogs",
            description="Root catalog containing all location-based sub-catalogs"
        )
        root_catalog.set_self_href(root_catalog_path) 
        print("Created new root catalog.")
    
    existing_root_children_ids = {child.id for child in root_catalog.get_children()}

    for info in blocks_info:
        location = info["location"]
        location_catalog_path = os.path.join(corestack_dir, location, "catalog.json")

        if os.path.exists(location_catalog_path):
            if location not in existing_root_children_ids:
                location_catalog = pystac.read_file(location_catalog_path)
                root_catalog.add_child(location_catalog)
                existing_root_children_ids.add(location)
                print(f"Added location catalog '{location}' to root catalog.")
            else:
                print(f"Location catalog '{location}' already linked in root catalog.")
        else:
            print(f"Warning: Location catalog not found for {location} at {location_catalog_path}")
                
    root_catalog.set_self_href(os.path.join(corestack_dir, "catalog.json"))
    root_catalog.normalize_and_save(corestack_dir, catalog_type=pystac.CatalogType.SELF_CONTAINED)
    print(f"Root catalog generated at {os.path.join(corestack_dir, 'catalog.json')}")

### 4. Execution

This cell runs the functions to generate all the catalogs from GEE assets.

In [5]:
base_dir = '../data/'
corestack_dir = os.path.join(base_dir, 'CorestackCatalogs')

for block_info in blocks_info:
    print(f"Processing block: {block_info['block']}")
    generate_stac_for_block(block_info)
    
generate_root_catalog(blocks_info, base_dir=base_dir, corestack_dir=corestack_dir)
