In [1]:
import re
import json
import requests
from datetime import datetime
from typing import Dict, TypedDict, Any, Tuple
import numpy as np
import boto3
import rasterio
from rasterio.io import MemoryFile
from rasterio.mask import mask
from rasterio.merge import merge
from rasterio.session import AWSSession
from pystac_client import Client
import planetary_computer
from planetary_computer import sign
from shapely.geometry import shape
from langgraph.graph import StateGraph, END
from langchain_core.runnables import RunnableLambda
from langchain.tools import tool

In [2]:
# Global constants
ESA_CLASS_NAMES = {
    0: "No Data", 10: "Tree cover", 20: "Shrubland", 30: "Grassland",
    40: "Cropland", 50: "Built-up", 60: "Bare / sparse vegetation",
    70: "Snow and ice", 80: "Permanent water bodies", 90: "Herbaceous wetland",
    95: "Mangroves", 100: "Moss and lichen"
}

DEFAULT_COG_PATH = "s3://gfw-data-lake/umd_glad_dist_alerts/v20250329/raster/epsg-4326/cog/default.tif"
INTENSITY_COG_PATH = "s3://gfw-data-lake/umd_glad_dist_alerts/v20250329/raster/epsg-4326/cog/intensity.tif"
STAC_API_URL = "https://planetarycomputer.microsoft.com/api/stac/v1"
AWS_PROFILE = "zeno_internal_sso"
REFERENCE_DATE = datetime(2015, 1, 1)

In [3]:
def parse_user_message(state: dict) -> dict:
    """
    Parse user message from `state["message"]`, returning aoi, date_range, and thresholds.
    """
    message = state["message"]
    geojson_match = re.search(r'\{.*\}', message, re.DOTALL)
    geojson_str = geojson_match.group(0) if geojson_match else None
    geojson = json.loads(geojson_str)

    date_match = re.search(r'date range:\s*(\d{4}-\d{2}-\d{2})\s*to\s*(\d{4}-\d{2}-\d{2})', message)
    start_date, end_date = date_match.groups()

    conf_match = re.search(r'confidence:\s*([0-9.]+)', message)
    intensity_match = re.search(r'intensity:\s*([0-9.]+)', message)

    return {
        "aoi": geojson,
        "date_range": (start_date, end_date),
        "thresholds": {
            "confidence": float(conf_match.group(1)),
            "intensity": int(float(intensity_match.group(1)))
        }
    }

In [4]:
def convert_date_range_to_days(date_range):
    start_date = datetime.strptime(str(date_range[0])[:10], "%Y-%m-%d")
    end_date = datetime.strptime(str(date_range[1])[:10], "%Y-%m-%d")
    return (start_date - REFERENCE_DATE).days, (end_date - REFERENCE_DATE).days, start_date, end_date

In [5]:
def load_alert_data(geometry):
    session = boto3.Session(profile_name=AWS_PROFILE)
    with rasterio.Env(AWSSession(session), AWS_REQUEST_PAYER="requester"):
        with rasterio.open(DEFAULT_COG_PATH) as src1:
            default_data, _ = mask(src1, geometry, crop=True)
        with rasterio.open(INTENSITY_COG_PATH) as src2:
            intensity_data, _ = mask(src2, geometry, crop=True)
    return default_data[0], intensity_data[0]

In [6]:
def load_esa_worldcover_mosaic(aoi_geom, start_date, end_date):
    stac = Client.open(STAC_API_URL)
    search = stac.search(
        collections=["esa-worldcover"],
        intersects=aoi_geom,
        datetime=f"{start_date.date().isoformat()}/{end_date.date().isoformat()}",
        max_items=10
    )
    items = list(search.get_items())
    if not items:
        raise ValueError("No ESA WorldCover items found for AOI.")

    session = boto3.Session(profile_name=AWS_PROFILE)
    urls = [sign(item.assets["map"].href) for item in items if "map" in item.assets]

    with rasterio.Env(AWSSession(session), AWS_REQUEST_PAYER="requester"):
        datasets = [rasterio.open(url) for url in urls]
        mosaic, out_transform = merge(datasets)
        profile = datasets[0].profile.copy()
        profile.update({
            "height": mosaic.shape[1], "width": mosaic.shape[2],
            "transform": out_transform
        })
        with rasterio.io.MemoryFile() as memfile:
            with memfile.open(**profile) as ds:
                ds.write(mosaic)
                land_cover_data, _ = mask(ds, [aoi_geom.__geo_interface__], crop=True)
        for ds in datasets:
            ds.close()
    return land_cover_data[0]

In [7]:
def filter_and_count_pixels(inputs: Dict) -> Dict:
    aoi = inputs["aoi"]
    date_range = inputs["date_range"]
    thresholds = inputs["thresholds"]

    start_days, end_days, start_date, end_date = convert_date_range_to_days(date_range)
    geometry = [shape(aoi["geometry"])]

    encoded, intensities = load_alert_data(geometry)
    confidence = encoded // 10000
    days_since_2015 = encoded % 10000

    land_cover = load_esa_worldcover_mosaic(shape(aoi["geometry"]), start_date, end_date)

    # Align shapes
    min_rows = min(encoded.shape[0], land_cover.shape[0])
    min_cols = min(encoded.shape[1], land_cover.shape[1])
    confidence = confidence[:min_rows, :min_cols]
    days_since_2015 = days_since_2015[:min_rows, :min_cols]
    intensities = intensities[:min_rows, :min_cols]
    land_cover = land_cover[:min_rows, :min_cols]

    if land_cover.shape != confidence.shape:
        raise ValueError(f"Shape mismatch: {land_cover.shape} vs {confidence.shape}")

    max_days = (datetime.today() - REFERENCE_DATE).days
    # Replace no-data value (9999) with the max valid day
    days_since_2015[days_since_2015 == 9999] = days_since_2015[days_since_2015 < 9999].max()

    #print(f"days_since_2015 range in data: {days_since_2015.min()} - {days_since_2015.max()}")
    #print(f"Date filter: {start_days} - {end_days}, Max allowed: {max_days}")

    #print(f"Unique confidence values: {np.unique(confidence)}")
    #print(f"Confidence threshold: {thresholds['confidence']}")
    #print(f"Intensity threshold: {thresholds['intensity']}")
    #print(f"Intensity range: {intensities.min()} - {intensities.max()}")

    #print(f"Data shape after cropping: {min_rows}x{min_cols}")

    valid_mask = (
        (confidence >= thresholds["confidence"]) &
        (days_since_2015 >= start_days) &
        (days_since_2015 <= end_days) &
        (days_since_2015 < max_days) &
        (intensities >= thresholds["intensity"])
    )
    
    #print(f"Valid pixels after mask: {np.sum(valid_mask)}")

    # Count pixels per class
    results = {}
    valid_lc = land_cover[valid_mask]
    unique_classes, counts = np.unique(valid_lc, return_counts=True)
    for cls, count in zip(unique_classes, counts):
        results[int(cls)] = {
            "class": ESA_CLASS_NAMES.get(int(cls), "Unknown"),
            "count": int(count)
        }

    return {"result": results}

In [8]:
class GraphState(TypedDict):
    message: str
    aoi: dict
    date_range: Tuple[str, str]
    thresholds: Dict[str, float]
    result: Dict[str, Any]

builder = StateGraph(GraphState)

builder.add_node("parse_message", RunnableLambda(parse_user_message))
builder.add_node("filter_and_count", RunnableLambda(filter_and_count_pixels))

builder.set_entry_point("parse_message")
builder.add_edge("parse_message", "filter_and_count")
builder.add_edge("filter_and_count", END)

graph = builder.compile()


In [9]:
#2 is low confidence, 3 is high confidence (1 means nothing and the values should all be >=20,000 and <40,000)

In [10]:
# Test AOI over the Peruvian Amazon (Loreto and Ucayali regions)

message = """
{
  "type": "Feature",
  "geometry": {
    "type": "Polygon",
    "coordinates": [
      [
        [-75.0, -3.0],
        [-70.0, -3.0],
        [-70.0, -6.0],
        [-75.0, -6.0],
        [-75.0, -3.0]
      ]
    ]
  }
}
date range: 2019-01-01 to 2020-12-31
confidence: 3
intensity: 50
"""

result = graph.invoke({"message": message})
print(result)





{'message': '\n{\n  "type": "Feature",\n  "geometry": {\n    "type": "Polygon",\n    "coordinates": [\n      [\n        [-75.0, -3.0],\n        [-70.0, -3.0],\n        [-70.0, -6.0],\n        [-75.0, -6.0],\n        [-75.0, -3.0]\n      ]\n    ]\n  }\n}\ndate range: 2019-01-01 to 2020-12-31\nconfidence: 3\nintensity: 50\n', 'aoi': {'type': 'Feature', 'geometry': {'type': 'Polygon', 'coordinates': [[[-75.0, -3.0], [-70.0, -3.0], [-70.0, -6.0], [-75.0, -6.0], [-75.0, -3.0]]]}}, 'date_range': ('2019-01-01', '2020-12-31'), 'thresholds': {'confidence': 3.0, 'intensity': 50}, 'result': {10: {'class': 'Tree cover', 'count': 5454}, 30: {'class': 'Grassland', 'count': 6}, 80: {'class': 'Permanent water bodies', 'count': 2}}}
