In [4]:
import re
import json
import requests
from datetime import datetime
from typing import Dict, TypedDict, Any, Tuple
import numpy as np
import boto3
import rasterio
from rasterio.io import MemoryFile
from rasterio.mask import mask
from rasterio.session import AWSSession
from shapely.geometry import shape
from langgraph.graph import StateGraph, END
from langchain_core.runnables import RunnableLambda
from langchain.tools import tool


def parse_user_message(state: dict) -> dict:
    """
    Parse user message from `state["message"]`, returning aoi, date_range, and thresholds.
    """
    message = state["message"]
    geojson_match = re.search(r'\{.*\}', message, re.DOTALL)
    geojson_str = geojson_match.group(0) if geojson_match else None
    geojson = json.loads(geojson_str)

    date_match = re.search(r'date range:\s*(\d{4}-\d{2}-\d{2})\s*to\s*(\d{4}-\d{2}-\d{2})', message)
    start_date, end_date = date_match.groups()

    conf_match = re.search(r'confidence:\s*([0-9.]+)', message)
    intensity_match = re.search(r'intensity:\s*([0-9.]+)', message)

    return {
        "aoi": geojson,
        "date_range": (start_date, end_date),
        "thresholds": {
            "confidence": float(conf_match.group(1)),
            "intensity": int(float(intensity_match.group(1)))
        }
    }


In [5]:
def filter_and_count_pixels(inputs: Dict) -> Dict:
    """Filter COG pixels and count results per land cover class using direct COG file."""
    aoi = inputs["aoi"]
    date_range = inputs["date_range"]
    thresholds = inputs["thresholds"]

    # Parse date range into days since 2015-01-01
    start_date = datetime.strptime(date_range[0], "%Y-%m-%d")
    end_date = datetime.strptime(date_range[1], "%Y-%m-%d")
    start_days = (start_date - datetime(2015, 1, 1)).days
    end_days = (end_date - datetime(2015, 1, 1)).days

    default_path = "s3://gfw-data-lake/umd_glad_dist_alerts/v20250329/raster/epsg-4326/cog/default.tif"
    intensity_path = "s3://gfw-data-lake/umd_glad_dist_alerts/v20250329/raster/epsg-4326/cog/intensity.tif"

    session = boto3.Session(profile_name="zeno_internal_sso")
    geometry = [shape(aoi["geometry"]).__geo_interface__]

    with rasterio.Env(AWSSession(session), AWS_REQUEST_PAYER="requester"):
        with rasterio.open(default_path) as src1:
            default_data, _ = mask(src1, geometry, crop=True)
        with rasterio.open(intensity_path) as src2:
            intensity_data, _ = mask(src2, geometry, crop=True)

    encoded = default_data[0]
    confidence = encoded // 10000
    days_since_2015 = encoded % 10000

    intensities = intensity_data[0]

    print("requested start and end days since 2015: ", start_days, end_days)

    # Load grassland COG from direct URL
    grassland_url = "https://s3.eu-central-1.wasabisys.com/arco/gpw_cultiv.grassland_rf.savgol_p_30m_20000101_20001231_go_epsg.4326_v1.tif"

    with rasterio.open(grassland_url) as src3:
        land_cover_data, _ = mask(src3, geometry, crop=True)

    land_cover = land_cover_data[0]    

    # Make sure the land cover matches the alerts data
    min_rows = min(default_data.squeeze().shape[0], land_cover.shape[0])
    min_cols = min(default_data.squeeze().shape[1], land_cover.shape[1])

    default_data = default_data.squeeze()[:min_rows, :min_cols]
    land_cover = land_cover[:min_rows, :min_cols]

    # Check for same shape
    if land_cover.shape != confidence.shape:
        raise ValueError(f"Shape mismatch: land cover {land_cover.shape} vs alerts {confidence.shape}")

    # Apply thresholds
    reference_date = datetime(2015, 1, 1)
    today = datetime.today()
    max_days = (today - reference_date).days

    valid_mask = (
        (confidence >= thresholds["confidence"]) &
        (days_since_2015 >= start_days) &
        (days_since_2015 <= end_days) &
        (days_since_2015 < max_days) &
        (intensities >= thresholds["intensity"])
    )

    # Count disturbance pixels per land cover class
    results = {}
    valid_lc = land_cover[valid_mask]
    unique_classes, counts = np.unique(valid_lc, return_counts=True)

    for cls, count in zip(unique_classes, counts):
        results[int(cls)] = int(count)

    return {"result": results}

In [6]:
class GraphState(TypedDict):
    message: str
    aoi: dict
    date_range: Tuple[str, str]
    thresholds: Dict[str, float]
    result: Dict[str, Any]

builder = StateGraph(GraphState)

builder.add_node("parse_message", RunnableLambda(parse_user_message))
builder.add_node("filter_and_count", RunnableLambda(filter_and_count_pixels))

builder.set_entry_point("parse_message")
builder.add_edge("parse_message", "filter_and_count")
builder.add_edge("filter_and_count", END)

graph = builder.compile()


In [7]:
#2 is low confidence, 3 is high confidence (1 means nothing and the values should all be >=20,000 and <40,000)

In [8]:
# Test AOI over the Peruvian Amazon (Loreto and Ucayali regions)

message = """
{
  "type": "Feature",
  "geometry": {
    "type": "Polygon",
    "coordinates": [
      [
        [-75.0, -3.0],
        [-70.0, -3.0],
        [-70.0, -6.0],
        [-75.0, -6.0],
        [-75.0, -3.0]
      ]
    ]
  }
}
date range: 2016-01-01 to 2020-12-31
confidence: 3
intensity: 50
"""

result = graph.invoke({"message": message})
print(result)



requested start and end days since 2015:  365 2191
{'message': '\n{\n  "type": "Feature",\n  "geometry": {\n    "type": "Polygon",\n    "coordinates": [\n      [\n        [-75.0, -3.0],\n        [-70.0, -3.0],\n        [-70.0, -6.0],\n        [-75.0, -6.0],\n        [-75.0, -3.0]\n      ]\n    ]\n  }\n}\ndate range: 2016-01-01 to 2020-12-31\nconfidence: 3\nintensity: 50\n', 'aoi': {'type': 'Feature', 'geometry': {'type': 'Polygon', 'coordinates': [[[-75.0, -3.0], [-70.0, -3.0], [-70.0, -6.0], [-75.0, -6.0], [-75.0, -3.0]]]}}, 'date_range': ('2016-01-01', '2020-12-31'), 'thresholds': {'confidence': 3.0, 'intensity': 50}, 'result': {0: 172110, 1: 54304, 2: 35150, 3: 22439, 4: 15326, 5: 11630, 6: 9324, 7: 7886, 8: 7214, 9: 6056, 10: 5394, 11: 4751, 12: 4172, 13: 3519, 14: 3244, 15: 2852, 16: 2648, 17: 2241, 18: 2039, 19: 1786, 20: 1584, 21: 1360, 22: 1120, 23: 955, 24: 855, 25: 732, 26: 592, 27: 510, 28: 441, 29: 397, 30: 398, 31: 336, 32: 331, 33: 288, 34: 293, 35: 256, 36: 237, 37: 215,